diff --git a/.dev_scripts/ci_container_test.sh b/.dev_scripts/ci_container_test.sh index fa5e4534..4fd2778f 100644 --- a/.dev_scripts/ci_container_test.sh +++ b/.dev_scripts/ci_container_test.sh @@ -1,26 +1,32 @@ -awk -F: '/^[^#]/ { print $1 }' requirements/framework.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -awk -F: '/^[^#]/ { print $1 }' requirements/audio.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -awk -F: '/^[^#]/ { print $1 }' requirements/cv.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -awk -F: '/^[^#]/ { print $1 }' requirements/multi-modal.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -awk -F: '/^[^#]/ { print $1 }' requirements/nlp.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -pip install -r requirements/tests.txt +echo "Testing envs" +printenv +echo "ENV END" +if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then + awk -F: '/^[^#]/ { print $1 }' requirements/framework.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html + awk -F: '/^[^#]/ { print $1 }' requirements/audio.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html + awk -F: '/^[^#]/ { print $1 }' requirements/cv.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html + awk -F: '/^[^#]/ { print $1 }' requirements/multi-modal.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html + awk -F: '/^[^#]/ { print $1 }' requirements/nlp.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html + pip install -r requirements/tests.txt -git config --global --add safe.directory /Maas-lib -git config --global user.email tmp -git config --global user.name tmp.com + git config --global --add safe.directory /Maas-lib + git config --global user.email tmp + git config --global user.name tmp.com -# linter test -# use internal project for pre-commit due to the network problem -if [ `git remote -v | grep alibaba | wc -l` -gt 1 ]; then - pre-commit run -c .pre-commit-config_local.yaml --all-files -fi -if [ $? -ne 0 ]; then - echo "linter test failed, please run 'pre-commit run --all-files' to check" - exit -1 + # linter test + # use internal project for pre-commit due to the network problem + if [ `git remote -v | grep alibaba | wc -l` -gt 1 ]; then + pre-commit run -c .pre-commit-config_local.yaml --all-files + if [ $? -ne 0 ]; then + echo "linter test failed, please run 'pre-commit run --all-files' to check" + exit -1 + fi + fi + # test with install + python setup.py install +else + echo "Running case in release image, run case directly!" fi -# test with install -python setup.py install - if [ $# -eq 0 ]; then ci_command="python tests/run.py --subprocess" else diff --git a/.dev_scripts/dockerci.sh b/.dev_scripts/dockerci.sh index af94b211..c502175b 100644 --- a/.dev_scripts/dockerci.sh +++ b/.dev_scripts/dockerci.sh @@ -20,28 +20,52 @@ do # pull image if there are update docker pull ${IMAGE_NAME}:${IMAGE_VERSION} - docker run --rm --name $CONTAINER_NAME --shm-size=16gb \ - --cpuset-cpus=${cpu_sets_arr[$gpu]} \ - --gpus="device=$gpu" \ - -v $CODE_DIR:$CODE_DIR_IN_CONTAINER \ - -v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ - -v $MODELSCOPE_HOME_CACHE/$gpu:/root \ - -v /home/admin/pre-commit:/home/admin/pre-commit \ - -e CI_TEST=True \ - -e TEST_LEVEL=$TEST_LEVEL \ - -e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ - -e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \ - -e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \ - -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ - -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ - -e TEST_LEVEL=$TEST_LEVEL \ - -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ - -e MODEL_TAG_URL=$MODEL_TAG_URL \ - --workdir=$CODE_DIR_IN_CONTAINER \ - --net host \ - ${IMAGE_NAME}:${IMAGE_VERSION} \ - $CI_COMMAND - + if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then + docker run --rm --name $CONTAINER_NAME --shm-size=16gb \ + --cpuset-cpus=${cpu_sets_arr[$gpu]} \ + --gpus="device=$gpu" \ + -v $CODE_DIR:$CODE_DIR_IN_CONTAINER \ + -v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ + -v $MODELSCOPE_HOME_CACHE/$gpu:/root \ + -v /home/admin/pre-commit:/home/admin/pre-commit \ + -e CI_TEST=True \ + -e TEST_LEVEL=$TEST_LEVEL \ + -e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ + -e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \ + -e MODELSCOPE_SDK_DEBUG=True \ + -e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \ + -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ + -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ + -e TEST_LEVEL=$TEST_LEVEL \ + -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ + -e MODEL_TAG_URL=$MODEL_TAG_URL \ + --workdir=$CODE_DIR_IN_CONTAINER \ + --net host \ + ${IMAGE_NAME}:${IMAGE_VERSION} \ + $CI_COMMAND + else + docker run --rm --name $CONTAINER_NAME --shm-size=16gb \ + --cpuset-cpus=${cpu_sets_arr[$gpu]} \ + --gpus="device=$gpu" \ + -v $CODE_DIR:$CODE_DIR_IN_CONTAINER \ + -v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ + -v $MODELSCOPE_HOME_CACHE/$gpu:/root \ + -v /home/admin/pre-commit:/home/admin/pre-commit \ + -e CI_TEST=True \ + -e TEST_LEVEL=$TEST_LEVEL \ + -e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ + -e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \ + -e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \ + -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ + -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ + -e TEST_LEVEL=$TEST_LEVEL \ + -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ + -e MODEL_TAG_URL=$MODEL_TAG_URL \ + --workdir=$CODE_DIR_IN_CONTAINER \ + --net host \ + ${IMAGE_NAME}:${IMAGE_VERSION} \ + $CI_COMMAND + fi if [ $? -ne 0 ]; then echo "Running test case failed, please check the log!" exit -1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c6290ff4..48fe7547 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://gitlab.com/pycqa/flake8.git - rev: 3.8.3 + rev: 4.0.0 hooks: - id: flake8 exclude: thirdparty/|examples/ diff --git a/.pre-commit-config_local.yaml b/.pre-commit-config_local.yaml index 138561e3..0b2e2f39 100644 --- a/.pre-commit-config_local.yaml +++ b/.pre-commit-config_local.yaml @@ -1,6 +1,6 @@ repos: - repo: /home/admin/pre-commit/flake8 - rev: 3.8.3 + rev: 4.0.0 hooks: - id: flake8 exclude: thirdparty/|examples/ diff --git a/data/test/audios/asr_example_8K.wav b/data/test/audios/asr_example_8K.wav new file mode 100644 index 00000000..956aad27 --- /dev/null +++ b/data/test/audios/asr_example_8K.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e999c247bfebb03d556a31722f0ce7145cac20a67fac9da813ad336e1f549f9f +size 38954 diff --git a/data/test/audios/asr_example_cn_dialect.wav b/data/test/audios/asr_example_cn_dialect.wav new file mode 100644 index 00000000..e18fb05d --- /dev/null +++ b/data/test/audios/asr_example_cn_dialect.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:32eb8d4d537941bf0edea69cd6723e8ba489fa3df64e13e29f96e4fae0b856f4 +size 93676 diff --git a/data/test/audios/asr_example_cn_en.wav b/data/test/audios/asr_example_cn_en.wav new file mode 100644 index 00000000..8baf3193 --- /dev/null +++ b/data/test/audios/asr_example_cn_en.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f57aee13ade70be6b2c6e4f5e5c7404bdb03057b63828baefbaadcf23855a4cb +size 472012 diff --git a/data/test/audios/asr_example_en.wav b/data/test/audios/asr_example_en.wav new file mode 100644 index 00000000..fa996eec --- /dev/null +++ b/data/test/audios/asr_example_en.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fee8e0460ca707f108782be0d93c555bf34fb6b1cb297e5fceed70192cc65f9b +size 71244 diff --git a/data/test/audios/asr_example_es.wav b/data/test/audios/asr_example_es.wav new file mode 100644 index 00000000..95b22dc3 --- /dev/null +++ b/data/test/audios/asr_example_es.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:450e31f9df8c5b48c617900625f01cb64c484f079a9843179fe9feaa7d163e61 +size 181964 diff --git a/data/test/audios/asr_example_id.wav b/data/test/audios/asr_example_id.wav new file mode 100644 index 00000000..54c30614 --- /dev/null +++ b/data/test/audios/asr_example_id.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:255494c41bc1dfb0c954d827ec6ce775900e4f7a55fb0a7881bdf9d66a03b425 +size 112078 diff --git a/data/test/audios/asr_example_ja.wav b/data/test/audios/asr_example_ja.wav new file mode 100644 index 00000000..e953fee2 --- /dev/null +++ b/data/test/audios/asr_example_ja.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:22a55277908bbc3ef60a0cf56b230eb507b9e837574e8f493e93644b1d21c281 +size 200556 diff --git a/data/test/audios/asr_example_ko.wav b/data/test/audios/asr_example_ko.wav new file mode 100644 index 00000000..0dad1be3 --- /dev/null +++ b/data/test/audios/asr_example_ko.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee92191836c76412463d8b282a7ab4e1aa57386ba699ec011a3e2c4d64f32f4b +size 162636 diff --git a/data/test/audios/asr_example_ru.wav b/data/test/audios/asr_example_ru.wav new file mode 100644 index 00000000..b0cb8f2f --- /dev/null +++ b/data/test/audios/asr_example_ru.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:77d1537fc584c1505d8aa10ec8c86af57ab661199e4f28fd7ffee3c22d1e4e61 +size 160204 diff --git a/data/test/audios/noise_2ch.wav b/data/test/audios/noise_2ch.wav new file mode 100644 index 00000000..c754e39a --- /dev/null +++ b/data/test/audios/noise_2ch.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e8d653a9a1ee49789c3df38e8da96af7118e0d8336d6ed12cd6458efa015071d +size 2327764 diff --git a/data/test/audios/wake_word_with_label_xyxy.wav b/data/test/audios/wake_word_with_label_xyxy.wav new file mode 100644 index 00000000..b7999777 --- /dev/null +++ b/data/test/audios/wake_word_with_label_xyxy.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c589d77404ea17d4d24daeb8624dce7e1ac919dc75e6bed44ea9d116f0514150 +size 68524 diff --git a/data/test/images/auto_demo.jpg b/data/test/images/auto_demo.jpg new file mode 100644 index 00000000..30393e53 --- /dev/null +++ b/data/test/images/auto_demo.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:76bf84536edbaf192a8a699efc62ba2b06056bac12c426ecfcc2e003d91fbd32 +size 53219 diff --git a/data/test/images/card_detection.jpg b/data/test/images/card_detection.jpg new file mode 100644 index 00000000..86728c2c --- /dev/null +++ b/data/test/images/card_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ecbc9d0827cfb92e93e7d75868b1724142685dc20d3b32023c3c657a7b688a9c +size 254845 diff --git a/data/test/images/face_detection2.jpeg b/data/test/images/face_detection2.jpeg new file mode 100644 index 00000000..7f6025fa --- /dev/null +++ b/data/test/images/face_detection2.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d510ab26ddc58ffea882c8ef850c1f9bd4444772f2bce7ebea3e76944536c3ae +size 48909 diff --git a/data/test/images/image_body_reshaping.jpg b/data/test/images/image_body_reshaping.jpg new file mode 100644 index 00000000..d78acb8f --- /dev/null +++ b/data/test/images/image_body_reshaping.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b2c1119e3d521cf2e583b1e85fc9c9afd1d44954b433135039a98050a730932d +size 1127557 diff --git a/data/test/images/image_inpainting/image_inpainting.png b/data/test/images/image_inpainting/image_inpainting.png new file mode 100644 index 00000000..e141012d --- /dev/null +++ b/data/test/images/image_inpainting/image_inpainting.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:46db348eae61448f1668ce282caec21375e96c3268d53da44aa67ec32cbf4fa5 +size 2747938 diff --git a/data/test/images/image_inpainting/image_inpainting_mask.png b/data/test/images/image_inpainting/image_inpainting_mask.png new file mode 100644 index 00000000..e30f67e7 --- /dev/null +++ b/data/test/images/image_inpainting/image_inpainting_mask.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:709c1828ed2d56badf2f19a40194da9a5e5e6db2fb73ef55d047407f49bc7a15 +size 27616 diff --git a/data/test/images/image_ocr_recognition.jpg b/data/test/images/image_ocr_recognition.jpg new file mode 100644 index 00000000..b41287cd --- /dev/null +++ b/data/test/images/image_ocr_recognition.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:772b19f76c98044e39330853928624f10e085106a4292b4dd19f865531080747 +size 959 diff --git a/data/test/images/keypoints_detect/body_keypoints_detection.jpg b/data/test/images/keypoints_detect/body_keypoints_detection.jpg deleted file mode 100644 index 71ce7d7e..00000000 --- a/data/test/images/keypoints_detect/body_keypoints_detection.jpg +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:379e11d7fc3734d3ec95afd0d86460b4653fbf4bb1f57f993610d6a6fd30fd3d -size 1702339 diff --git a/data/test/images/keypoints_detect/img_test_wholebody.jpg b/data/test/images/keypoints_detect/img_test_wholebody.jpg new file mode 100644 index 00000000..40a9f3f8 --- /dev/null +++ b/data/test/images/keypoints_detect/img_test_wholebody.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dec0fbb931cb609bf481e56b89cd2fbbab79839f22832c3bbe69a8fae2769cdd +size 167407 diff --git a/data/test/regression/fill_mask_sbert_zh.bin b/data/test/regression/fill_mask_sbert_zh.bin index 812f7ba2..62581a26 100644 --- a/data/test/regression/fill_mask_sbert_zh.bin +++ b/data/test/regression/fill_mask_sbert_zh.bin @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4fd6fa6b23c2fdaf876606a767d9b64b1924e1acddfc06ac42db73ba86083280 -size 119940 +oid sha256:4eae921001139d7e3c06331c9ef2213f8fc1c23512acd95751559866fb770e96 +size 121855 diff --git a/data/test/regression/fill_mask_veco_en.bin b/data/test/regression/fill_mask_veco_en.bin index be3fddc8..4d2dba7d 100644 --- a/data/test/regression/fill_mask_veco_en.bin +++ b/data/test/regression/fill_mask_veco_en.bin @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4d37672a0e299a08d2daf5c7fc29bfce96bb15701fe5e5e68f068861ac2ee705 -size 119619 +oid sha256:f97d34d7450d17d0a93647129ab10d16b1f6e70c34a73b6f7687b79519ee4f71 +size 121563 diff --git a/data/test/regression/fill_mask_veco_zh.bin b/data/test/regression/fill_mask_veco_zh.bin index c0d27e20..a6eb5621 100644 --- a/data/test/regression/fill_mask_veco_zh.bin +++ b/data/test/regression/fill_mask_veco_zh.bin @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c692e0753cfe349e520511427727a8252f141fa10e85f9a61562845e8d731f9a -size 119619 +oid sha256:a8355f27a3235209f206b5e75f4400353e5989e94cf4d71270b42ded8821d536 +size 121563 diff --git a/data/test/regression/sbert-base-tnews.bin b/data/test/regression/sbert-base-tnews.bin new file mode 100644 index 00000000..d2c63ab0 --- /dev/null +++ b/data/test/regression/sbert-base-tnews.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:344ef971bdf310b76c6571d1f4994ab6abc5edc659654d71a4f75b14a30960c2 +size 152926 diff --git a/data/test/regression/sbert_nli.bin b/data/test/regression/sbert_nli.bin index a5f680bb..52e31692 100644 --- a/data/test/regression/sbert_nli.bin +++ b/data/test/regression/sbert_nli.bin @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:44e3925c15d86d8596baeb6bd1d153d86f57b7489798b2cf988a1248e110fd62 -size 62231 +oid sha256:f0aeb07b6c9b40a0cfa7492e839431764e9bece93c906833a07c05e83520a399 +size 63161 diff --git a/data/test/regression/sbert_sen_sim.bin b/data/test/regression/sbert_sen_sim.bin index a59cbe0b..1c8efb81 100644 --- a/data/test/regression/sbert_sen_sim.bin +++ b/data/test/regression/sbert_sen_sim.bin @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1ff17a0272752de4c88d4254b2e881f97f8ef022f03609d03ee1de0ae964368a -size 62235 +oid sha256:7aa5c7a2565ccf0d2eea4baf8adbd0e020dbe36a7159b31156c53141cc9b2df2 +size 63165 diff --git a/data/test/regression/sbert_ws_en.bin b/data/test/regression/sbert_ws_en.bin index 4eb562d6..3ad45356 100644 --- a/data/test/regression/sbert_ws_en.bin +++ b/data/test/regression/sbert_ws_en.bin @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:9103ce2bc89212f67fb49ce70783b7667e376900d0f70fb8f5c4432eb74bc572 -size 60801 +oid sha256:cc6de82a8485fbfa008f6c2d5411cd07ba03e4a780bcb4e67efc6fba3c6ce92f +size 63597 diff --git a/data/test/regression/sbert_ws_zh.bin b/data/test/regression/sbert_ws_zh.bin index 555f640d..a85d787f 100644 --- a/data/test/regression/sbert_ws_zh.bin +++ b/data/test/regression/sbert_ws_zh.bin @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2d4dee34c7e83b77db04fb2f0d1200bfd37c7c24954c58e185da5cb96445975c -size 60801 +oid sha256:7d98ac11a4e9e2744a7402a5cc912da991a41938bbc5dd60f15ee5c6b3196030 +size 63349 diff --git a/data/test/regression/sbert_zero_shot.bin b/data/test/regression/sbert_zero_shot.bin index 23d40946..04171523 100644 --- a/data/test/regression/sbert_zero_shot.bin +++ b/data/test/regression/sbert_zero_shot.bin @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:9e3ecc2c30d382641d561f84849b199c12bb1a9418e8099a191153f6f5275a85 -size 61589 +oid sha256:01f9b9bf6f8bbf9bb377d4cb6f399b2e5e065381f5b7332343e0db7b4fae72a5 +size 62519 diff --git a/data/test/videos/referring_video_object_segmentation_test_video.mp4 b/data/test/videos/referring_video_object_segmentation_test_video.mp4 new file mode 100644 index 00000000..529595a5 --- /dev/null +++ b/data/test/videos/referring_video_object_segmentation_test_video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a49c9bc74a60860c360a4bf4509fe9db915279aaabd953f354f2c38e9be1e6cb +size 2924691 diff --git a/data/test/videos/test_realtime_vod.mp4 b/data/test/videos/test_realtime_vod.mp4 new file mode 100644 index 00000000..a0e44852 --- /dev/null +++ b/data/test/videos/test_realtime_vod.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f58df1d25590c158ae0a04b3999bd44b610cdaddb17d78afd84c34b3f00d4e87 +size 4068783 diff --git a/docker/Dockerfile.ubuntu b/docker/Dockerfile.ubuntu index a9a409b5..6dafbc3e 100644 --- a/docker/Dockerfile.ubuntu +++ b/docker/Dockerfile.ubuntu @@ -76,7 +76,7 @@ RUN pip install --no-cache-dir --upgrade pip && \ ENV SHELL=/bin/bash # install special package -RUN pip install --no-cache-dir mmcls>=0.21.0 mmdet>=2.25.0 decord>=0.6.0 datasets==2.1.0 numpy==1.18.5 ipykernel fairseq +RUN pip install --no-cache-dir mmcls>=0.21.0 mmdet>=2.25.0 decord>=0.6.0 datasets==2.1.0 numpy==1.18.5 ipykernel fairseq fasttext https://modelscope.oss-cn-beijing.aliyuncs.com/releases/dependencies/xtcocotools-1.12-cp37-cp37m-linux_x86_64.whl RUN if [ "$USE_GPU" = "True" ] ; then \ pip install --no-cache-dir dgl-cu113 dglgo -f https://data.dgl.ai/wheels/repo.html; \ diff --git a/modelscope/__init__.py b/modelscope/__init__.py index 0746d0e6..81fdf505 100644 --- a/modelscope/__init__.py +++ b/modelscope/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from .version import __version__ +from .version import __release_datetime__, __version__ -__all__ = ['__version__'] +__all__ = ['__version__', '__release_datetime__'] diff --git a/modelscope/exporters/base.py b/modelscope/exporters/base.py index f19d2bbb..c8b7900e 100644 --- a/modelscope/exporters/base.py +++ b/modelscope/exporters/base.py @@ -19,10 +19,13 @@ class Exporter(ABC): def from_model(cls, model: Model, **kwargs): """Build the Exporter instance. - @param model: A model instance. it will be used to output the generated file, + Args: + model: A Model instance. it will be used to generate the intermediate format file, and the configuration.json in its model_dir field will be used to create the exporter instance. - @param kwargs: Extra kwargs used to create the Exporter instance. - @return: The Exporter instance + kwargs: Extra kwargs used to create the Exporter instance. + + Returns: + The Exporter instance """ cfg = Config.from_file( os.path.join(model.model_dir, ModelFile.CONFIGURATION)) @@ -44,10 +47,13 @@ class Exporter(ABC): In some cases, several files may be generated, So please return a dict which contains the generated name with the file path. - @param opset: The version of the ONNX operator set to use. - @param outputs: The output dir. - @param kwargs: In this default implementation, - kwargs will be carried to generate_dummy_inputs as extra arguments (like input shape). - @return: A dict contains the model name with the model file path. + Args: + opset: The version of the ONNX operator set to use. + outputs: The output dir. + kwargs: In this default implementation, + kwargs will be carried to generate_dummy_inputs as extra arguments (like input shape). + + Returns: + A dict contains the model name with the model file path. """ pass diff --git a/modelscope/exporters/nlp/sbert_for_sequence_classification_exporter.py b/modelscope/exporters/nlp/sbert_for_sequence_classification_exporter.py index dc1e2b92..7cee331b 100644 --- a/modelscope/exporters/nlp/sbert_for_sequence_classification_exporter.py +++ b/modelscope/exporters/nlp/sbert_for_sequence_classification_exporter.py @@ -23,13 +23,18 @@ class SbertForSequenceClassificationExporter(TorchModelExporter): def generate_dummy_inputs(self, shape: Tuple = None, + pair: bool = False, **kwargs) -> Dict[str, Any]: """Generate dummy inputs for model exportation to onnx or other formats by tracing. - @param shape: A tuple of input shape which should have at most two dimensions. - shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor. - shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor. - @return: Dummy inputs. + Args: + shape: A tuple of input shape which should have at most two dimensions. + shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor. + shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor. + pair(bool, `optional`): Whether to generate sentence pairs or single sentences. + + Returns: + Dummy inputs. """ cfg = Config.from_file( @@ -55,7 +60,7 @@ class SbertForSequenceClassificationExporter(TorchModelExporter): **sequence_length }) preprocessor: Preprocessor = build_preprocessor(cfg, field_name) - if preprocessor.pair: + if pair: first_sequence = preprocessor.tokenizer.unk_token second_sequence = preprocessor.tokenizer.unk_token else: diff --git a/modelscope/exporters/torch_model_exporter.py b/modelscope/exporters/torch_model_exporter.py index 98a23fe5..94ef277a 100644 --- a/modelscope/exporters/torch_model_exporter.py +++ b/modelscope/exporters/torch_model_exporter.py @@ -13,8 +13,8 @@ from modelscope.models import TorchModel from modelscope.pipelines.base import collate_fn from modelscope.utils.constant import ModelFile from modelscope.utils.logger import get_logger -from modelscope.utils.regress_test_utils import compare_arguments_nested -from modelscope.utils.tensor_utils import torch_nested_numpify +from modelscope.utils.regress_test_utils import (compare_arguments_nested, + numpify_tensor_nested) from .base import Exporter logger = get_logger(__name__) @@ -28,49 +28,61 @@ class TorchModelExporter(Exporter): and to provide implementations for generate_dummy_inputs/inputs/outputs methods. """ - def export_onnx(self, outputs: str, opset=11, **kwargs): + def export_onnx(self, output_dir: str, opset=13, **kwargs): """Export the model as onnx format files. In some cases, several files may be generated, So please return a dict which contains the generated name with the file path. - @param opset: The version of the ONNX operator set to use. - @param outputs: The output dir. - @param kwargs: In this default implementation, - you can pass the arguments needed by _torch_export_onnx, other unrecognized args - will be carried to generate_dummy_inputs as extra arguments (such as input shape). - @return: A dict containing the model key - model file path pairs. + Args: + opset: The version of the ONNX operator set to use. + output_dir: The output dir. + kwargs: + model: A model instance which will replace the exporting of self.model. + In this default implementation, + you can pass the arguments needed by _torch_export_onnx, other unrecognized args + will be carried to generate_dummy_inputs as extra arguments (such as input shape). + + Returns: + A dict containing the model key - model file path pairs. """ - model = self.model + model = self.model if 'model' not in kwargs else kwargs.pop('model') if not isinstance(model, nn.Module) and hasattr(model, 'model'): model = model.model - onnx_file = os.path.join(outputs, ModelFile.ONNX_MODEL_FILE) + onnx_file = os.path.join(output_dir, ModelFile.ONNX_MODEL_FILE) self._torch_export_onnx(model, onnx_file, opset=opset, **kwargs) return {'model': onnx_file} - def export_torch_script(self, outputs: str, **kwargs): + def export_torch_script(self, output_dir: str, **kwargs): """Export the model as torch script files. In some cases, several files may be generated, So please return a dict which contains the generated name with the file path. - @param outputs: The output dir. - @param kwargs: In this default implementation, + Args: + output_dir: The output dir. + kwargs: + model: A model instance which will replace the exporting of self.model. + In this default implementation, you can pass the arguments needed by _torch_export_torch_script, other unrecognized args will be carried to generate_dummy_inputs as extra arguments (like input shape). - @return: A dict contains the model name with the model file path. + + Returns: + A dict contains the model name with the model file path. """ - model = self.model + model = self.model if 'model' not in kwargs else kwargs.pop('model') if not isinstance(model, nn.Module) and hasattr(model, 'model'): model = model.model - ts_file = os.path.join(outputs, ModelFile.TS_MODEL_FILE) + ts_file = os.path.join(output_dir, ModelFile.TS_MODEL_FILE) # generate ts by tracing self._torch_export_torch_script(model, ts_file, **kwargs) return {'model': ts_file} def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]: """Generate dummy inputs for model exportation to onnx or other formats by tracing. - @return: Dummy inputs. + + Returns: + Dummy inputs. """ return None @@ -93,7 +105,7 @@ class TorchModelExporter(Exporter): def _torch_export_onnx(self, model: nn.Module, output: str, - opset: int = 11, + opset: int = 13, device: str = 'cpu', validation: bool = True, rtol: float = None, @@ -101,18 +113,27 @@ class TorchModelExporter(Exporter): **kwargs): """Export the model to an onnx format file. - @param model: A torch.nn.Module instance to export. - @param output: The output file. - @param opset: The version of the ONNX operator set to use. - @param device: The device used to forward. - @param validation: Whether validate the export file. - @param rtol: The rtol used to regress the outputs. - @param atol: The atol used to regress the outputs. + Args: + model: A torch.nn.Module instance to export. + output: The output file. + opset: The version of the ONNX operator set to use. + device: The device used to forward. + validation: Whether validate the export file. + rtol: The rtol used to regress the outputs. + atol: The atol used to regress the outputs. + kwargs: + dummy_inputs: A dummy inputs which will replace the calling of self.generate_dummy_inputs(). + inputs: An inputs structure which will replace the calling of self.inputs. + outputs: An outputs structure which will replace the calling of self.outputs. """ - dummy_inputs = self.generate_dummy_inputs(**kwargs) - inputs = self.inputs - outputs = self.outputs + dummy_inputs = self.generate_dummy_inputs( + **kwargs) if 'dummy_inputs' not in kwargs else kwargs.pop( + 'dummy_inputs') + inputs = self.inputs if 'inputs' not in kwargs else kwargs.pop( + 'inputs') + outputs = self.outputs if 'outputs' not in kwargs else kwargs.pop( + 'outputs') if dummy_inputs is None or inputs is None or outputs is None: raise NotImplementedError( 'Model property dummy_inputs,inputs,outputs must be set.') @@ -125,7 +146,7 @@ class TorchModelExporter(Exporter): if isinstance(dummy_inputs, Mapping): dummy_inputs = dict(dummy_inputs) - onnx_outputs = list(self.outputs.keys()) + onnx_outputs = list(outputs.keys()) with replace_call(): onnx_export( @@ -160,11 +181,13 @@ class TorchModelExporter(Exporter): outputs_origin = model.forward( *_decide_input_format(model, dummy_inputs)) if isinstance(outputs_origin, Mapping): - outputs_origin = torch_nested_numpify( + outputs_origin = numpify_tensor_nested( list(outputs_origin.values())) + elif isinstance(outputs_origin, (tuple, list)): + outputs_origin = numpify_tensor_nested(outputs_origin) outputs = ort_session.run( onnx_outputs, - torch_nested_numpify(dummy_inputs), + numpify_tensor_nested(dummy_inputs), ) tols = {} @@ -184,19 +207,26 @@ class TorchModelExporter(Exporter): validation: bool = True, rtol: float = None, atol: float = None, + strict: bool = True, **kwargs): """Export the model to a torch script file. - @param model: A torch.nn.Module instance to export. - @param output: The output file. - @param device: The device used to forward. - @param validation: Whether validate the export file. - @param rtol: The rtol used to regress the outputs. - @param atol: The atol used to regress the outputs. + Args: + model: A torch.nn.Module instance to export. + output: The output file. + device: The device used to forward. + validation: Whether validate the export file. + rtol: The rtol used to regress the outputs. + atol: The atol used to regress the outputs. + strict: strict mode in torch script tracing. + kwargs: + dummy_inputs: A dummy inputs which will replace the calling of self.generate_dummy_inputs(). """ model.eval() - dummy_inputs = self.generate_dummy_inputs(**kwargs) + dummy_param = 'dummy_inputs' not in kwargs + dummy_inputs = self.generate_dummy_inputs( + **kwargs) if dummy_param else kwargs.pop('dummy_inputs') if dummy_inputs is None: raise NotImplementedError( 'Model property dummy_inputs must be set.') @@ -207,7 +237,7 @@ class TorchModelExporter(Exporter): model.eval() with replace_call(): traced_model = torch.jit.trace( - model, dummy_inputs, strict=False) + model, dummy_inputs, strict=strict) torch.jit.save(traced_model, output) if validation: @@ -216,9 +246,9 @@ class TorchModelExporter(Exporter): model.eval() ts_model.eval() outputs = ts_model.forward(*dummy_inputs) - outputs = torch_nested_numpify(outputs) + outputs = numpify_tensor_nested(outputs) outputs_origin = model.forward(*dummy_inputs) - outputs_origin = torch_nested_numpify(outputs_origin) + outputs_origin = numpify_tensor_nested(outputs_origin) tols = {} if rtol is not None: tols['rtol'] = rtol @@ -240,7 +270,6 @@ def replace_call(): problems. Here we recover the call method to the default implementation of torch.nn.Module, and change it back after the tracing was done. """ - TorchModel.call_origin, TorchModel.__call__ = TorchModel.__call__, TorchModel._call_impl yield TorchModel.__call__ = TorchModel.call_origin diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py index 8dcfa5b0..00254f16 100644 --- a/modelscope/hub/api.py +++ b/modelscope/hub/api.py @@ -1,32 +1,47 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +# yapf: disable +import datetime import os import pickle +import platform import shutil +import tempfile +import uuid from collections import defaultdict from http import HTTPStatus from http.cookiejar import CookieJar from os.path import expanduser -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import requests +from modelscope import __version__ from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, API_RESPONSE_FIELD_EMAIL, API_RESPONSE_FIELD_GIT_ACCESS_TOKEN, API_RESPONSE_FIELD_MESSAGE, API_RESPONSE_FIELD_USERNAME, - DEFAULT_CREDENTIALS_PATH) + DEFAULT_CREDENTIALS_PATH, + MODELSCOPE_ENVIRONMENT, ONE_YEAR_SECONDS, + Licenses, ModelVisibility) +from modelscope.hub.errors import (InvalidParameter, NotExistError, + NotLoginException, NoValidRevisionError, + RequestError, datahub_raise_on_error, + handle_http_post_error, + handle_http_response, is_ok, + raise_for_http_status, raise_on_error) +from modelscope.hub.git import GitCommandWrapper +from modelscope.hub.repository import Repository from modelscope.utils.config_ds import DOWNLOADED_DATASETS_PATH from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, DEFAULT_MODEL_REVISION, - DatasetFormations, DatasetMetaFormats, - DownloadMode) + DEFAULT_REPOSITORY_REVISION, + MASTER_MODEL_BRANCH, DatasetFormations, + DatasetMetaFormats, DownloadMode, + ModelFile) from modelscope.utils.logger import get_logger -from .errors import (InvalidParameter, NotExistError, RequestError, - datahub_raise_on_error, handle_http_response, is_ok, - raise_on_error) -from .utils.utils import (get_dataset_hub_endpoint, get_endpoint, +from .utils.utils import (get_endpoint, get_release_datetime, model_id_to_group_owner_name) logger = get_logger() @@ -34,10 +49,9 @@ logger = get_logger() class HubApi: - def __init__(self, endpoint=None, dataset_endpoint=None): + def __init__(self, endpoint=None): self.endpoint = endpoint if endpoint is not None else get_endpoint() - self.dataset_endpoint = dataset_endpoint if dataset_endpoint is not None else get_dataset_hub_endpoint( - ) + self.headers = {'user-agent': ModelScopeConfig.get_user_agent()} def login( self, @@ -57,8 +71,9 @@ class HubApi: """ path = f'{self.endpoint}/api/v1/login' - r = requests.post(path, json={'AccessToken': access_token}) - r.raise_for_status() + r = requests.post( + path, json={'AccessToken': access_token}, headers=self.headers) + raise_for_http_status(r) d = r.json() raise_on_error(d) @@ -105,17 +120,16 @@ class HubApi: path = f'{self.endpoint}/api/v1/models' owner_or_group, name = model_id_to_group_owner_name(model_id) + body = { + 'Path': owner_or_group, + 'Name': name, + 'ChineseName': chinese_name, + 'Visibility': visibility, # server check + 'License': license + } r = requests.post( - path, - json={ - 'Path': owner_or_group, - 'Name': name, - 'ChineseName': chinese_name, - 'Visibility': visibility, # server check - 'License': license - }, - cookies=cookies) - r.raise_for_status() + path, json=body, cookies=cookies, headers=self.headers) + handle_http_post_error(r, path, body) raise_on_error(r.json()) model_repo_url = f'{get_endpoint()}/{model_id}' return model_repo_url @@ -134,8 +148,8 @@ class HubApi: raise ValueError('Token does not exist, please login first.') path = f'{self.endpoint}/api/v1/models/{model_id}' - r = requests.delete(path, cookies=cookies) - r.raise_for_status() + r = requests.delete(path, cookies=cookies, headers=self.headers) + raise_for_http_status(r) raise_on_error(r.json()) def get_model_url(self, model_id): @@ -164,7 +178,7 @@ class HubApi: owner_or_group, name = model_id_to_group_owner_name(model_id) path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?Revision={revision}' - r = requests.get(path, cookies=cookies) + r = requests.get(path, cookies=cookies, headers=self.headers) handle_http_response(r, logger, cookies, model_id) if r.status_code == HTTPStatus.OK: if is_ok(r.json()): @@ -172,13 +186,108 @@ class HubApi: else: raise NotExistError(r.json()[API_RESPONSE_FIELD_MESSAGE]) else: - r.raise_for_status() + raise_for_http_status(r) + + def push_model(self, + model_id: str, + model_dir: str, + visibility: int = ModelVisibility.PUBLIC, + license: str = Licenses.APACHE_V2, + chinese_name: Optional[str] = None, + commit_message: Optional[str] = 'upload model', + revision: Optional[str] = DEFAULT_REPOSITORY_REVISION): + """ + Upload model from a given directory to given repository. A valid model directory + must contain a configuration.json file. - def list_model(self, - owner_or_group: str, - page_number=1, - page_size=10) -> dict: - """List model in owner or group. + This function upload the files in given directory to given repository. If the + given repository is not exists in remote, it will automatically create it with + given visibility, license and chinese_name parameters. If the revision is also + not exists in remote repository, it will create a new branch for it. + + This function must be called before calling HubApi's login with a valid token + which can be obtained from ModelScope's website. + + Args: + model_id (`str`): + The model id to be uploaded, caller must have write permission for it. + model_dir(`str`): + The Absolute Path of the finetune result. + visibility(`int`, defaults to `0`): + Visibility of the new created model(1-private, 5-public). If the model is + not exists in ModelScope, this function will create a new model with this + visibility and this parameter is required. You can ignore this parameter + if you make sure the model's existence. + license(`str`, defaults to `None`): + License of the new created model(see License). If the model is not exists + in ModelScope, this function will create a new model with this license + and this parameter is required. You can ignore this parameter if you + make sure the model's existence. + chinese_name(`str`, *optional*, defaults to `None`): + chinese name of the new created model. + commit_message(`str`, *optional*, defaults to `None`): + commit message of the push request. + revision (`str`, *optional*, default to DEFAULT_MODEL_REVISION): + which branch to push. If the branch is not exists, It will create a new + branch and push to it. + """ + if model_id is None: + raise InvalidParameter('model_id cannot be empty!') + if model_dir is None: + raise InvalidParameter('model_dir cannot be empty!') + if not os.path.exists(model_dir) or os.path.isfile(model_dir): + raise InvalidParameter('model_dir must be a valid directory.') + cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) + if not os.path.exists(cfg_file): + raise ValueError(f'{model_dir} must contain a configuration.json.') + cookies = ModelScopeConfig.get_cookies() + if cookies is None: + raise NotLoginException('Must login before upload!') + files_to_save = os.listdir(model_dir) + try: + self.get_model(model_id=model_id) + except Exception: + if visibility is None or license is None: + raise InvalidParameter( + 'visibility and license cannot be empty if want to create new repo' + ) + logger.info('Create new model %s' % model_id) + self.create_model( + model_id=model_id, + visibility=visibility, + license=license, + chinese_name=chinese_name) + tmp_dir = tempfile.mkdtemp() + git_wrapper = GitCommandWrapper() + try: + repo = Repository(model_dir=tmp_dir, clone_from=model_id) + branches = git_wrapper.get_remote_branches(tmp_dir) + if revision not in branches: + logger.info('Create new branch %s' % revision) + git_wrapper.new_branch(tmp_dir, revision) + git_wrapper.checkout(tmp_dir, revision) + for f in files_to_save: + if f[0] != '.': + src = os.path.join(model_dir, f) + if os.path.isdir(src): + shutil.copytree(src, os.path.join(tmp_dir, f)) + else: + shutil.copy(src, tmp_dir) + if not commit_message: + date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') + commit_message = '[automsg] push model %s to hub at %s' % ( + model_id, date) + repo.push(commit_message=commit_message, local_branch=revision, remote_branch=revision) + except Exception: + raise + finally: + shutil.rmtree(tmp_dir, ignore_errors=True) + + def list_models(self, + owner_or_group: str, + page_number=1, + page_size=10) -> dict: + """List models in owner or group. Args: owner_or_group(`str`): owner or group. @@ -193,7 +302,8 @@ class HubApi: path, data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' % (owner_or_group, page_number, page_size), - cookies=cookies) + cookies=cookies, + headers=self.headers) handle_http_response(r, logger, cookies, 'list_model') if r.status_code == HTTPStatus.OK: if is_ok(r.json()): @@ -202,7 +312,7 @@ class HubApi: else: raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) else: - r.raise_for_status() + raise_for_http_status(r) return None def _check_cookie(self, @@ -217,10 +327,70 @@ class HubApi: raise ValueError('Token does not exist, please login first.') return cookies + def list_model_revisions( + self, + model_id: str, + cutoff_timestamp: int = None, + use_cookies: Union[bool, CookieJar] = False) -> List[str]: + """Get model branch and tags. + + Args: + model_id (str): The model id + cutoff_timestamp (int): Tags created before the cutoff will be included. + The timestamp is represented by the seconds elasped from the epoch time. + use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, will + will load cookie from local. Defaults to False. + Returns: + Tuple[List[str], List[str]]: Return list of branch name and tags + """ + cookies = self._check_cookie(use_cookies) + if cutoff_timestamp is None: + cutoff_timestamp = get_release_datetime() + path = f'{self.endpoint}/api/v1/models/{model_id}/revisions?EndTime=%s' % cutoff_timestamp + r = requests.get(path, cookies=cookies, headers=self.headers) + handle_http_response(r, logger, cookies, model_id) + d = r.json() + raise_on_error(d) + info = d[API_RESPONSE_FIELD_DATA] + # tags returned from backend are guaranteed to be ordered by create-time + tags = [x['Revision'] for x in info['RevisionMap']['Tags'] + ] if info['RevisionMap']['Tags'] else [] + return tags + + def get_valid_revision(self, model_id: str, revision=None, cookies: Optional[CookieJar] = None): + release_timestamp = get_release_datetime() + current_timestamp = int(round(datetime.datetime.now().timestamp())) + # for active development in library codes (non-release-branches), release_timestamp + # is set to be a far-away-time-in-the-future, to ensure that we shall + # get the master-HEAD version from model repo by default (when no revision is provided) + if release_timestamp > current_timestamp + ONE_YEAR_SECONDS: + branches, tags = self.get_model_branches_and_tags( + model_id, use_cookies=False if cookies is None else cookies) + if revision is None: + revision = MASTER_MODEL_BRANCH + logger.info('Model revision not specified, use default: %s in development mode' % revision) + if revision not in branches and revision not in tags: + raise NotExistError('The model: %s has no branch or tag : %s .' % revision) + else: + revisions = self.list_model_revisions( + model_id, cutoff_timestamp=release_timestamp, use_cookies=False if cookies is None else cookies) + if revision is None: + if len(revisions) == 0: + raise NoValidRevisionError('The model: %s has no valid revision!' % model_id) + # tags (revisions) returned from backend are guaranteed to be ordered by create-time + # we shall obtain the latest revision created earlier than release version of this branch + revision = revisions[0] + logger.info('Model revision not specified, use the latest revision: %s' % revision) + else: + if revision not in revisions: + raise NotExistError( + 'The model: %s has no revision: %s !' % (model_id, revision)) + return revision + def get_model_branches_and_tags( self, model_id: str, - use_cookies: Union[bool, CookieJar] = False + use_cookies: Union[bool, CookieJar] = False, ) -> Tuple[List[str], List[str]]: """Get model branch and tags. @@ -234,7 +404,7 @@ class HubApi: cookies = self._check_cookie(use_cookies) path = f'{self.endpoint}/api/v1/models/{model_id}/revisions' - r = requests.get(path, cookies=cookies) + r = requests.get(path, cookies=cookies, headers=self.headers) handle_http_response(r, logger, cookies, model_id) d = r.json() raise_on_error(d) @@ -275,7 +445,11 @@ class HubApi: if root is not None: path = path + f'&Root={root}' - r = requests.get(path, cookies=cookies, headers=headers) + r = requests.get( + path, cookies=cookies, headers={ + **headers, + **self.headers + }) handle_http_response(r, logger, cookies, model_id) d = r.json() @@ -290,11 +464,10 @@ class HubApi: return files def list_datasets(self): - path = f'{self.dataset_endpoint}/api/v1/datasets' - headers = None + path = f'{self.endpoint}/api/v1/datasets' params = {} - r = requests.get(path, params=params, headers=headers) - r.raise_for_status() + r = requests.get(path, params=params, headers=self.headers) + raise_for_http_status(r) dataset_list = r.json()[API_RESPONSE_FIELD_DATA] return [x['Name'] for x in dataset_list] @@ -317,14 +490,14 @@ class HubApi: cache_dir): shutil.rmtree(cache_dir) os.makedirs(cache_dir, exist_ok=True) - datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}' + datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}' r = requests.get(datahub_url) resp = r.json() datahub_raise_on_error(datahub_url, resp) dataset_id = resp['Data']['Id'] dataset_type = resp['Data']['Type'] - datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}' - r = requests.get(datahub_url) + datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}' + r = requests.get(datahub_url, headers=self.headers) resp = r.json() datahub_raise_on_error(datahub_url, resp) file_list = resp['Data'] @@ -341,10 +514,10 @@ class HubApi: file_path = file_info['Path'] extension = os.path.splitext(file_path)[-1] if extension in dataset_meta_format: - datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ + datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ f'Revision={revision}&FilePath={file_path}' r = requests.get(datahub_url) - r.raise_for_status() + raise_for_http_status(r) local_path = os.path.join(cache_dir, file_path) if os.path.exists(local_path): logger.warning( @@ -365,7 +538,7 @@ class HubApi: namespace: str, revision: Optional[str] = DEFAULT_DATASET_REVISION): if file_name.endswith('.csv'): - file_name = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ + file_name = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ f'Revision={revision}&FilePath={file_name}' return file_name @@ -374,7 +547,7 @@ class HubApi: dataset_name: str, namespace: str, revision: Optional[str] = DEFAULT_DATASET_REVISION): - datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ + datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ f'ststoken?Revision={revision}' return self.datahub_remote_call(datahub_url) @@ -385,23 +558,39 @@ class HubApi: namespace: str, revision: Optional[str] = DEFAULT_DATASET_REVISION): - datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ + datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ f'ststoken?Revision={revision}' cookies = requests.utils.dict_from_cookiejar(cookies) - r = requests.get(url=datahub_url, cookies=cookies) + r = requests.get( + url=datahub_url, cookies=cookies, headers=self.headers) resp = r.json() raise_on_error(resp) return resp['Data'] + def list_oss_dataset_objects(self, dataset_name, namespace, max_limit, + is_recursive, is_filter_dir, revision): + url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/tree/?' \ + f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}' + + cookies = ModelScopeConfig.get_cookies() + if cookies: + cookies = requests.utils.dict_from_cookiejar(cookies) + + resp = requests.get(url=url, cookies=cookies) + resp = resp.json() + raise_on_error(resp) + resp = resp['Data'] + return resp + def on_dataset_download(self, dataset_name: str, namespace: str) -> None: url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase' - r = requests.post(url) - r.raise_for_status() + r = requests.post(url, headers=self.headers) + raise_for_http_status(r) @staticmethod def datahub_remote_call(url): - r = requests.get(url) + r = requests.get(url, headers={'user-agent': ModelScopeConfig.get_user_agent()}) resp = r.json() datahub_raise_on_error(url, resp) return resp['Data'] @@ -415,6 +604,7 @@ class ModelScopeConfig: COOKIES_FILE_NAME = 'cookies' GIT_TOKEN_FILE_NAME = 'git_token' USER_INFO_FILE_NAME = 'user' + USER_SESSION_ID_FILE_NAME = 'session' @staticmethod def make_sure_credential_path_exist(): @@ -443,6 +633,23 @@ class ModelScopeConfig: return cookies return None + @staticmethod + def get_user_session_id(): + session_path = os.path.join(ModelScopeConfig.path_credential, + ModelScopeConfig.USER_SESSION_ID_FILE_NAME) + session_id = '' + if os.path.exists(session_path): + with open(session_path, 'rb') as f: + session_id = str(f.readline().strip(), encoding='utf-8') + return session_id + if session_id == '' or len(session_id) != 32: + session_id = str(uuid.uuid4().hex) + ModelScopeConfig.make_sure_credential_path_exist() + with open(session_path, 'w+') as wf: + wf.write(session_id) + + return session_id + @staticmethod def save_token(token: str): ModelScopeConfig.make_sure_credential_path_exist() @@ -491,3 +698,32 @@ class ModelScopeConfig: except FileNotFoundError: pass return token + + @staticmethod + def get_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: + """Formats a user-agent string with basic info about a request. + + Args: + user_agent (`str`, `dict`, *optional*): + The user agent info in the form of a dictionary or a single string. + + Returns: + The formatted user-agent string. + """ + env = 'custom' + if MODELSCOPE_ENVIRONMENT in os.environ: + env = os.environ[MODELSCOPE_ENVIRONMENT] + + ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s' % ( + __version__, + platform.python_version(), + ModelScopeConfig.get_user_session_id(), + platform.platform(), + platform.processor(), + env, + ) + if isinstance(user_agent, dict): + ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items()) + elif isinstance(user_agent, str): + ua += ';' + user_agent + return ua diff --git a/modelscope/hub/constants.py b/modelscope/hub/constants.py index c8664597..730702c1 100644 --- a/modelscope/hub/constants.py +++ b/modelscope/hub/constants.py @@ -16,6 +16,9 @@ API_RESPONSE_FIELD_GIT_ACCESS_TOKEN = 'AccessToken' API_RESPONSE_FIELD_USERNAME = 'Username' API_RESPONSE_FIELD_EMAIL = 'Email' API_RESPONSE_FIELD_MESSAGE = 'Message' +MODELSCOPE_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT' +MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG' +ONE_YEAR_SECONDS = 24 * 365 * 60 * 60 class Licenses(object): diff --git a/modelscope/hub/deploy.py b/modelscope/hub/deploy.py new file mode 100644 index 00000000..8cacde82 --- /dev/null +++ b/modelscope/hub/deploy.py @@ -0,0 +1,339 @@ +import urllib +from abc import ABC +from http import HTTPStatus +from typing import Optional + +import json +import requests +from attrs import asdict, define, field, validators + +from modelscope.hub.api import ModelScopeConfig +from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, + API_RESPONSE_FIELD_MESSAGE) +from modelscope.hub.errors import (NotLoginException, NotSupportError, + RequestError, handle_http_response, is_ok, + raise_for_http_status) +from modelscope.hub.utils.utils import get_endpoint +from modelscope.utils.logger import get_logger + +# yapf: enable + +logger = get_logger() + + +class Accelerator(object): + CPU = 'cpu' + GPU = 'gpu' + + +class Vendor(object): + EAS = 'eas' + + +class EASRegion(object): + beijing = 'cn-beijing' + hangzhou = 'cn-hangzhou' + + +class EASCpuInstanceType(object): + """EAS Cpu Instance TYpe, ref(https://help.aliyun.com/document_detail/144261.html) + """ + tiny = 'ecs.c6.2xlarge' + small = 'ecs.c6.4xlarge' + medium = 'ecs.c6.6xlarge' + large = 'ecs.c6.8xlarge' + + +class EASGpuInstanceType(object): + """EAS Cpu Instance TYpe, ref(https://help.aliyun.com/document_detail/144261.html) + """ + tiny = 'ecs.gn5-c28g1.7xlarge' + small = 'ecs.gn5-c8g1.4xlarge' + medium = 'ecs.gn6i-c24g1.12xlarge' + large = 'ecs.gn6e-c12g1.3xlarge' + + +def min_smaller_than_max(instance, attribute, value): + if value > instance.max_replica: + raise ValueError( + "'min_replica' value: %s has to be smaller than 'max_replica' value: %s!" + % (value, instance.max_replica)) + + +@define +class ServiceScalingConfig(object): + """Resource scaling config + Currently we ignore max_replica + Args: + max_replica: maximum replica + min_replica: minimum replica + """ + max_replica: int = field(default=1, validator=validators.ge(1)) + min_replica: int = field( + default=1, validator=[validators.ge(1), min_smaller_than_max]) + + +@define +class ServiceResourceConfig(object): + """Eas Resource request. + + Args: + accelerator: the accelerator(cpu|gpu) + instance_type: the instance type. + scaling: The instance scaling config. + """ + instance_type: str + scaling: ServiceScalingConfig + accelerator: str = field( + default=Accelerator.CPU, + validator=validators.in_([Accelerator.CPU, Accelerator.GPU])) + + +@define +class ServiceProviderParameters(ABC): + pass + + +@define +class EASDeployParameters(ServiceProviderParameters): + """Parameters for EAS Deployment. + + Args: + resource_group: the resource group to deploy, current default. + region: The eas instance region(eg: cn-hangzhou). + access_key_id: The eas account access key id. + access_key_secret: The eas account access key secret. + vendor: must be 'eas' + """ + region: str + access_key_id: str + access_key_secret: str + resource_group: Optional[str] = None + vendor: str = field( + default=Vendor.EAS, validator=validators.in_([Vendor.EAS])) + + +@define +class EASListParameters(ServiceProviderParameters): + """EAS instance list parameters. + + Args: + resource_group: the resource group to deploy, current default. + region: The eas instance region(eg: cn-hangzhou). + access_key_id: The eas account access key id. + access_key_secret: The eas account access key secret. + vendor: must be 'eas' + """ + access_key_id: str + access_key_secret: str + region: str = None + resource_group: str = None + vendor: str = field( + default=Vendor.EAS, validator=validators.in_([Vendor.EAS])) + + +@define +class DeployServiceParameters(object): + """Deploy service parameters + + Args: + instance_name: the name of the service. + model_id: the modelscope model_id + revision: the modelscope model revision + resource: the resource requirement. + provider: the cloud service provider. + """ + instance_name: str + model_id: str + revision: str + resource: ServiceResourceConfig + provider: ServiceProviderParameters + + +class AttrsToQueryString(ABC): + """Convert the attrs class to json string. + + Args: + """ + + def to_query_str(self): + self_dict = asdict( + self.provider, filter=lambda attr, value: value is not None) + json_str = json.dumps(self_dict) + print(json_str) + safe_str = urllib.parse.quote_plus(json_str) + print(safe_str) + query_param = 'provider=%s' % safe_str + return query_param + + +@define +class ListServiceParameters(AttrsToQueryString): + provider: ServiceProviderParameters + skip: int = 0 + limit: int = 100 + + +@define +class GetServiceParameters(AttrsToQueryString): + provider: ServiceProviderParameters + + +@define +class DeleteServiceParameters(AttrsToQueryString): + provider: ServiceProviderParameters + + +class ServiceDeployer(object): + + def __init__(self, endpoint=None): + self.endpoint = endpoint if endpoint is not None else get_endpoint() + self.headers = {'user-agent': ModelScopeConfig.get_user_agent()} + self.cookies = ModelScopeConfig.get_cookies() + if self.cookies is None: + raise NotLoginException( + 'Token does not exist, please login with HubApi first.') + + # deploy_model + def create(self, model_id: str, revision: str, instance_name: str, + resource: ServiceResourceConfig, + provider: ServiceProviderParameters): + """Deploy model to cloud, current we only support PAI EAS, this is an async API , + and the deployment could take a while to finish remotely. Please check deploy instance + status separately via checking the status. + + Args: + model_id (str): The deployed model id + revision (str): The model revision + instance_name (str): The deployed model instance name. + resource (ServiceResourceConfig): The service resource information. + provider (ServiceProviderParameters): The service provider parameter + + Raises: + NotLoginException: To use this api, you need login first. + NotSupportError: Not supported platform. + RequestError: The server return error. + + Returns: + ServiceInstanceInfo: The information of the deployed service instance. + """ + if provider.vendor != Vendor.EAS: + raise NotSupportError( + 'Not support vendor: %s ,only support EAS current.' % + (provider.vendor)) + create_params = DeployServiceParameters( + instance_name=instance_name, + model_id=model_id, + revision=revision, + resource=resource, + provider=provider) + path = f'{self.endpoint}/api/v1/deployer/endpoint' + body = asdict(create_params) + r = requests.post( + path, json=body, cookies=self.cookies, headers=self.headers) + handle_http_response(r, logger, self.cookies, 'create_service') + if r.status_code >= HTTPStatus.OK and r.status_code < HTTPStatus.MULTIPLE_CHOICES: + if is_ok(r.json()): + data = r.json()[API_RESPONSE_FIELD_DATA] + return data + else: + raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) + else: + raise_for_http_status(r) + return None + + def get(self, instance_name: str, provider: ServiceProviderParameters): + """Query the specified instance information. + + Args: + instance_name (str): The deployed instance name. + provider (ServiceProviderParameters): The cloud provider information, for eas + need region(eg: ch-hangzhou), access_key_id and access_key_secret. + + Raises: + NotLoginException: To use this api, you need login first. + RequestError: The request is failed from server. + + Returns: + Dict: The information of the requested service instance. + """ + params = GetServiceParameters(provider=provider) + path = '%s/api/v1/deployer/endpoint/%s?%s' % ( + self.endpoint, instance_name, params.to_query_str()) + r = requests.get(path, cookies=self.cookies, headers=self.headers) + handle_http_response(r, logger, self.cookies, 'get_service') + if r.status_code == HTTPStatus.OK: + if is_ok(r.json()): + data = r.json()[API_RESPONSE_FIELD_DATA] + return data + else: + raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) + else: + raise_for_http_status(r) + return None + + def delete(self, instance_name: str, provider: ServiceProviderParameters): + """Delete deployed model, this api send delete command and return, it will take + some to delete, please check through the cloud console. + + Args: + instance_name (str): The instance name you want to delete. + provider (ServiceProviderParameters): The cloud provider information, for eas + need region(eg: ch-hangzhou), access_key_id and access_key_secret. + + Raises: + NotLoginException: To call this api, you need login first. + RequestError: The request is failed. + + Returns: + Dict: The deleted instance information. + """ + params = DeleteServiceParameters(provider=provider) + path = '%s/api/v1/deployer/endpoint/%s?%s' % ( + self.endpoint, instance_name, params.to_query_str()) + r = requests.delete(path, cookies=self.cookies, headers=self.headers) + handle_http_response(r, logger, self.cookies, 'delete_service') + if r.status_code == HTTPStatus.OK: + if is_ok(r.json()): + data = r.json()[API_RESPONSE_FIELD_DATA] + return data + else: + raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) + else: + raise_for_http_status(r) + return None + + def list(self, + provider: ServiceProviderParameters, + skip: int = 0, + limit: int = 100): + """List deployed model instances. + + Args: + provider (ServiceProviderParameters): The cloud service provider parameter, + for eas, need access_key_id and access_key_secret. + skip: start of the list, current not support. + limit: maximum number of instances return, current not support + Raises: + NotLoginException: To use this api, you need login first. + RequestError: The request is failed from server. + + Returns: + List: List of instance information + """ + + params = ListServiceParameters( + provider=provider, skip=skip, limit=limit) + path = '%s/api/v1/deployer/endpoint?%s' % (self.endpoint, + params.to_query_str()) + r = requests.get(path, cookies=self.cookies, headers=self.headers) + handle_http_response(r, logger, self.cookies, 'list_service_instances') + if r.status_code == HTTPStatus.OK: + if is_ok(r.json()): + data = r.json()[API_RESPONSE_FIELD_DATA] + return data + else: + raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) + else: + raise_for_http_status(r) + return None diff --git a/modelscope/hub/errors.py b/modelscope/hub/errors.py index c095a6ec..bfb55e6d 100644 --- a/modelscope/hub/errors.py +++ b/modelscope/hub/errors.py @@ -4,6 +4,18 @@ from http import HTTPStatus from requests.exceptions import HTTPError +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +class NotSupportError(Exception): + pass + + +class NoValidRevisionError(Exception): + pass + class NotExistError(Exception): pass @@ -45,15 +57,25 @@ def is_ok(rsp): return rsp['Code'] == HTTPStatus.OK and rsp['Success'] +def handle_http_post_error(response, url, request_body): + try: + response.raise_for_status() + except HTTPError as error: + logger.error('Request %s with body: %s exception' % + (url, request_body)) + raise error + + def handle_http_response(response, logger, cookies, model_id): try: response.raise_for_status() - except HTTPError: + except HTTPError as error: if cookies is None: # code in [403] and logger.error( f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \ private. Please login first.') - raise + logger.error('Response details: %s' % response.content) + raise error def raise_on_error(rsp): @@ -81,3 +103,33 @@ def datahub_raise_on_error(url, rsp): raise RequestError( f"Url = {url}, Status = {rsp.get('status')}, error = {rsp.get('error')}, message = {rsp.get('message')}" ) + + +def raise_for_http_status(rsp): + """ + Attempt to decode utf-8 first since some servers + localize reason strings, for invalid utf-8, fall back + to decoding with iso-8859-1. + """ + http_error_msg = '' + if isinstance(rsp.reason, bytes): + try: + reason = rsp.reason.decode('utf-8') + except UnicodeDecodeError: + reason = rsp.reason.decode('iso-8859-1') + else: + reason = rsp.reason + + if 400 <= rsp.status_code < 500: + http_error_msg = u'%s Client Error: %s for url: %s' % (rsp.status_code, + reason, rsp.url) + + elif 500 <= rsp.status_code < 600: + http_error_msg = u'%s Server Error: %s for url: %s' % (rsp.status_code, + reason, rsp.url) + + if http_error_msg: + req = rsp.request + if req.method == 'POST': + http_error_msg = u'%s, body: %s' % (http_error_msg, req.body) + raise HTTPError(http_error_msg, response=rsp) diff --git a/modelscope/hub/file_download.py b/modelscope/hub/file_download.py index 1cc5645b..042ea6a6 100644 --- a/modelscope/hub/file_download.py +++ b/modelscope/hub/file_download.py @@ -2,29 +2,25 @@ import copy import os -import sys import tempfile from functools import partial from http.cookiejar import CookieJar from pathlib import Path from typing import Dict, Optional, Union -from uuid import uuid4 import requests -from filelock import FileLock from tqdm import tqdm from modelscope import __version__ +from modelscope.hub.api import HubApi, ModelScopeConfig from modelscope.utils.constant import DEFAULT_MODEL_REVISION from modelscope.utils.logger import get_logger -from .api import HubApi, ModelScopeConfig from .constants import FILE_HASH from .errors import FileDownloadError, NotExistError from .utils.caching import ModelFileSystemCache from .utils.utils import (file_integrity_validation, get_cache_dir, get_endpoint, model_id_to_group_owner_name) -SESSION_ID = uuid4().hex logger = get_logger() @@ -35,6 +31,7 @@ def model_file_download( cache_dir: Optional[str] = None, user_agent: Union[Dict, str, None] = None, local_files_only: Optional[bool] = False, + cookies: Optional[CookieJar] = None, ) -> Optional[str]: # pragma: no cover """ Download from a given URL and cache it if it's not already present in the @@ -105,54 +102,47 @@ def model_file_download( " online, set 'local_files_only' to False.") _api = HubApi() - headers = {'user-agent': http_user_agent(user_agent=user_agent, )} - cookies = ModelScopeConfig.get_cookies() - branches, tags = _api.get_model_branches_and_tags( - model_id, use_cookies=False if cookies is None else cookies) + headers = { + 'user-agent': ModelScopeConfig.get_user_agent(user_agent=user_agent, ) + } + if cookies is None: + cookies = ModelScopeConfig.get_cookies() + + revision = _api.get_valid_revision( + model_id, revision=revision, cookies=cookies) file_to_download_info = None - is_commit_id = False - if revision in branches or revision in tags: # The revision is version or tag, - # we need to confirm the version is up to date - # we need to get the file list to check if the lateast version is cached, if so return, otherwise download - model_files = _api.get_model_files( - model_id=model_id, - revision=revision, - recursive=True, - use_cookies=False if cookies is None else cookies) - - for model_file in model_files: - if model_file['Type'] == 'tree': - continue - - if model_file['Path'] == file_path: - if cache.exists(model_file): - return cache.get_file_by_info(model_file) - else: - file_to_download_info = model_file - break - - if file_to_download_info is None: - raise NotExistError('The file path: %s not exist in: %s' % - (file_path, model_id)) - else: # the revision is commit id. - cached_file_path = cache.get_file_by_path_and_commit_id( - file_path, revision) - if cached_file_path is not None: - file_name = os.path.basename(cached_file_path) - logger.info( - f'File {file_name} already in cache, skip downloading!') - return cached_file_path # the file is in cache. - is_commit_id = True + # we need to confirm the version is up-to-date + # we need to get the file list to check if the latest version is cached, if so return, otherwise download + model_files = _api.get_model_files( + model_id=model_id, + revision=revision, + recursive=True, + use_cookies=False if cookies is None else cookies) + + for model_file in model_files: + if model_file['Type'] == 'tree': + continue + + if model_file['Path'] == file_path: + if cache.exists(model_file): + logger.info( + f'File {model_file["Name"]} already in cache, skip downloading!' + ) + return cache.get_file_by_info(model_file) + else: + file_to_download_info = model_file + break + + if file_to_download_info is None: + raise NotExistError('The file path: %s not exist in: %s' % + (file_path, model_id)) + # we need to download again url_to_download = get_file_download_url(model_id, file_path, revision) file_to_download_info = { - 'Path': - file_path, - 'Revision': - revision if is_commit_id else file_to_download_info['Revision'], - FILE_HASH: - None if (is_commit_id or FILE_HASH not in file_to_download_info) else - file_to_download_info[FILE_HASH] + 'Path': file_path, + 'Revision': file_to_download_info['Revision'], + FILE_HASH: file_to_download_info[FILE_HASH] } temp_file_name = next(tempfile._get_candidate_names()) @@ -171,25 +161,6 @@ def model_file_download( os.path.join(temporary_cache_dir, temp_file_name)) -def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: - """Formats a user-agent string with basic info about a request. - - Args: - user_agent (`str`, `dict`, *optional*): - The user agent info in the form of a dictionary or a single string. - - Returns: - The formatted user-agent string. - """ - ua = f'modelscope/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}' - - if isinstance(user_agent, dict): - ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items()) - elif isinstance(user_agent, str): - ua = user_agent - return ua - - def get_file_download_url(model_id: str, file_path: str, revision: str): """ Format file download url according to `model_id`, `revision` and `file_path`. diff --git a/modelscope/hub/git.py b/modelscope/hub/git.py index 486f8df3..7943023b 100644 --- a/modelscope/hub/git.py +++ b/modelscope/hub/git.py @@ -3,10 +3,9 @@ import os import subprocess from typing import List -from xmlrpc.client import Boolean from modelscope.utils.logger import get_logger -from .api import ModelScopeConfig +from ..utils.constant import MASTER_MODEL_BRANCH from .errors import GitError logger = get_logger() @@ -131,6 +130,7 @@ class GitCommandWrapper(metaclass=Singleton): return response def add_user_info(self, repo_base_dir, repo_name): + from modelscope.hub.api import ModelScopeConfig user_name, user_email = ModelScopeConfig.get_user_info() if user_name and user_email: # config user.name and user.email if exist @@ -138,8 +138,8 @@ class GitCommandWrapper(metaclass=Singleton): repo_base_dir, repo_name, user_name) response = self._run_git_command(*config_user_name_args.split(' ')) logger.debug(response.stdout.decode('utf8')) - config_user_email_args = '-C %s/%s config user.name %s' % ( - repo_base_dir, repo_name, user_name) + config_user_email_args = '-C %s/%s config user.email %s' % ( + repo_base_dir, repo_name, user_email) response = self._run_git_command( *config_user_email_args.split(' ')) logger.debug(response.stdout.decode('utf8')) @@ -177,6 +177,18 @@ class GitCommandWrapper(metaclass=Singleton): cmds = ['-C', '%s' % repo_dir, 'checkout', '-b', revision] return self._run_git_command(*cmds) + def get_remote_branches(self, repo_dir: str): + cmds = ['-C', '%s' % repo_dir, 'branch', '-r'] + rsp = self._run_git_command(*cmds) + info = [ + line.strip() + for line in rsp.stdout.decode('utf8').strip().split(os.linesep) + ] + if len(info) == 1: + return ['/'.join(info[0].split('/')[1:])] + else: + return ['/'.join(line.split('/')[1:]) for line in info[1:]] + def pull(self, repo_dir: str): cmds = ['-C', repo_dir, 'pull'] return self._run_git_command(*cmds) @@ -216,3 +228,22 @@ class GitCommandWrapper(metaclass=Singleton): files.append(line.split(' ')[-1]) return files + + def tag(self, + repo_dir: str, + tag_name: str, + message: str, + ref: str = MASTER_MODEL_BRANCH): + cmd_args = [ + '-C', repo_dir, 'tag', tag_name, '-m', + '"%s"' % message, ref + ] + rsp = self._run_git_command(*cmd_args) + logger.debug(rsp.stdout.decode('utf8')) + return rsp + + def push_tag(self, repo_dir: str, tag_name): + cmd_args = ['-C', repo_dir, 'push', 'origin', tag_name] + rsp = self._run_git_command(*cmd_args) + logger.debug(rsp.stdout.decode('utf8')) + return rsp diff --git a/modelscope/hub/repository.py b/modelscope/hub/repository.py index d92089ed..6b116f79 100644 --- a/modelscope/hub/repository.py +++ b/modelscope/hub/repository.py @@ -5,9 +5,9 @@ from typing import Optional from modelscope.hub.errors import GitError, InvalidParameter, NotLoginException from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, - DEFAULT_MODEL_REVISION) + DEFAULT_REPOSITORY_REVISION, + MASTER_MODEL_BRANCH) from modelscope.utils.logger import get_logger -from .api import ModelScopeConfig from .git import GitCommandWrapper from .utils.utils import get_endpoint @@ -21,7 +21,7 @@ class Repository: def __init__(self, model_dir: str, clone_from: str, - revision: Optional[str] = DEFAULT_MODEL_REVISION, + revision: Optional[str] = DEFAULT_REPOSITORY_REVISION, auth_token: Optional[str] = None, git_path: Optional[str] = None): """ @@ -47,6 +47,7 @@ class Repository: err_msg = 'a non-default value of revision cannot be empty.' raise InvalidParameter(err_msg) + from modelscope.hub.api import ModelScopeConfig if auth_token: self.auth_token = auth_token else: @@ -89,7 +90,8 @@ class Repository: def push(self, commit_message: str, - branch: Optional[str] = DEFAULT_MODEL_REVISION, + local_branch: Optional[str] = DEFAULT_REPOSITORY_REVISION, + remote_branch: Optional[str] = DEFAULT_REPOSITORY_REVISION, force: bool = False): """Push local files to remote, this method will do. git pull @@ -116,14 +118,48 @@ class Repository: url = self.git_wrapper.get_repo_remote_url(self.model_dir) self.git_wrapper.pull(self.model_dir) + self.git_wrapper.add(self.model_dir, all_files=True) self.git_wrapper.commit(self.model_dir, commit_message) self.git_wrapper.push( repo_dir=self.model_dir, token=self.auth_token, url=url, - local_branch=branch, - remote_branch=branch) + local_branch=local_branch, + remote_branch=remote_branch) + + def tag(self, tag_name: str, message: str, ref: str = MASTER_MODEL_BRANCH): + """Create a new tag. + Args: + tag_name (str): The name of the tag + message (str): The tag message. + ref (str): The tag reference, can be commit id or branch. + """ + if tag_name is None or tag_name == '': + msg = 'We use tag-based revision, therefore tag_name cannot be None or empty.' + raise InvalidParameter(msg) + if message is None or message == '': + msg = 'We use annotated tag, therefore message cannot None or empty.' + self.git_wrapper.tag( + repo_dir=self.model_dir, + tag_name=tag_name, + message=message, + ref=ref) + + def tag_and_push(self, + tag_name: str, + message: str, + ref: str = MASTER_MODEL_BRANCH): + """Create tag and push to remote + + Args: + tag_name (str): The name of the tag + message (str): The tag message. + ref (str, optional): The tag ref, can be commit id or branch. Defaults to MASTER_MODEL_BRANCH. + """ + self.tag(tag_name, message, ref) + + self.git_wrapper.push_tag(repo_dir=self.model_dir, tag_name=tag_name) class DatasetRepository: @@ -166,7 +202,7 @@ class DatasetRepository: err_msg = 'a non-default value of revision cannot be empty.' raise InvalidParameter(err_msg) self.revision = revision - + from modelscope.hub.api import ModelScopeConfig if auth_token: self.auth_token = auth_token else: diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py index cde6ad34..4b81de44 100644 --- a/modelscope/hub/snapshot_download.py +++ b/modelscope/hub/snapshot_download.py @@ -2,16 +2,15 @@ import os import tempfile +from http.cookiejar import CookieJar from pathlib import Path from typing import Dict, Optional, Union +from modelscope.hub.api import HubApi, ModelScopeConfig from modelscope.utils.constant import DEFAULT_MODEL_REVISION from modelscope.utils.logger import get_logger -from .api import HubApi, ModelScopeConfig from .constants import FILE_HASH -from .errors import NotExistError -from .file_download import (get_file_download_url, http_get_file, - http_user_agent) +from .file_download import get_file_download_url, http_get_file from .utils.caching import ModelFileSystemCache from .utils.utils import (file_integrity_validation, get_cache_dir, model_id_to_group_owner_name) @@ -23,7 +22,8 @@ def snapshot_download(model_id: str, revision: Optional[str] = DEFAULT_MODEL_REVISION, cache_dir: Union[str, Path, None] = None, user_agent: Optional[Union[Dict, str]] = None, - local_files_only: Optional[bool] = False) -> str: + local_files_only: Optional[bool] = False, + cookies: Optional[CookieJar] = None) -> str: """Download all files of a repo. Downloads a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from a repo, because you don't know which @@ -81,15 +81,15 @@ def snapshot_download(model_id: str, ) # we can not confirm the cached file is for snapshot 'revision' else: # make headers - headers = {'user-agent': http_user_agent(user_agent=user_agent, )} + headers = { + 'user-agent': + ModelScopeConfig.get_user_agent(user_agent=user_agent, ) + } _api = HubApi() - cookies = ModelScopeConfig.get_cookies() - # get file list from model repo - branches, tags = _api.get_model_branches_and_tags( - model_id, use_cookies=False if cookies is None else cookies) - if revision not in branches and revision not in tags: - raise NotExistError('The specified branch or tag : %s not exist!' - % revision) + if cookies is None: + cookies = ModelScopeConfig.get_cookies() + revision = _api.get_valid_revision( + model_id, revision=revision, cookies=cookies) snapshot_header = headers if 'CI_TEST' in os.environ else { **headers, @@ -110,7 +110,7 @@ def snapshot_download(model_id: str, for model_file in model_files: if model_file['Type'] == 'tree': continue - # check model_file is exist in cache, if exist, skip download, otherwise download + # check model_file is exist in cache, if existed, skip download, otherwise download if cache.exists(model_file): file_name = os.path.basename(model_file['Name']) logger.info( diff --git a/modelscope/hub/utils/utils.py b/modelscope/hub/utils/utils.py index d84b78ea..a54f3413 100644 --- a/modelscope/hub/utils/utils.py +++ b/modelscope/hub/utils/utils.py @@ -2,12 +2,12 @@ import hashlib import os +from datetime import datetime from typing import Optional -from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DATA_ENDPOINT, - DEFAULT_MODELSCOPE_DOMAIN, +from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, DEFAULT_MODELSCOPE_GROUP, - MODEL_ID_SEPARATOR, + MODEL_ID_SEPARATOR, MODELSCOPE_SDK_DEBUG, MODELSCOPE_URL_SCHEME) from modelscope.hub.errors import FileIntegrityError from modelscope.utils.file_utils import get_default_cache_dir @@ -38,17 +38,24 @@ def get_cache_dir(model_id: Optional[str] = None): base_path, model_id + '/') +def get_release_datetime(): + if MODELSCOPE_SDK_DEBUG in os.environ: + rt = int(round(datetime.now().timestamp())) + else: + from modelscope import version + rt = int( + round( + datetime.strptime(version.__release_datetime__, + '%Y-%m-%d %H:%M:%S').timestamp())) + return rt + + def get_endpoint(): modelscope_domain = os.getenv('MODELSCOPE_DOMAIN', DEFAULT_MODELSCOPE_DOMAIN) return MODELSCOPE_URL_SCHEME + modelscope_domain -def get_dataset_hub_endpoint(): - return os.environ.get('HUB_DATASET_ENDPOINT', - DEFAULT_MODELSCOPE_DATA_ENDPOINT) - - def compute_hash(file_path): BUFFER_SIZE = 1024 * 64 # 64k buffer size sha256_hash = hashlib.sha256() diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 33273502..7944d1ed 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -9,11 +9,14 @@ class Models(object): Model name should only contain model info but not task info. """ + # tinynas models tinynas_detection = 'tinynas-detection' + tinynas_damoyolo = 'tinynas-damoyolo' # vision models detection = 'detection' realtime_object_detection = 'realtime-object-detection' + realtime_video_object_detection = 'realtime-video-object-detection' scrfd = 'scrfd' classification_model = 'ClassificationModel' nafnet = 'nafnet' @@ -27,11 +30,13 @@ class Models(object): face_2d_keypoints = 'face-2d-keypoints' panoptic_segmentation = 'swinL-panoptic-segmentation' image_reid_person = 'passvitb' + image_inpainting = 'FFTInpainting' video_summarization = 'pgl-video-summarization' swinL_semantic_segmentation = 'swinL-semantic-segmentation' vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' text_driven_segmentation = 'text-driven-segmentation' resnet50_bert = 'resnet50-bert' + referring_video_object_segmentation = 'swinT-referring-video-object-segmentation' fer = 'fer' retinaface = 'retinaface' shop_segmentation = 'shop-segmentation' @@ -39,14 +44,18 @@ class Models(object): mtcnn = 'mtcnn' ulfd = 'ulfd' video_inpainting = 'video-inpainting' + human_wholebody_keypoint = 'human-wholebody-keypoint' hand_static = 'hand-static' face_human_hand_detection = 'face-human-hand-detection' face_emotion = 'face-emotion' product_segmentation = 'product-segmentation' + image_body_reshaping = 'image-body-reshaping' # EasyCV models yolox = 'YOLOX' segformer = 'Segformer' + hand_2d_keypoints = 'HRNet-Hand2D-Keypoints' + image_object_detection_auto = 'image-object-detection-auto' # nlp models bert = 'bert' @@ -58,18 +67,22 @@ class Models(object): space_dst = 'space-dst' space_intent = 'space-intent' space_modeling = 'space-modeling' - star = 'star' - star3 = 'star3' + space_T_en = 'space-T-en' + space_T_cn = 'space-T-cn' tcrf = 'transformer-crf' + tcrf_wseg = 'transformer-crf-for-word-segmentation' transformer_softmax = 'transformer-softmax' lcrf = 'lstm-crf' + lcrf_wseg = 'lstm-crf-for-word-segmentation' gcnncrf = 'gcnn-crf' bart = 'bart' gpt3 = 'gpt3' + gpt_neo = 'gpt-neo' plug = 'plug' bert_for_ds = 'bert-for-document-segmentation' ponet = 'ponet' T5 = 'T5' + bloom = 'bloom' # audio models sambert_hifigan = 'sambert-hifigan' @@ -88,6 +101,10 @@ class Models(object): team = 'team-multi-modal-similarity' video_clip = 'video-clip-multi-modal-embedding' + # science models + unifold = 'unifold' + unifold_symmetry = 'unifold-symmetry' + class TaskModels(object): # nlp task @@ -96,6 +113,7 @@ class TaskModels(object): information_extraction = 'information-extraction' fill_mask = 'fill-mask' feature_extraction = 'feature-extraction' + text_generation = 'text-generation' class Heads(object): @@ -111,6 +129,8 @@ class Heads(object): token_classification = 'token-classification' # extraction information_extraction = 'information-extraction' + # text gen + text_generation = 'text-generation' class Pipelines(object): @@ -144,6 +164,7 @@ class Pipelines(object): salient_detection = 'u2net-salient-detection' image_classification = 'image-classification' face_detection = 'resnet-face-detection-scrfd10gkps' + card_detection = 'resnet-card-detection-scrfd34gkps' ulfd_face_detection = 'manual-face-detection-ulfd' facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' retina_face_detection = 'resnet50-face-detection-retinaface' @@ -160,6 +181,7 @@ class Pipelines(object): face_image_generation = 'gan-face-image-generation' product_retrieval_embedding = 'resnet50-product-retrieval-embedding' realtime_object_detection = 'cspnet_realtime-object-detection_yolox' + realtime_video_object_detection = 'cspnet_realtime-video-object-detection_streamyolo' face_recognition = 'ir101-face-recognition-cfglint' image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' image2image_translation = 'image-to-image-translation' @@ -168,6 +190,7 @@ class Pipelines(object): ocr_recognition = 'convnextTiny-ocr-recognition' image_portrait_enhancement = 'gpen-image-portrait-enhancement' image_to_image_generation = 'image-to-image-generation' + image_object_detection_auto = 'yolox_image-object-detection-auto' skin_retouching = 'unet-skin-retouching' tinynas_classification = 'tinynas-classification' tinynas_detection = 'tinynas-detection' @@ -178,21 +201,32 @@ class Pipelines(object): video_summarization = 'googlenet_pgl_video_summarization' image_semantic_segmentation = 'image-semantic-segmentation' image_reid_person = 'passvitb-image-reid-person' + image_inpainting = 'fft-inpainting' text_driven_segmentation = 'text-driven-segmentation' movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' shop_segmentation = 'shop-segmentation' video_inpainting = 'video-inpainting' + human_wholebody_keypoint = 'hrnetw48_human-wholebody-keypoint_image' pst_action_recognition = 'patchshift-action-recognition' hand_static = 'hand-static' face_human_hand_detection = 'face-human-hand-detection' face_emotion = 'face-emotion' product_segmentation = 'product-segmentation' + image_body_reshaping = 'flow-based-body-reshaping' + referring_video_object_segmentation = 'referring-video-object-segmentation' # nlp tasks + automatic_post_editing = 'automatic-post-editing' + translation_quality_estimation = 'translation-quality-estimation' + domain_classification = 'domain-classification' sentence_similarity = 'sentence-similarity' word_segmentation = 'word-segmentation' + multilingual_word_segmentation = 'multilingual-word-segmentation' + word_segmentation_thai = 'word-segmentation-thai' part_of_speech = 'part-of-speech' named_entity_recognition = 'named-entity-recognition' + named_entity_recognition_thai = 'named-entity-recognition-thai' + named_entity_recognition_viet = 'named-entity-recognition-viet' text_generation = 'text-generation' text2text_generation = 'text2text-generation' sentiment_analysis = 'sentiment-analysis' @@ -208,14 +242,18 @@ class Pipelines(object): zero_shot_classification = 'zero-shot-classification' text_error_correction = 'text-error-correction' plug_generation = 'plug-generation' + gpt3_generation = 'gpt3-generation' faq_question_answering = 'faq-question-answering' conversational_text_to_sql = 'conversational-text-to-sql' table_question_answering_pipeline = 'table-question-answering-pipeline' sentence_embedding = 'sentence-embedding' - passage_ranking = 'passage-ranking' + text_ranking = 'text-ranking' relation_extraction = 'relation-extraction' document_segmentation = 'document-segmentation' feature_extraction = 'feature-extraction' + translation_en_to_de = 'translation_en_to_de' # keep it underscore + translation_en_to_ro = 'translation_en_to_ro' # keep it underscore + translation_en_to_fr = 'translation_en_to_fr' # keep it underscore # audio tasks sambert_hifigan_tts = 'sambert-hifigan-tts' @@ -236,6 +274,10 @@ class Pipelines(object): text_to_image_synthesis = 'text-to-image-synthesis' video_multi_modal_embedding = 'video-multi-modal-embedding' image_text_retrieval = 'image-text-retrieval' + ofa_ocr_recognition = 'ofa-ocr-recognition' + + # science tasks + protein_structure = 'unifold-protein-structure' class Trainers(object): @@ -253,12 +295,16 @@ class Trainers(object): # multi-modal trainers clip_multi_modal_embedding = 'clip-multi-modal-embedding' + ofa = 'ofa' # cv trainers image_instance_segmentation = 'image-instance-segmentation' image_portrait_enhancement = 'image-portrait-enhancement' video_summarization = 'video-summarization' movie_scene_segmentation = 'movie-scene-segmentation' + face_detection_scrfd = 'face-detection-scrfd' + card_detection_scrfd = 'card-detection-scrfd' + image_inpainting = 'image-inpainting' # nlp trainers bert_sentiment_analysis = 'bert-sentiment-analysis' @@ -266,10 +312,11 @@ class Trainers(object): dialog_intent_trainer = 'dialog-intent-trainer' nlp_base_trainer = 'nlp-base-trainer' nlp_veco_trainer = 'nlp-veco-trainer' - nlp_passage_ranking_trainer = 'nlp-passage-ranking-trainer' + nlp_text_ranking_trainer = 'nlp-text-ranking-trainer' # audio trainers speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' + speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' class Preprocessors(object): @@ -298,8 +345,12 @@ class Preprocessors(object): bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' text_gen_tokenizer = 'text-gen-tokenizer' text2text_gen_preprocessor = 'text2text-gen-preprocessor' + text_gen_jieba_tokenizer = 'text-gen-jieba-tokenizer' + text2text_translate_preprocessor = 'text2text-translate-preprocessor' token_cls_tokenizer = 'token-cls-tokenizer' ner_tokenizer = 'ner-tokenizer' + thai_ner_tokenizer = 'thai-ner-tokenizer' + viet_ner_tokenizer = 'viet-ner-tokenizer' nli_tokenizer = 'nli-tokenizer' sen_cls_tokenizer = 'sen-cls-tokenizer' dialog_intent_preprocessor = 'dialog-intent-preprocessor' @@ -309,9 +360,10 @@ class Preprocessors(object): zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' text_error_correction = 'text-error-correction' sentence_embedding = 'sentence-embedding' - passage_ranking = 'passage-ranking' + text_ranking = 'text-ranking' sequence_labeling_tokenizer = 'sequence-labeling-tokenizer' word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' + thai_wseg_tokenizer = 'thai-wseg-tokenizer' fill_mask = 'fill-mask' fill_mask_ponet = 'fill-mask-ponet' faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' @@ -320,6 +372,7 @@ class Preprocessors(object): re_tokenizer = 're-tokenizer' document_segmentation = 'document-segmentation' feature_extraction = 'feature-extraction' + sentence_piece = 'sentence-piece' # audio preprocessor linear_aec_fbank = 'linear-aec-fbank' @@ -331,6 +384,9 @@ class Preprocessors(object): ofa_tasks_preprocessor = 'ofa-tasks-preprocessor' mplug_tasks_preprocessor = 'mplug-tasks-preprocessor' + # science preprocessor + unifold_preprocessor = 'unifold-preprocessor' + class Metrics(object): """ Names for different metrics. @@ -340,6 +396,9 @@ class Metrics(object): accuracy = 'accuracy' audio_noise_metric = 'audio-noise-metric' + # text gen + BLEU = 'bleu' + # metrics for image denoise task image_denoise_metric = 'image-denoise-metric' @@ -358,6 +417,10 @@ class Metrics(object): video_summarization_metric = 'video-summarization-metric' # metric for movie-scene-segmentation task movie_scene_segmentation_metric = 'movie-scene-segmentation-metric' + # metric for inpainting task + image_inpainting_metric = 'image-inpainting-metric' + # metric for ocr + NED = 'ned' class Optimizers(object): @@ -399,6 +462,9 @@ class Hooks(object): IterTimerHook = 'IterTimerHook' EvaluationHook = 'EvaluationHook' + # Compression + SparsityHook = 'SparsityHook' + class LR_Schedulers(object): """learning rate scheduler is defined here @@ -413,7 +479,10 @@ class Datasets(object): """ Names for different datasets. """ ClsDataset = 'ClsDataset' - Face2dKeypointsDataset = 'Face2dKeypointsDataset' + Face2dKeypointsDataset = 'FaceKeypointDataset' + HandCocoWholeBodyDataset = 'HandCocoWholeBodyDataset' + HumanWholeBodyKeypointDataset = 'WholeBodyCocoTopDownDataset' SegDataset = 'SegDataset' DetDataset = 'DetDataset' DetImagesMixDataset = 'DetImagesMixDataset' + PairedDataset = 'PairedDataset' diff --git a/modelscope/metrics/__init__.py b/modelscope/metrics/__init__.py index d3975a2c..c022eaf4 100644 --- a/modelscope/metrics/__init__.py +++ b/modelscope/metrics/__init__.py @@ -17,6 +17,9 @@ if TYPE_CHECKING: from .token_classification_metric import TokenClassificationMetric from .video_summarization_metric import VideoSummarizationMetric from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric + from .accuracy_metric import AccuracyMetric + from .bleu_metric import BleuMetric + from .image_inpainting_metric import ImageInpaintingMetric else: _import_structure = { @@ -34,6 +37,9 @@ else: 'token_classification_metric': ['TokenClassificationMetric'], 'video_summarization_metric': ['VideoSummarizationMetric'], 'movie_scene_segmentation_metric': ['MovieSceneSegmentationMetric'], + 'image_inpainting_metric': ['ImageInpaintingMetric'], + 'accuracy_metric': ['AccuracyMetric'], + 'bleu_metric': ['BleuMetric'], } import sys diff --git a/modelscope/metrics/accuracy_metric.py b/modelscope/metrics/accuracy_metric.py new file mode 100644 index 00000000..1761786e --- /dev/null +++ b/modelscope/metrics/accuracy_metric.py @@ -0,0 +1,46 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Dict + +import numpy as np + +from modelscope.metainfo import Metrics +from modelscope.outputs import OutputKeys +from modelscope.utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + + +@METRICS.register_module(group_key=default_group, module_name=Metrics.accuracy) +class AccuracyMetric(Metric): + """The metric computation class for classification classes. + + This metric class calculates accuracy for the whole input batches. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.preds = [] + self.labels = [] + + def add(self, outputs: Dict, inputs: Dict): + label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS + ground_truths = inputs[label_name] + eval_results = outputs[label_name] + assert type(ground_truths) == type(eval_results) + if isinstance(ground_truths, list): + self.preds.extend(eval_results) + self.labels.extend(ground_truths) + elif isinstance(ground_truths, np.ndarray): + self.preds.extend(eval_results.tolist()) + self.labels.extend(ground_truths.tolist()) + else: + raise 'only support list or np.ndarray' + + def evaluate(self): + assert len(self.preds) == len(self.labels) + return { + MetricKeys.ACCURACY: (np.asarray([ + pred == ref for pred, ref in zip(self.preds, self.labels) + ])).mean().item() + } diff --git a/modelscope/metrics/audio_noise_metric.py b/modelscope/metrics/audio_noise_metric.py index f26db46d..8555e95b 100644 --- a/modelscope/metrics/audio_noise_metric.py +++ b/modelscope/metrics/audio_noise_metric.py @@ -35,6 +35,8 @@ class AudioNoiseMetric(Metric): total_loss = avg_loss + avg_amp + avg_phase + avg_sisnr return { 'total_loss': total_loss.item(), - 'avg_sisnr': avg_sisnr.item(), + # model use opposite number of sisnr as a calculation shortcut. + # revert it in evaluation result + 'avg_sisnr': -avg_sisnr.item(), MetricKeys.AVERAGE_LOSS: avg_loss.item() } diff --git a/modelscope/metrics/base.py b/modelscope/metrics/base.py index 3a9d810f..955946b5 100644 --- a/modelscope/metrics/base.py +++ b/modelscope/metrics/base.py @@ -10,8 +10,8 @@ class Metric(ABC): complex metrics for a specific task with or without other Metric subclasses. """ - def __init__(self, trainer=None, *args, **kwargs): - self.trainer = trainer + def __init__(self, *args, **kwargs): + pass @abstractmethod def add(self, outputs: Dict, inputs: Dict): diff --git a/modelscope/metrics/bleu_metric.py b/modelscope/metrics/bleu_metric.py new file mode 100644 index 00000000..7c134b6a --- /dev/null +++ b/modelscope/metrics/bleu_metric.py @@ -0,0 +1,42 @@ +from itertools import zip_longest +from typing import Dict + +import sacrebleu + +from modelscope.metainfo import Metrics +from modelscope.utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + +EVAL_BLEU_ORDER = 4 + + +@METRICS.register_module(group_key=default_group, module_name=Metrics.BLEU) +class BleuMetric(Metric): + """The metric computation bleu for text generation classes. + + This metric class calculates accuracy for the whole input batches. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.eval_tokenized_bleu = kwargs.get('eval_tokenized_bleu', False) + self.hyp_name = kwargs.get('hyp_name', 'hyp') + self.ref_name = kwargs.get('ref_name', 'ref') + self.refs = list() + self.hyps = list() + + def add(self, outputs: Dict, inputs: Dict): + self.refs.extend(inputs[self.ref_name]) + self.hyps.extend(outputs[self.hyp_name]) + + def evaluate(self): + if self.eval_tokenized_bleu: + bleu = sacrebleu.corpus_bleu( + self.hyps, list(zip_longest(*self.refs)), tokenize='none') + else: + bleu = sacrebleu.corpus_bleu(self.hyps, + list(zip_longest(*self.refs))) + return { + MetricKeys.BLEU_4: bleu.score, + } diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py index 9e875cc4..da3b64c7 100644 --- a/modelscope/metrics/builder.py +++ b/modelscope/metrics/builder.py @@ -18,10 +18,12 @@ class MetricKeys(object): SSIM = 'ssim' AVERAGE_LOSS = 'avg_loss' FScore = 'fscore' + FID = 'fid' BLEU_1 = 'bleu-1' BLEU_4 = 'bleu-4' ROUGE_1 = 'rouge-1' ROUGE_L = 'rouge-l' + NED = 'ned' # ocr metric task_default_metrics = { @@ -31,6 +33,7 @@ task_default_metrics = { Tasks.sentiment_classification: [Metrics.seq_cls_metric], Tasks.token_classification: [Metrics.token_cls_metric], Tasks.text_generation: [Metrics.text_gen_metric], + Tasks.text_classification: [Metrics.seq_cls_metric], Tasks.image_denoising: [Metrics.image_denoise_metric], Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric], Tasks.image_portrait_enhancement: @@ -39,6 +42,7 @@ task_default_metrics = { Tasks.image_captioning: [Metrics.text_gen_metric], Tasks.visual_question_answering: [Metrics.text_gen_metric], Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric], + Tasks.image_inpainting: [Metrics.image_inpainting_metric], } diff --git a/modelscope/metrics/ciderD/__init__.py b/modelscope/metrics/ciderD/__init__.py new file mode 100755 index 00000000..3f7d85bb --- /dev/null +++ b/modelscope/metrics/ciderD/__init__.py @@ -0,0 +1 @@ +__author__ = 'tylin' diff --git a/modelscope/metrics/ciderD/ciderD.py b/modelscope/metrics/ciderD/ciderD.py new file mode 100755 index 00000000..05c7eb23 --- /dev/null +++ b/modelscope/metrics/ciderD/ciderD.py @@ -0,0 +1,57 @@ +# Filename: ciderD.py +# +# Description: Describes the class to compute the CIDEr-D (Consensus-Based Image Description Evaluation) Metric +# by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) +# +# Creation Date: Sun Feb 8 14:16:54 2015 +# +# Authors: Ramakrishna Vedantam and Tsung-Yi Lin +from __future__ import absolute_import, division, print_function + +from .ciderD_scorer import CiderScorer + + +class CiderD: + """ + Main Class to compute the CIDEr metric + + """ + + def __init__(self, n=4, sigma=6.0, df='corpus'): + # set cider to sum over 1 to 4-grams + self._n = n + # set the standard deviation parameter for gaussian penalty + self._sigma = sigma + # set which where to compute document frequencies from + self._df = df + self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df) + + def compute_score(self, gts, res): + """ + Main function to compute CIDEr score + :param hypo_for_image (dict) : dictionary with key and value + ref_for_image (dict) : dictionary with key and value + :return: cider (float) : computed CIDEr score for the corpus + """ # noqa + + # clear all the previous hypos and refs + tmp_cider_scorer = self.cider_scorer.copy_empty() + tmp_cider_scorer.clear() + for res_id in res: + + hypo = res_id['caption'] + ref = gts[res_id['image_id']] + + # Sanity check. + assert (type(hypo) is list) + assert (len(hypo) == 1) + assert (type(ref) is list) + assert (len(ref) > 0) + tmp_cider_scorer += (hypo[0], ref) + + (score, scores) = tmp_cider_scorer.compute_score() + + return score, scores + + def method(self): + return 'CIDEr-D' diff --git a/modelscope/metrics/ciderD/ciderD_scorer.py b/modelscope/metrics/ciderD/ciderD_scorer.py new file mode 100755 index 00000000..4157ec11 --- /dev/null +++ b/modelscope/metrics/ciderD/ciderD_scorer.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python +# Tsung-Yi Lin +# Ramakrishna Vedantam +from __future__ import absolute_import, division, print_function +import copy +import math +import os +import pdb +from collections import defaultdict + +import numpy as np +import six +from six.moves import cPickle + + +def precook(s, n=4, out=False): + """ + Takes a string as input and returns an object that can be given to + either cook_refs or cook_test. This is optional: cook_refs and cook_test + can take string arguments as well. + :param s: string : sentence to be converted into ngrams + :param n: int : number of ngrams for which representation is calculated + :return: term frequency vector for occuring ngrams + """ + words = s.split() + counts = defaultdict(int) + for k in range(1, n + 1): + for i in range(len(words) - k + 1): + ngram = tuple(words[i:i + k]) + counts[ngram] += 1 + return counts + + +def cook_refs(refs, n=4): # lhuang: oracle will call with "average" + '''Takes a list of reference sentences for a single segment + and returns an object that encapsulates everything that BLEU + needs to know about them. + :param refs: list of string : reference sentences for some image + :param n: int : number of ngrams for which (ngram) representation is calculated + :return: result (list of dict) + ''' + return [precook(ref, n) for ref in refs] + + +def cook_test(test, n=4): + '''Takes a test sentence and returns an object that + encapsulates everything that BLEU needs to know about it. + :param test: list of string : hypothesis sentence for some image + :param n: int : number of ngrams for which (ngram) representation is calculated + :return: result (dict) + ''' + return precook(test, n, True) + + +class CiderScorer(object): + """CIDEr scorer. + """ + + def copy(self): + ''' copy the refs.''' + new = CiderScorer(n=self.n) + new.ctest = copy.copy(self.ctest) + new.crefs = copy.copy(self.crefs) + return new + + def copy_empty(self): + new = CiderScorer(df_mode='corpus', n=self.n, sigma=self.sigma) + new.df_mode = self.df_mode + new.ref_len = self.ref_len + new.document_frequency = self.document_frequency + return new + + def __init__(self, df_mode='corpus', test=None, refs=None, n=4, sigma=6.0): + ''' singular instance ''' + self.n = n + self.sigma = sigma + self.crefs = [] + self.ctest = [] + self.df_mode = df_mode + self.ref_len = None + if self.df_mode != 'corpus': + pkl_file = cPickle.load( + open(df_mode, 'rb'), + **(dict(encoding='latin1') if six.PY3 else {})) + self.ref_len = np.log(float(pkl_file['ref_len'])) + self.document_frequency = pkl_file['document_frequency'] + else: + self.document_frequency = None + self.cook_append(test, refs) + + def clear(self): + self.crefs = [] + self.ctest = [] + + def cook_append(self, test, refs): + '''called by constructor and __iadd__ to avoid creating new instances.''' + + if refs is not None: + self.crefs.append(cook_refs(refs)) + if test is not None: + self.ctest.append(cook_test(test)) # N.B.: -1 + else: + self.ctest.append( + None) # lens of crefs and ctest have to match + + def size(self): + assert len(self.crefs) == len( + self.ctest), 'refs/test mismatch! %d<>%d' % (len( + self.crefs), len(self.ctest)) + return len(self.crefs) + + def __iadd__(self, other): + '''add an instance (e.g., from another sentence).''' + + if type(other) is tuple: + # avoid creating new CiderScorer instances + self.cook_append(other[0], other[1]) + else: + self.ctest.extend(other.ctest) + self.crefs.extend(other.crefs) + + return self + + def compute_doc_freq(self): + """ + Compute term frequency for reference data. + This will be used to compute idf (inverse document frequency later) + The term frequency is stored in the object + :return: None + """ + for refs in self.crefs: + # refs, k ref captions of one image + for ngram in set([ + ngram for ref in refs for (ngram, count) in ref.items() + ]): # noqa + self.document_frequency[ngram] += 1 + + def compute_cider(self): + + def counts2vec(cnts): + """ + Function maps counts of ngram to vector of tfidf weights. + The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. + The n-th entry of array denotes length of n-grams. + :param cnts: + :return: vec (array of dict), norm (array of float), length (int) + """ + vec = [defaultdict(float) for _ in range(self.n)] + length = 0 + norm = [0.0 for _ in range(self.n)] + for (ngram, term_freq) in cnts.items(): + # give word count 1 if it doesn't appear in reference corpus + df = np.log(max(1.0, self.document_frequency[ngram])) + # ngram index + n = len(ngram) - 1 + # tf (term_freq) * idf (precomputed idf) for n-grams + vec[n][ngram] = float(term_freq) * (self.ref_len - df) + # compute norm for the vector. the norm will be used for computing similarity + norm[n] += pow(vec[n][ngram], 2) + + if n == 1: + length += term_freq + norm = [np.sqrt(n) for n in norm] + return vec, norm, length + + def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): + ''' + Compute the cosine similarity of two vectors. + :param vec_hyp: array of dictionary for vector corresponding to hypothesis + :param vec_ref: array of dictionary for vector corresponding to reference + :param norm_hyp: array of float for vector corresponding to hypothesis + :param norm_ref: array of float for vector corresponding to reference + :param length_hyp: int containing length of hypothesis + :param length_ref: int containing length of reference + :return: array of score for each n-grams cosine similarity + ''' + delta = float(length_hyp - length_ref) + # measure consine similarity + val = np.array([0.0 for _ in range(self.n)]) + for n in range(self.n): + # ngram + for (ngram, count) in vec_hyp[n].items(): + # vrama91 : added clipping + val[n] += min(vec_hyp[n][ngram], + vec_ref[n][ngram]) * vec_ref[n][ngram] + + if (norm_hyp[n] != 0) and (norm_ref[n] != 0): + val[n] /= (norm_hyp[n] * norm_ref[n]) + + assert (not math.isnan(val[n])) + # vrama91: added a length based gaussian penalty + val[n] *= np.e**(-(delta**2) / (2 * self.sigma**2)) + return val + + # compute log reference length + if self.df_mode == 'corpus': + self.ref_len = np.log(float(len(self.crefs))) + # elif self.df_mode == "coco-val-df": + # if coco option selected, use length of coco-val set + # self.ref_len = np.log(float(40504)) + + scores = [] + for test, refs in zip(self.ctest, self.crefs): + # compute vector for test captions + vec, norm, length = counts2vec(test) + # compute vector for ref captions + score = np.array([0.0 for _ in range(self.n)]) + for ref in refs: + vec_ref, norm_ref, length_ref = counts2vec(ref) + score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) + # change by vrama91 - mean of ngram scores, instead of sum + score_avg = np.mean(score) + # divide by number of references + score_avg /= len(refs) + # multiply score by 10 + score_avg *= 10.0 + # append score of an image to the score list + scores.append(score_avg) + return scores + + def compute_score(self, option=None, verbose=0): + # compute idf + if self.df_mode == 'corpus': + self.document_frequency = defaultdict(float) + self.compute_doc_freq() + # assert to check document frequency + assert (len(self.ctest) >= max(self.document_frequency.values())) + # import json for now and write the corresponding files + # compute cider score + score = self.compute_cider() + # debug + # print score + return np.mean(np.array(score)), np.array(score) diff --git a/modelscope/metrics/image_denoise_metric.py b/modelscope/metrics/image_denoise_metric.py index 94ec9dc7..1692f299 100644 --- a/modelscope/metrics/image_denoise_metric.py +++ b/modelscope/metrics/image_denoise_metric.py @@ -1,12 +1,16 @@ +# ------------------------------------------------------------------------ +# Copyright (c) Alibaba, Inc. and its affiliates. +# ------------------------------------------------------------------------ +# modified from https://github.com/megvii-research/NAFNet/blob/main/basicsr/metrics/psnr_ssim.py +# ------------------------------------------------------------------------ from typing import Dict +import cv2 import numpy as np -from skimage.metrics import peak_signal_noise_ratio, structural_similarity +import torch from modelscope.metainfo import Metrics from modelscope.utils.registry import default_group -from modelscope.utils.tensor_utils import (torch_nested_detach, - torch_nested_numpify) from .base import Metric from .builder import METRICS, MetricKeys @@ -20,26 +24,249 @@ class ImageDenoiseMetric(Metric): label_name = 'target' def __init__(self): + super(ImageDenoiseMetric, self).__init__() self.preds = [] self.labels = [] def add(self, outputs: Dict, inputs: Dict): ground_truths = outputs[ImageDenoiseMetric.label_name] eval_results = outputs[ImageDenoiseMetric.pred_name] - self.preds.append( - torch_nested_numpify(torch_nested_detach(eval_results))) - self.labels.append( - torch_nested_numpify(torch_nested_detach(ground_truths))) + self.preds.append(eval_results) + self.labels.append(ground_truths) def evaluate(self): psnr_list, ssim_list = [], [] for (pred, label) in zip(self.preds, self.labels): - psnr_list.append( - peak_signal_noise_ratio(label[0], pred[0], data_range=255)) - ssim_list.append( - structural_similarity( - label[0], pred[0], multichannel=True, data_range=255)) + psnr_list.append(calculate_psnr(label[0], pred[0], crop_border=0)) + ssim_list.append(calculate_ssim(label[0], pred[0], crop_border=0)) return { MetricKeys.PSNR: np.mean(psnr_list), MetricKeys.SSIM: np.mean(ssim_list) } + + +def reorder_image(img, input_order='HWC'): + """Reorder images to 'HWC' order. + If the input_order is (h, w), return (h, w, 1); + If the input_order is (c, h, w), return (h, w, c); + If the input_order is (h, w, c), return as it is. + Args: + img (ndarray): Input image. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + If the input image shape is (h, w), input_order will not have + effects. Default: 'HWC'. + Returns: + ndarray: reordered image. + """ + + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'" + ) + if len(img.shape) == 2: + img = img[..., None] + if input_order == 'CHW': + img = img.transpose(1, 2, 0) + return img + + +def calculate_psnr(img1, img2, crop_border, input_order='HWC'): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + Args: + img1 (ndarray/tensor): Images with range [0, 255]/[0, 1]. + img2 (ndarray/tensor): Images with range [0, 255]/[0, 1]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the PSNR calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + Returns: + float: psnr result. + """ + + assert img1.shape == img2.shape, ( + f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f'Wrong input_order {input_order}. Supported input_orders are ' + '"HWC" and "CHW"') + if type(img1) == torch.Tensor: + if len(img1.shape) == 4: + img1 = img1.squeeze(0) + img1 = img1.detach().cpu().numpy().transpose(1, 2, 0) + if type(img2) == torch.Tensor: + if len(img2.shape) == 4: + img2 = img2.squeeze(0) + img2 = img2.detach().cpu().numpy().transpose(1, 2, 0) + + img1 = reorder_image(img1, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + def _psnr(img1, img2): + + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + max_value = 1. if img1.max() <= 1 else 255. + return 20. * np.log10(max_value / np.sqrt(mse)) + + return _psnr(img1, img2) + + +def calculate_ssim(img1, img2, crop_border, input_order='HWC', ssim3d=True): + """Calculate SSIM (structural similarity). + Ref: + Image quality assessment: From error visibility to structural similarity + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + For three-channel images, SSIM is calculated for each channel and then + averaged. + Args: + img1 (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the SSIM calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + Returns: + float: ssim result. + """ + + assert img1.shape == img2.shape, ( + f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f'Wrong input_order {input_order}. Supported input_orders are ' + '"HWC" and "CHW"') + + if type(img1) == torch.Tensor: + if len(img1.shape) == 4: + img1 = img1.squeeze(0) + img1 = img1.detach().cpu().numpy().transpose(1, 2, 0) + if type(img2) == torch.Tensor: + if len(img2.shape) == 4: + img2 = img2.squeeze(0) + img2 = img2.detach().cpu().numpy().transpose(1, 2, 0) + + img1 = reorder_image(img1, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + def _cal_ssim(img1, img2): + ssims = [] + + max_value = 1 if img1.max() <= 1 else 255 + with torch.no_grad(): + final_ssim = _ssim_3d(img1, img2, max_value) if ssim3d else _ssim( + img1, img2, max_value) + ssims.append(final_ssim) + + return np.array(ssims).mean() + + return _cal_ssim(img1, img2) + + +def _ssim(img, img2, max_value): + """Calculate SSIM (structural similarity) for one channel images. + It is called by func:`calculate_ssim`. + Args: + img (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + Returns: + float: SSIM result. + """ + + c1 = (0.01 * max_value)**2 + c2 = (0.03 * max_value)**2 + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img, -1, window)[5:-5, + 5:-5] # valid mode for window size 11 + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + tmp1 = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2) + tmp2 = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2) + ssim_map = tmp1 / tmp2 + return ssim_map.mean() + + +def _3d_gaussian_calculator(img, conv3d): + out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) + return out + + +def _generate_3d_gaussian_kernel(): + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + kernel_3 = cv2.getGaussianKernel(11, 1.5) + kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0)) + conv3d = torch.nn.Conv3d( + 1, + 1, (11, 11, 11), + stride=1, + padding=(5, 5, 5), + bias=False, + padding_mode='replicate') + conv3d.weight.requires_grad = False + conv3d.weight[0, 0, :, :, :] = kernel + return conv3d + + +def _ssim_3d(img1, img2, max_value): + assert len(img1.shape) == 3 and len(img2.shape) == 3 + """Calculate SSIM (structural similarity) for one channel images. + It is called by func:`calculate_ssim`. + Args: + img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. + img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. + Returns: + float: ssim result. + """ + C1 = (0.01 * max_value)**2 + C2 = (0.03 * max_value)**2 + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + kernel = _generate_3d_gaussian_kernel().cuda() + + img1 = torch.tensor(img1).float().cuda() + img2 = torch.tensor(img2).float().cuda() + + mu1 = _3d_gaussian_calculator(img1, kernel) + mu2 = _3d_gaussian_calculator(img2, kernel) + + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = _3d_gaussian_calculator(img1**2, kernel) - mu1_sq + sigma2_sq = _3d_gaussian_calculator(img2**2, kernel) - mu2_sq + sigma12 = _3d_gaussian_calculator(img1 * img2, kernel) - mu1_mu2 + + tmp1 = (2 * mu1_mu2 + C1) * (2 * sigma12 + C2) + tmp2 = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) + ssim_map = tmp1 / tmp2 + return float(ssim_map.mean()) diff --git a/modelscope/metrics/image_inpainting_metric.py b/modelscope/metrics/image_inpainting_metric.py new file mode 100644 index 00000000..954d4ca2 --- /dev/null +++ b/modelscope/metrics/image_inpainting_metric.py @@ -0,0 +1,210 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +from typing import Dict + +import numpy as np +import torch +import torch.nn.functional as F +from scipy import linalg + +from modelscope.metainfo import Metrics +from modelscope.models.cv.image_inpainting.modules.inception import InceptionV3 +from modelscope.utils.registry import default_group +from modelscope.utils.tensor_utils import (torch_nested_detach, + torch_nested_numpify) +from .base import Metric +from .builder import METRICS, MetricKeys + + +def fid_calculate_activation_statistics(act): + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + +def calculate_frechet_distance(activations_pred, activations_target, eps=1e-6): + mu1, sigma1 = fid_calculate_activation_statistics(activations_pred) + mu2, sigma2 = fid_calculate_activation_statistics(activations_target) + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + # if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-2): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) + - 2 * tr_covmean) + + +class FIDScore(torch.nn.Module): + + def __init__(self, dims=2048, eps=1e-6): + super().__init__() + if getattr(FIDScore, '_MODEL', None) is None: + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + FIDScore._MODEL = InceptionV3([block_idx]).eval() + self.model = FIDScore._MODEL + self.eps = eps + self.reset() + + def forward(self, pred_batch, target_batch, mask=None): + activations_pred = self._get_activations(pred_batch) + activations_target = self._get_activations(target_batch) + + self.activations_pred.append(activations_pred.detach().cpu()) + self.activations_target.append(activations_target.detach().cpu()) + + def get_value(self): + activations_pred, activations_target = (self.activations_pred, + self.activations_target) + activations_pred = torch.cat(activations_pred).cpu().numpy() + activations_target = torch.cat(activations_target).cpu().numpy() + + total_distance = calculate_frechet_distance( + activations_pred, activations_target, eps=self.eps) + + self.reset() + return total_distance + + def reset(self): + self.activations_pred = [] + self.activations_target = [] + + def _get_activations(self, batch): + activations = self.model(batch)[0] + if activations.shape[2] != 1 or activations.shape[3] != 1: + assert False, \ + 'We should not have got here, because Inception always scales inputs to 299x299' + activations = activations.squeeze(-1).squeeze(-1) + return activations + + +class SSIM(torch.nn.Module): + """SSIM. Modified from: + https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py + """ + + def __init__(self, window_size=11, size_average=True): + super().__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.register_buffer('window', + self._create_window(window_size, self.channel)) + + def forward(self, img1, img2): + assert len(img1.shape) == 4 + + channel = img1.size()[1] + + if channel == self.channel and self.window.data.type( + ) == img1.data.type(): + window = self.window + else: + window = self._create_window(self.window_size, channel) + + window = window.type_as(img1) + + self.window = window + self.channel = channel + + return self._ssim(img1, img2, window, self.window_size, channel, + self.size_average) + + def _gaussian(self, window_size, sigma): + gauss = torch.Tensor([ + np.exp(-(x - (window_size // 2))**2 / float(2 * sigma**2)) + for x in range(window_size) + ]) + return gauss / gauss.sum() + + def _create_window(self, window_size, channel): + _1D_window = self._gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm( + _1D_window.t()).float().unsqueeze(0).unsqueeze(0) + return _2D_window.expand(channel, 1, window_size, + window_size).contiguous() + + def _ssim(self, + img1, + img2, + window, + window_size, + channel, + size_average=True): + mu1 = F.conv2d( + img1, window, padding=(window_size // 2), groups=channel) + mu2 = F.conv2d( + img2, window, padding=(window_size // 2), groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d( + img1 * img1, window, padding=(window_size // 2), + groups=channel) - mu1_sq + sigma2_sq = F.conv2d( + img2 * img2, window, padding=(window_size // 2), + groups=channel) - mu2_sq + sigma12 = F.conv2d( + img1 * img2, window, padding=(window_size // 2), + groups=channel) - mu1_mu2 + + C1 = 0.01**2 + C2 = 0.03**2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ + ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + + return ssim_map.mean(1).mean(1).mean(1) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + return + + +@METRICS.register_module( + group_key=default_group, module_name=Metrics.image_inpainting_metric) +class ImageInpaintingMetric(Metric): + """The metric computation class for image inpainting classes. + """ + + def __init__(self): + self.preds = [] + self.targets = [] + self.SSIM = SSIM(window_size=11, size_average=False).eval() + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.FID = FIDScore().to(device) + + def add(self, outputs: Dict, inputs: Dict): + pred = outputs['inpainted'] + target = inputs['image'] + self.preds.append(torch_nested_detach(pred)) + self.targets.append(torch_nested_detach(target)) + + def evaluate(self): + ssim_list = [] + for (pred, target) in zip(self.preds, self.targets): + ssim_list.append(self.SSIM(pred, target)) + self.FID(pred, target) + ssim_list = torch_nested_numpify(ssim_list) + fid = self.FID.get_value() + return {MetricKeys.SSIM: np.mean(ssim_list), MetricKeys.FID: fid} diff --git a/modelscope/metrics/image_portrait_enhancement_metric.py b/modelscope/metrics/image_portrait_enhancement_metric.py index b8412b9e..7d94aade 100644 --- a/modelscope/metrics/image_portrait_enhancement_metric.py +++ b/modelscope/metrics/image_portrait_enhancement_metric.py @@ -1,5 +1,8 @@ +# Part of the implementation is borrowed and modified from BasicSR, publicly available at +# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/metrics/psnr_ssim.py from typing import Dict +import cv2 import numpy as np from modelscope.metainfo import Metrics @@ -35,6 +38,7 @@ class ImagePortraitEnhancementMetric(Metric): def add(self, outputs: Dict, inputs: Dict): ground_truths = outputs['target'] eval_results = outputs['pred'] + self.preds.extend(eval_results) self.targets.extend(ground_truths) diff --git a/modelscope/metrics/token_classification_metric.py b/modelscope/metrics/token_classification_metric.py index 05b72170..f8595fc1 100644 --- a/modelscope/metrics/token_classification_metric.py +++ b/modelscope/metrics/token_classification_metric.py @@ -34,17 +34,24 @@ class TokenClassificationMetric(Metric): self.labels.append( torch_nested_numpify(torch_nested_detach(ground_truths))) - def __init__(self, return_entity_level_metrics=False, *args, **kwargs): + def __init__(self, + return_entity_level_metrics=False, + label2id=None, + *args, + **kwargs): super().__init__(*args, **kwargs) self.return_entity_level_metrics = return_entity_level_metrics self.preds = [] self.labels = [] + self.label2id = label2id def evaluate(self): - self.id2label = { - id: label - for label, id in self.trainer.label2id.items() - } + label2id = self.label2id + if label2id is None: + assert hasattr(self, 'trainer') + label2id = self.trainer.label2id + + self.id2label = {id: label for label, id in label2id.items()} self.preds = np.concatenate(self.preds, axis=0) self.labels = np.concatenate(self.labels, axis=0) predictions = np.argmax(self.preds, axis=-1) diff --git a/modelscope/metrics/video_summarization_metric.py b/modelscope/metrics/video_summarization_metric.py index d1867600..40580382 100644 --- a/modelscope/metrics/video_summarization_metric.py +++ b/modelscope/metrics/video_summarization_metric.py @@ -1,3 +1,6 @@ +# Part of the implementation is borrowed and modified from PGL-SUM, +# publicly available at https://github.com/e-apostolidis/PGL-SUM + from typing import Dict import numpy as np diff --git a/modelscope/models/audio/asr/generic_automatic_speech_recognition.py b/modelscope/models/audio/asr/generic_automatic_speech_recognition.py index 11accf0a..aebc6751 100644 --- a/modelscope/models/audio/asr/generic_automatic_speech_recognition.py +++ b/modelscope/models/audio/asr/generic_automatic_speech_recognition.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os from typing import Any, Dict diff --git a/modelscope/models/audio/kws/farfield/model.py b/modelscope/models/audio/kws/farfield/model.py index fea82194..d63d1e2a 100644 --- a/modelscope/models/audio/kws/farfield/model.py +++ b/modelscope/models/audio/kws/farfield/model.py @@ -1,15 +1,14 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os -from typing import Dict - -import torch +from typing import Dict, Optional from modelscope.metainfo import Models from modelscope.models import TorchModel from modelscope.models.base import Tensor from modelscope.models.builder import MODELS -from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.audio.audio_utils import update_conf +from modelscope.utils.constant import Tasks from .fsmn_sele_v2 import FSMNSeleNetV2 @@ -20,48 +19,38 @@ class FSMNSeleNetV2Decorator(TorchModel): MODEL_TXT = 'model.txt' SC_CONFIG = 'sound_connect.conf' - SC_CONF_ITEM_KWS_MODEL = '${kws_model}' - def __init__(self, model_dir: str, *args, **kwargs): + def __init__(self, + model_dir: str, + training: Optional[bool] = False, + *args, + **kwargs): """initialize the dfsmn model from the `model_dir` path. Args: model_dir (str): the model path. """ super().__init__(model_dir, *args, **kwargs) - sc_config_file = os.path.join(model_dir, self.SC_CONFIG) - model_txt_file = os.path.join(model_dir, self.MODEL_TXT) - model_bin_file = os.path.join(model_dir, - ModelFile.TORCH_MODEL_BIN_FILE) - self._model = None - if os.path.exists(model_bin_file): - kwargs.pop('device') - self._model = FSMNSeleNetV2(*args, **kwargs) - checkpoint = torch.load(model_bin_file) - self._model.load_state_dict(checkpoint, strict=False) - - self._sc = None - if os.path.exists(model_txt_file): - with open(sc_config_file) as f: - lines = f.readlines() - with open(sc_config_file, 'w') as f: - for line in lines: - if self.SC_CONF_ITEM_KWS_MODEL in line: - line = line.replace(self.SC_CONF_ITEM_KWS_MODEL, - model_txt_file) - f.write(line) - import py_sound_connect - self._sc = py_sound_connect.SoundConnect(sc_config_file) - self.size_in = self._sc.bytesPerBlockIn() - self.size_out = self._sc.bytesPerBlockOut() - - if self._model is None and self._sc is None: - raise Exception( - f'Invalid model directory! Neither {model_txt_file} nor {model_bin_file} exists.' - ) + if training: + self.model = FSMNSeleNetV2(*args, **kwargs) + else: + sc_config_file = os.path.join(model_dir, self.SC_CONFIG) + model_txt_file = os.path.join(model_dir, self.MODEL_TXT) + self._sc = None + if os.path.exists(model_txt_file): + conf_dict = dict(mode=56542, kws_model=model_txt_file) + update_conf(sc_config_file, sc_config_file, conf_dict) + import py_sound_connect + self._sc = py_sound_connect.SoundConnect(sc_config_file) + self.size_in = self._sc.bytesPerBlockIn() + self.size_out = self._sc.bytesPerBlockOut() + else: + raise Exception( + f'Invalid model directory! Failed to load model file: {model_txt_file}.' + ) def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: - ... + return self.model.forward(input) def forward_decode(self, data: bytes): result = {'pcm': self._sc.process(data, self.size_out)} diff --git a/modelscope/models/audio/kws/generic_key_word_spotting.py b/modelscope/models/audio/kws/generic_key_word_spotting.py index c1b7a0e4..2f70327d 100644 --- a/modelscope/models/audio/kws/generic_key_word_spotting.py +++ b/modelscope/models/audio/kws/generic_key_word_spotting.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os from typing import Any, Dict diff --git a/modelscope/models/audio/tts/models/datasets/__init__.py b/modelscope/models/audio/tts/models/datasets/__init__.py old mode 100755 new mode 100644 diff --git a/modelscope/models/audio/tts/voice.py b/modelscope/models/audio/tts/voice.py index dc830db5..b7240088 100644 --- a/modelscope/models/audio/tts/voice.py +++ b/modelscope/models/audio/tts/voice.py @@ -2,6 +2,7 @@ import os import pickle as pkl +from threading import Lock import json import numpy as np @@ -27,6 +28,7 @@ class Voice: self.__am_config = AttrDict(**am_config) self.__voc_config = AttrDict(**voc_config) self.__model_loaded = False + self.__lock = Lock() if 'am' not in self.__am_config: raise TtsModelConfigurationException( 'modelscope error: am configuration invalid') @@ -71,34 +73,35 @@ class Voice: self.__generator.remove_weight_norm() def __am_forward(self, symbol_seq): - with torch.no_grad(): - inputs_feat_lst = self.__ling_unit.encode_symbol_sequence( - symbol_seq) - inputs_sy = torch.from_numpy(inputs_feat_lst[0]).long().to( - self.__device) - inputs_tone = torch.from_numpy(inputs_feat_lst[1]).long().to( - self.__device) - inputs_syllable = torch.from_numpy(inputs_feat_lst[2]).long().to( - self.__device) - inputs_ws = torch.from_numpy(inputs_feat_lst[3]).long().to( - self.__device) - inputs_ling = torch.stack( - [inputs_sy, inputs_tone, inputs_syllable, inputs_ws], - dim=-1).unsqueeze(0) - inputs_emo = torch.from_numpy(inputs_feat_lst[4]).long().to( - self.__device).unsqueeze(0) - inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to( - self.__device).unsqueeze(0) - inputs_len = torch.zeros(1).to(self.__device).long( - ) + inputs_emo.size(1) - 1 # minus 1 for "~" - res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1], - inputs_spk[:, :-1], inputs_len) - postnet_outputs = res['postnet_outputs'] - LR_length_rounded = res['LR_length_rounded'] - valid_length = int(LR_length_rounded[0].item()) - postnet_outputs = postnet_outputs[ - 0, :valid_length, :].cpu().numpy() - return postnet_outputs + with self.__lock: + with torch.no_grad(): + inputs_feat_lst = self.__ling_unit.encode_symbol_sequence( + symbol_seq) + inputs_sy = torch.from_numpy(inputs_feat_lst[0]).long().to( + self.__device) + inputs_tone = torch.from_numpy(inputs_feat_lst[1]).long().to( + self.__device) + inputs_syllable = torch.from_numpy( + inputs_feat_lst[2]).long().to(self.__device) + inputs_ws = torch.from_numpy(inputs_feat_lst[3]).long().to( + self.__device) + inputs_ling = torch.stack( + [inputs_sy, inputs_tone, inputs_syllable, inputs_ws], + dim=-1).unsqueeze(0) + inputs_emo = torch.from_numpy(inputs_feat_lst[4]).long().to( + self.__device).unsqueeze(0) + inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to( + self.__device).unsqueeze(0) + inputs_len = torch.zeros(1).to(self.__device).long( + ) + inputs_emo.size(1) - 1 # minus 1 for "~" + res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1], + inputs_spk[:, :-1], inputs_len) + postnet_outputs = res['postnet_outputs'] + LR_length_rounded = res['LR_length_rounded'] + valid_length = int(LR_length_rounded[0].item()) + postnet_outputs = postnet_outputs[ + 0, :valid_length, :].cpu().numpy() + return postnet_outputs def __vocoder_forward(self, melspec): dim0 = list(melspec.shape)[-1] @@ -118,14 +121,15 @@ class Voice: return audio def forward(self, symbol_seq): - if not self.__model_loaded: - torch.manual_seed(self.__am_config.seed) - if torch.cuda.is_available(): + with self.__lock: + if not self.__model_loaded: torch.manual_seed(self.__am_config.seed) - self.__device = torch.device('cuda') - else: - self.__device = torch.device('cpu') - self.__load_am() - self.__load_vocoder() - self.__model_loaded = True + if torch.cuda.is_available(): + torch.manual_seed(self.__am_config.seed) + self.__device = torch.device('cuda') + else: + self.__device = torch.device('cpu') + self.__load_am() + self.__load_vocoder() + self.__model_loaded = True return self.__vocoder_forward(self.__am_forward(symbol_seq)) diff --git a/modelscope/models/base/base_model.py b/modelscope/models/base/base_model.py index cdc71fcf..1246551e 100644 --- a/modelscope/models/base/base_model.py +++ b/modelscope/models/base/base_model.py @@ -5,11 +5,11 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional, Union from modelscope.hub.snapshot_download import snapshot_download -from modelscope.models.builder import build_model -from modelscope.utils.checkpoint import save_pretrained +from modelscope.models.builder import MODELS, build_model +from modelscope.utils.checkpoint import save_checkpoint, save_pretrained from modelscope.utils.config import Config -from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile -from modelscope.utils.device import device_placement, verify_device +from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile, Tasks +from modelscope.utils.device import verify_device from modelscope.utils.logger import get_logger logger = get_logger() @@ -66,7 +66,6 @@ class Model(ABC): revision: Optional[str] = DEFAULT_MODEL_REVISION, cfg_dict: Config = None, device: str = None, - *model_args, **kwargs): """ Instantiate a model from local directory or remote model repo. Note that when loading from remote, the model revision can be specified. @@ -90,11 +89,11 @@ class Model(ABC): cfg = Config.from_file( osp.join(local_model_dir, ModelFile.CONFIGURATION)) task_name = cfg.task + if 'task' in kwargs: + task_name = kwargs.pop('task') model_cfg = cfg.model - if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): model_cfg.type = model_cfg.model_type - model_cfg.model_dir = local_model_dir for k, v in kwargs.items(): model_cfg[k] = v @@ -109,15 +108,19 @@ class Model(ABC): # dynamically add pipeline info to model for pipeline inference if hasattr(cfg, 'pipeline'): model.pipeline = cfg.pipeline + + if not hasattr(model, 'cfg'): + model.cfg = cfg return model def save_pretrained(self, target_folder: Union[str, os.PathLike], save_checkpoint_names: Union[str, List[str]] = None, - save_function: Callable = None, + save_function: Callable = save_checkpoint, config: Optional[dict] = None, **kwargs): - """save the pretrained model, its configuration and other related files to a directory, so that it can be re-loaded + """save the pretrained model, its configuration and other related files to a directory, + so that it can be re-loaded Args: target_folder (Union[str, os.PathLike]): @@ -133,5 +136,10 @@ class Model(ABC): The config for the configuration.json, might not be identical with model.config """ + if config is None and hasattr(self, 'cfg'): + config = self.cfg + assert config is not None, 'Cannot save the model because the model config is empty.' + if isinstance(config, Config): + config = config.to_dict() save_pretrained(self, target_folder, save_checkpoint_names, save_function, config, **kwargs) diff --git a/modelscope/models/builder.py b/modelscope/models/builder.py index 7a8e28f4..2804c6c7 100644 --- a/modelscope/models/builder.py +++ b/modelscope/models/builder.py @@ -1,12 +1,20 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from modelscope.utils.config import ConfigDict +from modelscope.utils.constant import Tasks +from modelscope.utils.import_utils import INDEX_KEY, LazyImportModule from modelscope.utils.registry import TYPE_NAME, Registry, build_from_cfg MODELS = Registry('models') -BACKBONES = Registry('backbones') +BACKBONES = MODELS HEADS = Registry('heads') +modules = LazyImportModule.AST_INDEX[INDEX_KEY] +for module_index in list(modules.keys()): + if module_index[1] == Tasks.backbone and module_index[0] == 'BACKBONES': + modules[(MODELS.name.upper(), module_index[1], + module_index[2])] = modules[module_index] + def build_model(cfg: ConfigDict, task_name: str = None, @@ -23,30 +31,27 @@ def build_model(cfg: ConfigDict, cfg, MODELS, group_key=task_name, default_args=default_args) -def build_backbone(cfg: ConfigDict, - field: str = None, - default_args: dict = None): +def build_backbone(cfg: ConfigDict, default_args: dict = None): """ build backbone given backbone config dict Args: cfg (:obj:`ConfigDict`): config dict for backbone object. - field (str, optional): field, such as CV, NLP's backbone default_args (dict, optional): Default initialization arguments. """ return build_from_cfg( - cfg, BACKBONES, group_key=field, default_args=default_args) + cfg, BACKBONES, group_key=Tasks.backbone, default_args=default_args) def build_head(cfg: ConfigDict, - group_key: str = None, + task_name: str = None, default_args: dict = None): """ build head given config dict Args: cfg (:obj:`ConfigDict`): config dict for head object. + task_name (str, optional): task name, refer to + :obj:`Tasks` for more details default_args (dict, optional): Default initialization arguments. """ - if group_key is None: - group_key = cfg[TYPE_NAME] return build_from_cfg( - cfg, HEADS, group_key=group_key, default_args=default_args) + cfg, HEADS, group_key=task_name, default_args=default_args) diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index f2798b59..64039863 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -4,14 +4,16 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, body_3d_keypoints, cartoon, cmdssl_video_embedding, crowd_counting, face_2d_keypoints, face_detection, - face_generation, image_classification, image_color_enhance, - image_colorization, image_denoise, image_instance_segmentation, + face_generation, human_wholebody_keypoint, image_classification, + image_color_enhance, image_colorization, image_denoise, + image_inpainting, image_instance_segmentation, image_panoptic_segmentation, image_portrait_enhancement, image_reid_person, image_semantic_segmentation, image_to_image_generation, image_to_image_translation, movie_scene_segmentation, object_detection, product_retrieval_embedding, realtime_object_detection, - salient_detection, shop_segmentation, super_resolution, + referring_video_object_segmentation, salient_detection, + shop_segmentation, super_resolution, video_single_object_tracking, video_summarization, virual_tryon) # yapf: enable diff --git a/modelscope/models/cv/action_detection/action_detection_onnx.py b/modelscope/models/cv/action_detection/action_detection_onnx.py index 1c8be354..223d77f7 100644 --- a/modelscope/models/cv/action_detection/action_detection_onnx.py +++ b/modelscope/models/cv/action_detection/action_detection_onnx.py @@ -4,6 +4,7 @@ import os import os.path as osp import shutil import subprocess +import uuid import cv2 import numpy as np @@ -84,7 +85,9 @@ class ActionDetONNX(Model): def forward_video(self, video_name, scale): min_size, max_size = self._get_sizes(scale) - tmp_dir = osp.join(self.tmp_dir, osp.basename(video_name)[:-4]) + tmp_dir = osp.join( + self.tmp_dir, + str(uuid.uuid1()) + '_' + osp.basename(video_name)[:-4]) if osp.exists(tmp_dir): shutil.rmtree(tmp_dir) os.makedirs(tmp_dir) @@ -110,6 +113,7 @@ class ActionDetONNX(Model): len(frame_names) * self.temporal_stride, self.temporal_stride)) batch_imgs = [self.parse_frames(names) for names in frame_names] + shutil.rmtree(tmp_dir) N, _, T, H, W = batch_imgs[0].shape scale_min = min_size / min(H, W) @@ -128,7 +132,6 @@ class ActionDetONNX(Model): 'timestamp': t, 'actions': res } for t, res in zip(timestamp, results)] - shutil.rmtree(tmp_dir) return results def forward(self, video_name): diff --git a/modelscope/models/cv/crowd_counting/cc_model.py b/modelscope/models/cv/crowd_counting/cc_model.py index 582b26f4..16fbc261 100644 --- a/modelscope/models/cv/crowd_counting/cc_model.py +++ b/modelscope/models/cv/crowd_counting/cc_model.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os from typing import Any, Dict, Optional, Union diff --git a/modelscope/models/cv/crowd_counting/hrnet_aspp_relu.py b/modelscope/models/cv/crowd_counting/hrnet_aspp_relu.py index 982ba939..0d1bd3ca 100644 --- a/modelscope/models/cv/crowd_counting/hrnet_aspp_relu.py +++ b/modelscope/models/cv/crowd_counting/hrnet_aspp_relu.py @@ -1,10 +1,10 @@ -# ------------------------------------------------------------------------------ -# Copyright (c) Microsoft -# Licensed under the MIT License. -# Written by Bin Xiao (Bin.Xiao@microsoft.com) -# Modified by Ke Sun (sunk@mail.ustc.edu.cn) -# https://github.com/HRNet/HRNet-Image-Classification/blob/master/lib/models/cls_hrnet.py -# ------------------------------------------------------------------------------ +""" +Copyright (c) Microsoft +Licensed under the MIT License. +Written by Bin Xiao (Bin.Xiao@microsoft.com) +Modified by Ke Sun (sunk@mail.ustc.edu.cn) +https://github.com/HRNet/HRNet-Image-Classification/blob/master/lib/models/cls_hrnet.py +""" import functools import logging diff --git a/modelscope/models/cv/face_detection/__init__.py b/modelscope/models/cv/face_detection/__init__.py index a2a845d2..27d1bd4c 100644 --- a/modelscope/models/cv/face_detection/__init__.py +++ b/modelscope/models/cv/face_detection/__init__.py @@ -8,12 +8,14 @@ if TYPE_CHECKING: from .mtcnn import MtcnnFaceDetector from .retinaface import RetinaFaceDetection from .ulfd_slim import UlfdFaceDetector + from .scrfd import ScrfdDetect else: _import_structure = { 'ulfd_slim': ['UlfdFaceDetector'], 'retinaface': ['RetinaFaceDetection'], 'mtcnn': ['MtcnnFaceDetector'], - 'mogface': ['MogFaceDetector'] + 'mogface': ['MogFaceDetector'], + 'scrfd': ['ScrfdDetect'] } import sys diff --git a/modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/transforms.py b/modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/transforms.py deleted file mode 100755 index 241f2c0e..00000000 --- a/modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/transforms.py +++ /dev/null @@ -1,189 +0,0 @@ -""" -The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at -https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/transforms.py -""" -import numpy as np -from mmdet.datasets.builder import PIPELINES -from numpy import random - - -@PIPELINES.register_module() -class RandomSquareCrop(object): - """Random crop the image & bboxes, the cropped patches have minimum IoU - requirement with original image & bboxes, the IoU threshold is randomly - selected from min_ious. - - Args: - min_ious (tuple): minimum IoU threshold for all intersections with - bounding boxes - min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w, - where a >= min_crop_size). - - Note: - The keys for bboxes, labels and masks should be paired. That is, \ - `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \ - `gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`. - """ - - def __init__(self, - crop_ratio_range=None, - crop_choice=None, - bbox_clip_border=True): - - self.crop_ratio_range = crop_ratio_range - self.crop_choice = crop_choice - self.bbox_clip_border = bbox_clip_border - - assert (self.crop_ratio_range is None) ^ (self.crop_choice is None) - if self.crop_ratio_range is not None: - self.crop_ratio_min, self.crop_ratio_max = self.crop_ratio_range - - self.bbox2label = { - 'gt_bboxes': 'gt_labels', - 'gt_bboxes_ignore': 'gt_labels_ignore' - } - self.bbox2mask = { - 'gt_bboxes': 'gt_masks', - 'gt_bboxes_ignore': 'gt_masks_ignore' - } - - def __call__(self, results): - """Call function to crop images and bounding boxes with minimum IoU - constraint. - - Args: - results (dict): Result dict from loading pipeline. - - Returns: - dict: Result dict with images and bounding boxes cropped, \ - 'img_shape' key is updated. - """ - - if 'img_fields' in results: - assert results['img_fields'] == ['img'], \ - 'Only single img_fields is allowed' - img = results['img'] - assert 'bbox_fields' in results - assert 'gt_bboxes' in results - boxes = results['gt_bboxes'] - h, w, c = img.shape - scale_retry = 0 - if self.crop_ratio_range is not None: - max_scale = self.crop_ratio_max - else: - max_scale = np.amax(self.crop_choice) - while True: - scale_retry += 1 - - if scale_retry == 1 or max_scale > 1.0: - if self.crop_ratio_range is not None: - scale = np.random.uniform(self.crop_ratio_min, - self.crop_ratio_max) - elif self.crop_choice is not None: - scale = np.random.choice(self.crop_choice) - else: - scale = scale * 1.2 - - for i in range(250): - short_side = min(w, h) - cw = int(scale * short_side) - ch = cw - - # TODO +1 - if w == cw: - left = 0 - elif w > cw: - left = random.randint(0, w - cw) - else: - left = random.randint(w - cw, 0) - if h == ch: - top = 0 - elif h > ch: - top = random.randint(0, h - ch) - else: - top = random.randint(h - ch, 0) - - patch = np.array( - (int(left), int(top), int(left + cw), int(top + ch)), - dtype=np.int) - - # center of boxes should inside the crop img - # only adjust boxes and instance masks when the gt is not empty - # adjust boxes - def is_center_of_bboxes_in_patch(boxes, patch): - # TODO >= - center = (boxes[:, :2] + boxes[:, 2:]) / 2 - mask = \ - ((center[:, 0] > patch[0]) - * (center[:, 1] > patch[1]) - * (center[:, 0] < patch[2]) - * (center[:, 1] < patch[3])) - return mask - - mask = is_center_of_bboxes_in_patch(boxes, patch) - if not mask.any(): - continue - for key in results.get('bbox_fields', []): - boxes = results[key].copy() - mask = is_center_of_bboxes_in_patch(boxes, patch) - boxes = boxes[mask] - if self.bbox_clip_border: - boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) - boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) - boxes -= np.tile(patch[:2], 2) - - results[key] = boxes - # labels - label_key = self.bbox2label.get(key) - if label_key in results: - results[label_key] = results[label_key][mask] - - # keypoints field - if key == 'gt_bboxes': - for kps_key in results.get('keypoints_fields', []): - keypointss = results[kps_key].copy() - keypointss = keypointss[mask, :, :] - if self.bbox_clip_border: - keypointss[:, :, : - 2] = keypointss[:, :, :2].clip( - max=patch[2:]) - keypointss[:, :, : - 2] = keypointss[:, :, :2].clip( - min=patch[:2]) - keypointss[:, :, 0] -= patch[0] - keypointss[:, :, 1] -= patch[1] - results[kps_key] = keypointss - - # mask fields - mask_key = self.bbox2mask.get(key) - if mask_key in results: - results[mask_key] = results[mask_key][mask.nonzero() - [0]].crop(patch) - - # adjust the img no matter whether the gt is empty before crop - rimg = np.ones((ch, cw, 3), dtype=img.dtype) * 128 - patch_from = patch.copy() - patch_from[0] = max(0, patch_from[0]) - patch_from[1] = max(0, patch_from[1]) - patch_from[2] = min(img.shape[1], patch_from[2]) - patch_from[3] = min(img.shape[0], patch_from[3]) - patch_to = patch.copy() - patch_to[0] = max(0, patch_to[0] * -1) - patch_to[1] = max(0, patch_to[1] * -1) - patch_to[2] = patch_to[0] + (patch_from[2] - patch_from[0]) - patch_to[3] = patch_to[1] + (patch_from[3] - patch_from[1]) - rimg[patch_to[1]:patch_to[3], - patch_to[0]:patch_to[2], :] = img[ - patch_from[1]:patch_from[3], - patch_from[0]:patch_from[2], :] - img = rimg - results['img'] = img - results['img_shape'] = img.shape - - return results - - def __repr__(self): - repr_str = self.__class__.__name__ - repr_str += f'(min_ious={self.min_iou}, ' - repr_str += f'crop_size={self.crop_size})' - return repr_str diff --git a/modelscope/models/cv/face_detection/mogface/models/detectors.py b/modelscope/models/cv/face_detection/mogface/models/detectors.py index 5ae67104..8c1d9150 100644 --- a/modelscope/models/cv/face_detection/mogface/models/detectors.py +++ b/modelscope/models/cv/face_detection/mogface/models/detectors.py @@ -1,3 +1,5 @@ +# The implementation is based on MogFace, available at +# https://github.com/damo-cv/MogFace import os import cv2 diff --git a/modelscope/models/cv/face_detection/scrfd/__init__.py b/modelscope/models/cv/face_detection/scrfd/__init__.py new file mode 100644 index 00000000..92f81f7a --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .scrfd_detect import ScrfdDetect diff --git a/modelscope/models/cv/face_detection/mmdet_patch/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/__init__.py similarity index 100% rename from modelscope/models/cv/face_detection/mmdet_patch/__init__.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/__init__.py diff --git a/modelscope/models/cv/face_detection/mmdet_patch/core/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/__init__.py similarity index 100% rename from modelscope/models/cv/face_detection/mmdet_patch/core/__init__.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/__init__.py diff --git a/modelscope/models/cv/face_detection/mmdet_patch/core/bbox/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/bbox/__init__.py similarity index 100% rename from modelscope/models/cv/face_detection/mmdet_patch/core/bbox/__init__.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/bbox/__init__.py diff --git a/modelscope/models/cv/face_detection/mmdet_patch/core/bbox/transforms.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/bbox/transforms.py similarity index 94% rename from modelscope/models/cv/face_detection/mmdet_patch/core/bbox/transforms.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/bbox/transforms.py index d65480eb..75e32d85 100755 --- a/modelscope/models/cv/face_detection/mmdet_patch/core/bbox/transforms.py +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/bbox/transforms.py @@ -6,7 +6,7 @@ import numpy as np import torch -def bbox2result(bboxes, labels, num_classes, kps=None): +def bbox2result(bboxes, labels, num_classes, kps=None, num_kps=5): """Convert detection results to a list of numpy arrays. Args: @@ -17,7 +17,7 @@ def bbox2result(bboxes, labels, num_classes, kps=None): Returns: list(ndarray): bbox results of each class """ - bbox_len = 5 if kps is None else 5 + 10 # if has kps, add 10 kps into bbox + bbox_len = 5 if kps is None else 5 + num_kps * 2 # if has kps, add num_kps*2 into bbox if bboxes.shape[0] == 0: return [ np.zeros((0, bbox_len), dtype=np.float32) diff --git a/modelscope/models/cv/face_detection/mmdet_patch/core/post_processing/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/post_processing/__init__.py similarity index 100% rename from modelscope/models/cv/face_detection/mmdet_patch/core/post_processing/__init__.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/post_processing/__init__.py diff --git a/modelscope/models/cv/face_detection/mmdet_patch/core/post_processing/bbox_nms.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/post_processing/bbox_nms.py similarity index 89% rename from modelscope/models/cv/face_detection/mmdet_patch/core/post_processing/bbox_nms.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/post_processing/bbox_nms.py index 7a4f5b3a..697b7338 100644 --- a/modelscope/models/cv/face_detection/mmdet_patch/core/post_processing/bbox_nms.py +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/post_processing/bbox_nms.py @@ -17,6 +17,7 @@ def multiclass_nms(multi_bboxes, Args: multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) + multi_kps (Tensor): shape (n, #class*num_kps*2) or (n, num_kps*2) multi_scores (Tensor): shape (n, #class), where the last column contains scores of the background class, but this will be ignored. score_thr (float): bbox threshold, bboxes with scores lower than it @@ -36,16 +37,18 @@ def multiclass_nms(multi_bboxes, num_classes = multi_scores.size(1) - 1 # exclude background category kps = None + if multi_kps is not None: + num_kps = int((multi_kps.shape[1] / num_classes) / 2) if multi_bboxes.shape[1] > 4: bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) if multi_kps is not None: - kps = multi_kps.view(multi_scores.size(0), -1, 10) + kps = multi_kps.view(multi_scores.size(0), -1, num_kps * 2) else: bboxes = multi_bboxes[:, None].expand( multi_scores.size(0), num_classes, 4) if multi_kps is not None: kps = multi_kps[:, None].expand( - multi_scores.size(0), num_classes, 10) + multi_scores.size(0), num_classes, num_kps * 2) scores = multi_scores[:, :-1] if score_factors is not None: @@ -56,7 +59,7 @@ def multiclass_nms(multi_bboxes, bboxes = bboxes.reshape(-1, 4) if kps is not None: - kps = kps.reshape(-1, 10) + kps = kps.reshape(-1, num_kps * 2) scores = scores.reshape(-1) labels = labels.reshape(-1) diff --git a/modelscope/models/cv/face_detection/mmdet_patch/datasets/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/__init__.py similarity index 100% rename from modelscope/models/cv/face_detection/mmdet_patch/datasets/__init__.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/__init__.py diff --git a/modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/__init__.py similarity index 53% rename from modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/__init__.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/__init__.py index 85288910..a2cafd1a 100755 --- a/modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/__init__.py +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/__init__.py @@ -2,6 +2,12 @@ The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines """ +from .auto_augment import RotateV2 +from .formating import DefaultFormatBundleV2 +from .loading import LoadAnnotationsV2 from .transforms import RandomSquareCrop -__all__ = ['RandomSquareCrop'] +__all__ = [ + 'RandomSquareCrop', 'LoadAnnotationsV2', 'RotateV2', + 'DefaultFormatBundleV2' +] diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/auto_augment.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/auto_augment.py new file mode 100644 index 00000000..ee60c2e0 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/auto_augment.py @@ -0,0 +1,271 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/auto_augment.py +""" +import copy + +import cv2 +import mmcv +import numpy as np +from mmdet.datasets.builder import PIPELINES + +_MAX_LEVEL = 10 + + +def level_to_value(level, max_value): + """Map from level to values based on max_value.""" + return (level / _MAX_LEVEL) * max_value + + +def random_negative(value, random_negative_prob): + """Randomly negate value based on random_negative_prob.""" + return -value if np.random.rand() < random_negative_prob else value + + +def bbox2fields(): + """The key correspondence from bboxes to labels, masks and + segmentations.""" + bbox2label = { + 'gt_bboxes': 'gt_labels', + 'gt_bboxes_ignore': 'gt_labels_ignore' + } + bbox2mask = { + 'gt_bboxes': 'gt_masks', + 'gt_bboxes_ignore': 'gt_masks_ignore' + } + bbox2seg = { + 'gt_bboxes': 'gt_semantic_seg', + } + return bbox2label, bbox2mask, bbox2seg + + +@PIPELINES.register_module() +class RotateV2(object): + """Apply Rotate Transformation to image (and its corresponding bbox, mask, + segmentation). + + Args: + level (int | float): The level should be in range (0,_MAX_LEVEL]. + scale (int | float): Isotropic scale factor. Same in + ``mmcv.imrotate``. + center (int | float | tuple[float]): Center point (w, h) of the + rotation in the source image. If None, the center of the + image will be used. Same in ``mmcv.imrotate``. + img_fill_val (int | float | tuple): The fill value for image border. + If float, the same value will be used for all the three + channels of image. If tuple, the should be 3 elements (e.g. + equals the number of channels for image). + seg_ignore_label (int): The fill value used for segmentation map. + Note this value must equals ``ignore_label`` in ``semantic_head`` + of the corresponding config. Default 255. + prob (float): The probability for perform transformation and + should be in range 0 to 1. + max_rotate_angle (int | float): The maximum angles for rotate + transformation. + random_negative_prob (float): The probability that turns the + offset negative. + """ + + def __init__(self, + level, + scale=1, + center=None, + img_fill_val=128, + seg_ignore_label=255, + prob=0.5, + max_rotate_angle=30, + random_negative_prob=0.5): + assert isinstance(level, (int, float)), \ + f'The level must be type int or float. got {type(level)}.' + assert 0 <= level <= _MAX_LEVEL, \ + f'The level should be in range (0,{_MAX_LEVEL}]. got {level}.' + assert isinstance(scale, (int, float)), \ + f'The scale must be type int or float. got type {type(scale)}.' + if isinstance(center, (int, float)): + center = (center, center) + elif isinstance(center, tuple): + assert len(center) == 2, 'center with type tuple must have '\ + f'2 elements. got {len(center)} elements.' + else: + assert center is None, 'center must be None or type int, '\ + f'float or tuple, got type {type(center)}.' + if isinstance(img_fill_val, (float, int)): + img_fill_val = tuple([float(img_fill_val)] * 3) + elif isinstance(img_fill_val, tuple): + assert len(img_fill_val) == 3, 'img_fill_val as tuple must '\ + f'have 3 elements. got {len(img_fill_val)}.' + img_fill_val = tuple([float(val) for val in img_fill_val]) + else: + raise ValueError( + 'img_fill_val must be float or tuple with 3 elements.') + assert np.all([0 <= val <= 255 for val in img_fill_val]), \ + 'all elements of img_fill_val should between range [0,255]. '\ + f'got {img_fill_val}.' + assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. '\ + f'got {prob}.' + assert isinstance(max_rotate_angle, (int, float)), 'max_rotate_angle '\ + f'should be type int or float. got type {type(max_rotate_angle)}.' + self.level = level + self.scale = scale + # Rotation angle in degrees. Positive values mean + # clockwise rotation. + self.angle = level_to_value(level, max_rotate_angle) + self.center = center + self.img_fill_val = img_fill_val + self.seg_ignore_label = seg_ignore_label + self.prob = prob + self.max_rotate_angle = max_rotate_angle + self.random_negative_prob = random_negative_prob + + def _rotate_img(self, results, angle, center=None, scale=1.0): + """Rotate the image. + + Args: + results (dict): Result dict from loading pipeline. + angle (float): Rotation angle in degrees, positive values + mean clockwise rotation. Same in ``mmcv.imrotate``. + center (tuple[float], optional): Center point (w, h) of the + rotation. Same in ``mmcv.imrotate``. + scale (int | float): Isotropic scale factor. Same in + ``mmcv.imrotate``. + """ + for key in results.get('img_fields', ['img']): + img = results[key].copy() + img_rotated = mmcv.imrotate( + img, angle, center, scale, border_value=self.img_fill_val) + results[key] = img_rotated.astype(img.dtype) + results['img_shape'] = results[key].shape + + def _rotate_bboxes(self, results, rotate_matrix): + """Rotate the bboxes.""" + h, w, c = results['img_shape'] + for key in results.get('bbox_fields', []): + min_x, min_y, max_x, max_y = np.split( + results[key], results[key].shape[-1], axis=-1) + coordinates = np.stack([[min_x, min_y], [max_x, min_y], + [min_x, max_y], + [max_x, max_y]]) # [4, 2, nb_bbox, 1] + # pad 1 to convert from format [x, y] to homogeneous + # coordinates format [x, y, 1] + coordinates = np.concatenate( + (coordinates, + np.ones((4, 1, coordinates.shape[2], 1), coordinates.dtype)), + axis=1) # [4, 3, nb_bbox, 1] + coordinates = coordinates.transpose( + (2, 0, 1, 3)) # [nb_bbox, 4, 3, 1] + rotated_coords = np.matmul(rotate_matrix, + coordinates) # [nb_bbox, 4, 2, 1] + rotated_coords = rotated_coords[..., 0] # [nb_bbox, 4, 2] + min_x, min_y = np.min( + rotated_coords[:, :, 0], axis=1), np.min( + rotated_coords[:, :, 1], axis=1) + max_x, max_y = np.max( + rotated_coords[:, :, 0], axis=1), np.max( + rotated_coords[:, :, 1], axis=1) + results[key] = np.stack([min_x, min_y, max_x, max_y], + axis=-1).astype(results[key].dtype) + + def _rotate_keypoints90(self, results, angle): + """Rotate the keypoints, only valid when angle in [-90,90,-180,180]""" + if angle not in [-90, 90, 180, -180 + ] or self.scale != 1 or self.center is not None: + return + for key in results.get('keypoints_fields', []): + k = results[key] + if angle == 90: + w, h, c = results['img'].shape + new = np.stack([h - k[..., 1], k[..., 0], k[..., 2]], axis=-1) + elif angle == -90: + w, h, c = results['img'].shape + new = np.stack([k[..., 1], w - k[..., 0], k[..., 2]], axis=-1) + else: + h, w, c = results['img'].shape + new = np.stack([w - k[..., 0], h - k[..., 1], k[..., 2]], + axis=-1) + # a kps is invalid if thrid value is -1 + kps_invalid = new[..., -1][:, -1] == -1 + new[kps_invalid] = np.zeros(new.shape[1:]) - 1 + results[key] = new + + def _rotate_masks(self, + results, + angle, + center=None, + scale=1.0, + fill_val=0): + """Rotate the masks.""" + h, w, c = results['img_shape'] + for key in results.get('mask_fields', []): + masks = results[key] + results[key] = masks.rotate((h, w), angle, center, scale, fill_val) + + def _rotate_seg(self, + results, + angle, + center=None, + scale=1.0, + fill_val=255): + """Rotate the segmentation map.""" + for key in results.get('seg_fields', []): + seg = results[key].copy() + results[key] = mmcv.imrotate( + seg, angle, center, scale, + border_value=fill_val).astype(seg.dtype) + + def _filter_invalid(self, results, min_bbox_size=0): + """Filter bboxes and corresponding masks too small after rotate + augmentation.""" + bbox2label, bbox2mask, _ = bbox2fields() + for key in results.get('bbox_fields', []): + bbox_w = results[key][:, 2] - results[key][:, 0] + bbox_h = results[key][:, 3] - results[key][:, 1] + valid_inds = (bbox_w > min_bbox_size) & (bbox_h > min_bbox_size) + valid_inds = np.nonzero(valid_inds)[0] + results[key] = results[key][valid_inds] + # label fields. e.g. gt_labels and gt_labels_ignore + label_key = bbox2label.get(key) + if label_key in results: + results[label_key] = results[label_key][valid_inds] + # mask fields, e.g. gt_masks and gt_masks_ignore + mask_key = bbox2mask.get(key) + if mask_key in results: + results[mask_key] = results[mask_key][valid_inds] + + def __call__(self, results): + """Call function to rotate images, bounding boxes, masks and semantic + segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Rotated results. + """ + if np.random.rand() > self.prob: + return results + h, w = results['img'].shape[:2] + center = self.center + if center is None: + center = ((w - 1) * 0.5, (h - 1) * 0.5) + angle = random_negative(self.angle, self.random_negative_prob) + self._rotate_img(results, angle, center, self.scale) + rotate_matrix = cv2.getRotationMatrix2D(center, -angle, self.scale) + self._rotate_bboxes(results, rotate_matrix) + self._rotate_keypoints90(results, angle) + self._rotate_masks(results, angle, center, self.scale, fill_val=0) + self._rotate_seg( + results, angle, center, self.scale, fill_val=self.seg_ignore_label) + self._filter_invalid(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(level={self.level}, ' + repr_str += f'scale={self.scale}, ' + repr_str += f'center={self.center}, ' + repr_str += f'img_fill_val={self.img_fill_val}, ' + repr_str += f'seg_ignore_label={self.seg_ignore_label}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'max_rotate_angle={self.max_rotate_angle}, ' + repr_str += f'random_negative_prob={self.random_negative_prob})' + return repr_str diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/formating.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/formating.py new file mode 100644 index 00000000..bd2394a8 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/formating.py @@ -0,0 +1,113 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/formating.py +""" +import numpy as np +import torch +from mmcv.parallel import DataContainer as DC +from mmdet.datasets.builder import PIPELINES + + +def to_tensor(data): + """Convert objects of various python types to :obj:`torch.Tensor`. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int` and :class:`float`. + + Args: + data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to + be converted. + """ + + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray): + return torch.from_numpy(data) + elif isinstance(data, Sequence) and not mmcv.is_str(data): + return torch.tensor(data) + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + else: + raise TypeError(f'type {type(data)} cannot be converted to tensor.') + + +@PIPELINES.register_module() +class DefaultFormatBundleV2(object): + """Default formatting bundle. + + It simplifies the pipeline of formatting common fields, including "img", + "proposals", "gt_bboxes", "gt_labels", "gt_masks" and "gt_semantic_seg". + These fields are formatted as follows. + + - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True) + - proposals: (1)to tensor, (2)to DataContainer + - gt_bboxes: (1)to tensor, (2)to DataContainer + - gt_bboxes_ignore: (1)to tensor, (2)to DataContainer + - gt_labels: (1)to tensor, (2)to DataContainer + - gt_masks: (1)to tensor, (2)to DataContainer (cpu_only=True) + - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, \ + (3)to DataContainer (stack=True) + """ + + def __call__(self, results): + """Call function to transform and format common fields in results. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + dict: The result dict contains the data that is formatted with \ + default bundle. + """ + + if 'img' in results: + img = results['img'] + # add default meta keys + results = self._add_default_meta_keys(results) + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + img = np.ascontiguousarray(img.transpose(2, 0, 1)) + results['img'] = DC(to_tensor(img), stack=True) + for key in [ + 'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_keypointss', + 'gt_labels' + ]: + if key not in results: + continue + results[key] = DC(to_tensor(results[key])) + if 'gt_masks' in results: + results['gt_masks'] = DC(results['gt_masks'], cpu_only=True) + if 'gt_semantic_seg' in results: + results['gt_semantic_seg'] = DC( + to_tensor(results['gt_semantic_seg'][None, ...]), stack=True) + return results + + def _add_default_meta_keys(self, results): + """Add default meta keys. + + We set default meta keys including `pad_shape`, `scale_factor` and + `img_norm_cfg` to avoid the case where no `Resize`, `Normalize` and + `Pad` are implemented during the whole pipeline. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + results (dict): Updated result dict contains the data to convert. + """ + img = results['img'] + results.setdefault('pad_shape', img.shape) + results.setdefault('scale_factor', 1.0) + num_channels = 1 if len(img.shape) < 3 else img.shape[2] + results.setdefault( + 'img_norm_cfg', + dict( + mean=np.zeros(num_channels, dtype=np.float32), + std=np.ones(num_channels, dtype=np.float32), + to_rgb=False)) + return results + + def __repr__(self): + return self.__class__.__name__ diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/loading.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/loading.py new file mode 100644 index 00000000..b4c2a385 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/loading.py @@ -0,0 +1,225 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/loading.py +""" +import os.path as osp + +import numpy as np +import pycocotools.mask as maskUtils +from mmdet.core import BitmapMasks, PolygonMasks +from mmdet.datasets.builder import PIPELINES + + +@PIPELINES.register_module() +class LoadAnnotationsV2(object): + """Load mutiple types of annotations. + + Args: + with_bbox (bool): Whether to parse and load the bbox annotation. + Default: True. + with_label (bool): Whether to parse and load the label annotation. + Default: True. + with_keypoints (bool): Whether to parse and load the keypoints annotation. + Default: False. + with_mask (bool): Whether to parse and load the mask annotation. + Default: False. + with_seg (bool): Whether to parse and load the semantic segmentation + annotation. Default: False. + poly2mask (bool): Whether to convert the instance masks from polygons + to bitmaps. Default: True. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. + Defaults to ``dict(backend='disk')``. + """ + + def __init__(self, + with_bbox=True, + with_label=True, + with_keypoints=False, + with_mask=False, + with_seg=False, + poly2mask=True, + file_client_args=dict(backend='disk')): + self.with_bbox = with_bbox + self.with_label = with_label + self.with_keypoints = with_keypoints + self.with_mask = with_mask + self.with_seg = with_seg + self.poly2mask = poly2mask + self.file_client_args = file_client_args.copy() + self.file_client = None + + def _load_bboxes(self, results): + """Private function to load bounding box annotations. + + Args: + results (dict): Result dict from :obj:`mmdet.CustomDataset`. + + Returns: + dict: The dict contains loaded bounding box annotations. + """ + + ann_info = results['ann_info'] + results['gt_bboxes'] = ann_info['bboxes'].copy() + + gt_bboxes_ignore = ann_info.get('bboxes_ignore', None) + if gt_bboxes_ignore is not None: + results['gt_bboxes_ignore'] = gt_bboxes_ignore.copy() + results['bbox_fields'].append('gt_bboxes_ignore') + results['bbox_fields'].append('gt_bboxes') + return results + + def _load_keypoints(self, results): + """Private function to load bounding box annotations. + + Args: + results (dict): Result dict from :obj:`mmdet.CustomDataset`. + + Returns: + dict: The dict contains loaded bounding box annotations. + """ + + ann_info = results['ann_info'] + results['gt_keypointss'] = ann_info['keypointss'].copy() + + results['keypoints_fields'] = ['gt_keypointss'] + return results + + def _load_labels(self, results): + """Private function to load label annotations. + + Args: + results (dict): Result dict from :obj:`mmdet.CustomDataset`. + + Returns: + dict: The dict contains loaded label annotations. + """ + + results['gt_labels'] = results['ann_info']['labels'].copy() + return results + + def _poly2mask(self, mask_ann, img_h, img_w): + """Private function to convert masks represented with polygon to + bitmaps. + + Args: + mask_ann (list | dict): Polygon mask annotation input. + img_h (int): The height of output mask. + img_w (int): The width of output mask. + + Returns: + numpy.ndarray: The decode bitmap mask of shape (img_h, img_w). + """ + + if isinstance(mask_ann, list): + # polygon -- a single object might consist of multiple parts + # we merge all parts into one mask rle code + rles = maskUtils.frPyObjects(mask_ann, img_h, img_w) + rle = maskUtils.merge(rles) + elif isinstance(mask_ann['counts'], list): + # uncompressed RLE + rle = maskUtils.frPyObjects(mask_ann, img_h, img_w) + else: + # rle + rle = mask_ann + mask = maskUtils.decode(rle) + return mask + + def process_polygons(self, polygons): + """Convert polygons to list of ndarray and filter invalid polygons. + + Args: + polygons (list[list]): Polygons of one instance. + + Returns: + list[numpy.ndarray]: Processed polygons. + """ + + polygons = [np.array(p) for p in polygons] + valid_polygons = [] + for polygon in polygons: + if len(polygon) % 2 == 0 and len(polygon) >= 6: + valid_polygons.append(polygon) + return valid_polygons + + def _load_masks(self, results): + """Private function to load mask annotations. + + Args: + results (dict): Result dict from :obj:`mmdet.CustomDataset`. + + Returns: + dict: The dict contains loaded mask annotations. + If ``self.poly2mask`` is set ``True``, `gt_mask` will contain + :obj:`PolygonMasks`. Otherwise, :obj:`BitmapMasks` is used. + """ + + h, w = results['img_info']['height'], results['img_info']['width'] + gt_masks = results['ann_info']['masks'] + if self.poly2mask: + gt_masks = BitmapMasks( + [self._poly2mask(mask, h, w) for mask in gt_masks], h, w) + else: + gt_masks = PolygonMasks( + [self.process_polygons(polygons) for polygons in gt_masks], h, + w) + results['gt_masks'] = gt_masks + results['mask_fields'].append('gt_masks') + return results + + def _load_semantic_seg(self, results): + """Private function to load semantic segmentation annotations. + + Args: + results (dict): Result dict from :obj:`dataset`. + + Returns: + dict: The dict contains loaded semantic segmentation annotations. + """ + import mmcv + if self.file_client is None: + self.file_client = mmcv.FileClient(**self.file_client_args) + + filename = osp.join(results['seg_prefix'], + results['ann_info']['seg_map']) + img_bytes = self.file_client.get(filename) + results['gt_semantic_seg'] = mmcv.imfrombytes( + img_bytes, flag='unchanged').squeeze() + results['seg_fields'].append('gt_semantic_seg') + return results + + def __call__(self, results): + """Call function to load multiple types annotations. + + Args: + results (dict): Result dict from :obj:`mmdet.CustomDataset`. + + Returns: + dict: The dict contains loaded bounding box, label, mask and + semantic segmentation annotations. + """ + + if self.with_bbox: + results = self._load_bboxes(results) + if results is None: + return None + if self.with_label: + results = self._load_labels(results) + if self.with_keypoints: + results = self._load_keypoints(results) + if self.with_mask: + results = self._load_masks(results) + if self.with_seg: + results = self._load_semantic_seg(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(with_bbox={self.with_bbox}, ' + repr_str += f'with_label={self.with_label}, ' + repr_str += f'with_keypoints={self.with_keypoints}, ' + repr_str += f'with_mask={self.with_mask}, ' + repr_str += f'with_seg={self.with_seg})' + repr_str += f'poly2mask={self.poly2mask})' + repr_str += f'poly2mask={self.file_client_args})' + return repr_str diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/transforms.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/transforms.py new file mode 100755 index 00000000..270c34da --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/transforms.py @@ -0,0 +1,737 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/transforms.py +""" +import mmcv +import numpy as np +from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps +from mmdet.datasets.builder import PIPELINES +from numpy import random + + +@PIPELINES.register_module() +class ResizeV2(object): + """Resize images & bbox & mask &kps. + + This transform resizes the input image to some scale. Bboxes and masks are + then resized with the same scale factor. If the input dict contains the key + "scale", then the scale in the input dict is used, otherwise the specified + scale in the init method is used. If the input dict contains the key + "scale_factor" (if MultiScaleFlipAug does not give img_scale but + scale_factor), the actual scale will be computed by image shape and + scale_factor. + + `img_scale` can either be a tuple (single-scale) or a list of tuple + (multi-scale). There are 3 multiscale modes: + + - ``ratio_range is not None``: randomly sample a ratio from the ratio \ + range and multiply it with the image scale. + - ``ratio_range is None`` and ``multiscale_mode == "range"``: randomly \ + sample a scale from the multiscale range. + - ``ratio_range is None`` and ``multiscale_mode == "value"``: randomly \ + sample a scale from multiple scales. + + Args: + img_scale (tuple or list[tuple]): Images scales for resizing. + multiscale_mode (str): Either "range" or "value". + ratio_range (tuple[float]): (min_ratio, max_ratio) + keep_ratio (bool): Whether to keep the aspect ratio when resizing the + image. + bbox_clip_border (bool, optional): Whether clip the objects outside + the border of the image. Defaults to True. + backend (str): Image resize backend, choices are 'cv2' and 'pillow'. + These two backends generates slightly different results. Defaults + to 'cv2'. + override (bool, optional): Whether to override `scale` and + `scale_factor` so as to call resize twice. Default False. If True, + after the first resizing, the existed `scale` and `scale_factor` + will be ignored so the second resizing can be allowed. + This option is a work-around for multiple times of resize in DETR. + Defaults to False. + """ + + def __init__(self, + img_scale=None, + multiscale_mode='range', + ratio_range=None, + keep_ratio=True, + bbox_clip_border=True, + backend='cv2', + override=False): + if img_scale is None: + self.img_scale = None + else: + if isinstance(img_scale, list): + self.img_scale = img_scale + else: + self.img_scale = [img_scale] + assert mmcv.is_list_of(self.img_scale, tuple) + + if ratio_range is not None: + # mode 1: given a scale and a range of image ratio + assert len(self.img_scale) == 1 + else: + # mode 2: given multiple scales or a range of scales + assert multiscale_mode in ['value', 'range'] + + self.backend = backend + self.multiscale_mode = multiscale_mode + self.ratio_range = ratio_range + self.keep_ratio = keep_ratio + # TODO: refactor the override option in Resize + self.override = override + self.bbox_clip_border = bbox_clip_border + + @staticmethod + def random_select(img_scales): + """Randomly select an img_scale from given candidates. + + Args: + img_scales (list[tuple]): Images scales for selection. + + Returns: + (tuple, int): Returns a tuple ``(img_scale, scale_dix)``, \ + where ``img_scale`` is the selected image scale and \ + ``scale_idx`` is the selected index in the given candidates. + """ + + assert mmcv.is_list_of(img_scales, tuple) + scale_idx = np.random.randint(len(img_scales)) + img_scale = img_scales[scale_idx] + return img_scale, scale_idx + + @staticmethod + def random_sample(img_scales): + """Randomly sample an img_scale when ``multiscale_mode=='range'``. + + Args: + img_scales (list[tuple]): Images scale range for sampling. + There must be two tuples in img_scales, which specify the lower + and uper bound of image scales. + + Returns: + (tuple, None): Returns a tuple ``(img_scale, None)``, where \ + ``img_scale`` is sampled scale and None is just a placeholder \ + to be consistent with :func:`random_select`. + """ + + assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2 + img_scale_long = [max(s) for s in img_scales] + img_scale_short = [min(s) for s in img_scales] + long_edge = np.random.randint( + min(img_scale_long), + max(img_scale_long) + 1) + short_edge = np.random.randint( + min(img_scale_short), + max(img_scale_short) + 1) + img_scale = (long_edge, short_edge) + return img_scale, None + + @staticmethod + def random_sample_ratio(img_scale, ratio_range): + """Randomly sample an img_scale when ``ratio_range`` is specified. + + A ratio will be randomly sampled from the range specified by + ``ratio_range``. Then it would be multiplied with ``img_scale`` to + generate sampled scale. + + Args: + img_scale (tuple): Images scale base to multiply with ratio. + ratio_range (tuple[float]): The minimum and maximum ratio to scale + the ``img_scale``. + + Returns: + (tuple, None): Returns a tuple ``(scale, None)``, where \ + ``scale`` is sampled ratio multiplied with ``img_scale`` and \ + None is just a placeholder to be consistent with \ + :func:`random_select`. + """ + + assert isinstance(img_scale, tuple) and len(img_scale) == 2 + min_ratio, max_ratio = ratio_range + assert min_ratio <= max_ratio + ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio + scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio) + return scale, None + + def _random_scale(self, results): + """Randomly sample an img_scale according to ``ratio_range`` and + ``multiscale_mode``. + + If ``ratio_range`` is specified, a ratio will be sampled and be + multiplied with ``img_scale``. + If multiple scales are specified by ``img_scale``, a scale will be + sampled according to ``multiscale_mode``. + Otherwise, single scale will be used. + + Args: + results (dict): Result dict from :obj:`dataset`. + + Returns: + dict: Two new keys 'scale` and 'scale_idx` are added into \ + ``results``, which would be used by subsequent pipelines. + """ + + if self.ratio_range is not None: + scale, scale_idx = self.random_sample_ratio( + self.img_scale[0], self.ratio_range) + elif len(self.img_scale) == 1: + scale, scale_idx = self.img_scale[0], 0 + elif self.multiscale_mode == 'range': + scale, scale_idx = self.random_sample(self.img_scale) + elif self.multiscale_mode == 'value': + scale, scale_idx = self.random_select(self.img_scale) + else: + raise NotImplementedError + + results['scale'] = scale + results['scale_idx'] = scale_idx + + def _resize_img(self, results): + """Resize images with ``results['scale']``.""" + for key in results.get('img_fields', ['img']): + if self.keep_ratio: + img, scale_factor = mmcv.imrescale( + results[key], + results['scale'], + return_scale=True, + backend=self.backend) + # the w_scale and h_scale has minor difference + # a real fix should be done in the mmcv.imrescale in the future + new_h, new_w = img.shape[:2] + h, w = results[key].shape[:2] + w_scale = new_w / w + h_scale = new_h / h + else: + img, w_scale, h_scale = mmcv.imresize( + results[key], + results['scale'], + return_scale=True, + backend=self.backend) + results[key] = img + + scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], + dtype=np.float32) + results['img_shape'] = img.shape + # in case that there is no padding + results['pad_shape'] = img.shape + results['scale_factor'] = scale_factor + results['keep_ratio'] = self.keep_ratio + + def _resize_bboxes(self, results): + """Resize bounding boxes with ``results['scale_factor']``.""" + for key in results.get('bbox_fields', []): + bboxes = results[key] * results['scale_factor'] + if self.bbox_clip_border: + img_shape = results['img_shape'] + bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1]) + bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0]) + results[key] = bboxes + + def _resize_keypoints(self, results): + """Resize keypoints with ``results['scale_factor']``.""" + for key in results.get('keypoints_fields', []): + keypointss = results[key].copy() + factors = results['scale_factor'] + assert factors[0] == factors[2] + assert factors[1] == factors[3] + keypointss[:, :, 0] *= factors[0] + keypointss[:, :, 1] *= factors[1] + if self.bbox_clip_border: + img_shape = results['img_shape'] + keypointss[:, :, 0] = np.clip(keypointss[:, :, 0], 0, + img_shape[1]) + keypointss[:, :, 1] = np.clip(keypointss[:, :, 1], 0, + img_shape[0]) + results[key] = keypointss + + def _resize_masks(self, results): + """Resize masks with ``results['scale']``""" + for key in results.get('mask_fields', []): + if results[key] is None: + continue + if self.keep_ratio: + results[key] = results[key].rescale(results['scale']) + else: + results[key] = results[key].resize(results['img_shape'][:2]) + + def _resize_seg(self, results): + """Resize semantic segmentation map with ``results['scale']``.""" + for key in results.get('seg_fields', []): + if self.keep_ratio: + gt_seg = mmcv.imrescale( + results[key], + results['scale'], + interpolation='nearest', + backend=self.backend) + else: + gt_seg = mmcv.imresize( + results[key], + results['scale'], + interpolation='nearest', + backend=self.backend) + results['gt_semantic_seg'] = gt_seg + + def __call__(self, results): + """Call function to resize images, bounding boxes, masks, semantic + segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', \ + 'keep_ratio' keys are added into result dict. + """ + + if 'scale' not in results: + if 'scale_factor' in results: + img_shape = results['img'].shape[:2] + scale_factor = results['scale_factor'] + assert isinstance(scale_factor, float) + results['scale'] = tuple( + [int(x * scale_factor) for x in img_shape][::-1]) + else: + self._random_scale(results) + else: + if not self.override: + assert 'scale_factor' not in results, ( + 'scale and scale_factor cannot be both set.') + else: + results.pop('scale') + if 'scale_factor' in results: + results.pop('scale_factor') + self._random_scale(results) + + self._resize_img(results) + self._resize_bboxes(results) + self._resize_keypoints(results) + self._resize_masks(results) + self._resize_seg(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(img_scale={self.img_scale}, ' + repr_str += f'multiscale_mode={self.multiscale_mode}, ' + repr_str += f'ratio_range={self.ratio_range}, ' + repr_str += f'keep_ratio={self.keep_ratio})' + repr_str += f'bbox_clip_border={self.bbox_clip_border})' + return repr_str + + +@PIPELINES.register_module() +class RandomFlipV2(object): + """Flip the image & bbox & mask & kps. + + If the input dict contains the key "flip", then the flag will be used, + otherwise it will be randomly decided by a ratio specified in the init + method. + + When random flip is enabled, ``flip_ratio``/``direction`` can either be a + float/string or tuple of float/string. There are 3 flip modes: + + - ``flip_ratio`` is float, ``direction`` is string: the image will be + ``direction``ly flipped with probability of ``flip_ratio`` . + E.g., ``flip_ratio=0.5``, ``direction='horizontal'``, + then image will be horizontally flipped with probability of 0.5. + - ``flip_ratio`` is float, ``direction`` is list of string: the image wil + be ``direction[i]``ly flipped with probability of + ``flip_ratio/len(direction)``. + E.g., ``flip_ratio=0.5``, ``direction=['horizontal', 'vertical']``, + then image will be horizontally flipped with probability of 0.25, + vertically with probability of 0.25. + - ``flip_ratio`` is list of float, ``direction`` is list of string: + given ``len(flip_ratio) == len(direction)``, the image wil + be ``direction[i]``ly flipped with probability of ``flip_ratio[i]``. + E.g., ``flip_ratio=[0.3, 0.5]``, ``direction=['horizontal', + 'vertical']``, then image will be horizontally flipped with probability + of 0.3, vertically with probability of 0.5 + + Args: + flip_ratio (float | list[float], optional): The flipping probability. + Default: None. + direction(str | list[str], optional): The flipping direction. Options + are 'horizontal', 'vertical', 'diagonal'. Default: 'horizontal'. + If input is a list, the length must equal ``flip_ratio``. Each + element in ``flip_ratio`` indicates the flip probability of + corresponding direction. + """ + + def __init__(self, flip_ratio=None, direction='horizontal'): + if isinstance(flip_ratio, list): + assert mmcv.is_list_of(flip_ratio, float) + assert 0 <= sum(flip_ratio) <= 1 + elif isinstance(flip_ratio, float): + assert 0 <= flip_ratio <= 1 + elif flip_ratio is None: + pass + else: + raise ValueError('flip_ratios must be None, float, ' + 'or list of float') + self.flip_ratio = flip_ratio + + valid_directions = ['horizontal', 'vertical', 'diagonal'] + if isinstance(direction, str): + assert direction in valid_directions + elif isinstance(direction, list): + assert mmcv.is_list_of(direction, str) + assert set(direction).issubset(set(valid_directions)) + else: + raise ValueError('direction must be either str or list of str') + self.direction = direction + + if isinstance(flip_ratio, list): + assert len(self.flip_ratio) == len(self.direction) + self.count = 0 + + def bbox_flip(self, bboxes, img_shape, direction): + """Flip bboxes horizontally. + + Args: + bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k) + img_shape (tuple[int]): Image shape (height, width) + direction (str): Flip direction. Options are 'horizontal', + 'vertical'. + + Returns: + numpy.ndarray: Flipped bounding boxes. + """ + + assert bboxes.shape[-1] % 4 == 0 + flipped = bboxes.copy() + if direction == 'horizontal': + w = img_shape[1] + flipped[..., 0::4] = w - bboxes[..., 2::4] + flipped[..., 2::4] = w - bboxes[..., 0::4] + elif direction == 'vertical': + h = img_shape[0] + flipped[..., 1::4] = h - bboxes[..., 3::4] + flipped[..., 3::4] = h - bboxes[..., 1::4] + elif direction == 'diagonal': + w = img_shape[1] + h = img_shape[0] + flipped[..., 0::4] = w - bboxes[..., 2::4] + flipped[..., 1::4] = h - bboxes[..., 3::4] + flipped[..., 2::4] = w - bboxes[..., 0::4] + flipped[..., 3::4] = h - bboxes[..., 1::4] + else: + raise ValueError(f"Invalid flipping direction '{direction}'") + return flipped + + def keypoints_flip(self, keypointss, img_shape, direction): + """Flip keypoints horizontally.""" + + assert direction == 'horizontal' + assert keypointss.shape[-1] == 3 + num_kps = keypointss.shape[1] + assert num_kps in [4, 5], f'Only Support num_kps=4 or 5, got:{num_kps}' + assert keypointss.ndim == 3 + flipped = keypointss.copy() + if num_kps == 5: + flip_order = [1, 0, 2, 4, 3] + elif num_kps == 4: + flip_order = [3, 2, 1, 0] + for idx, a in enumerate(flip_order): + flipped[:, idx, :] = keypointss[:, a, :] + w = img_shape[1] + flipped[..., 0] = w - flipped[..., 0] + return flipped + + def __call__(self, results): + """Call function to flip bounding boxes, masks, semantic segmentation + maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Flipped results, 'flip', 'flip_direction' keys are added \ + into result dict. + """ + if 'flip' not in results: + if isinstance(self.direction, list): + # None means non-flip + direction_list = self.direction + [None] + else: + # None means non-flip + direction_list = [self.direction, None] + + if isinstance(self.flip_ratio, list): + non_flip_ratio = 1 - sum(self.flip_ratio) + flip_ratio_list = self.flip_ratio + [non_flip_ratio] + else: + non_flip_ratio = 1 - self.flip_ratio + # exclude non-flip + single_ratio = self.flip_ratio / (len(direction_list) - 1) + flip_ratio_list = [single_ratio] * (len(direction_list) + - 1) + [non_flip_ratio] + + cur_dir = np.random.choice(direction_list, p=flip_ratio_list) + + results['flip'] = cur_dir is not None + if 'flip_direction' not in results: + results['flip_direction'] = cur_dir + if results['flip']: + # flip image + for key in results.get('img_fields', ['img']): + results[key] = mmcv.imflip( + results[key], direction=results['flip_direction']) + # flip bboxes + for key in results.get('bbox_fields', []): + results[key] = self.bbox_flip(results[key], + results['img_shape'], + results['flip_direction']) + # flip kps + for key in results.get('keypoints_fields', []): + results[key] = self.keypoints_flip(results[key], + results['img_shape'], + results['flip_direction']) + # flip masks + for key in results.get('mask_fields', []): + results[key] = results[key].flip(results['flip_direction']) + + # flip segs + for key in results.get('seg_fields', []): + results[key] = mmcv.imflip( + results[key], direction=results['flip_direction']) + return results + + def __repr__(self): + return self.__class__.__name__ + f'(flip_ratio={self.flip_ratio})' + + +@PIPELINES.register_module() +class RandomSquareCrop(object): + """Random crop the image & bboxes, the cropped patches have minimum IoU + requirement with original image & bboxes, the IoU threshold is randomly + selected from min_ious. + + Args: + min_ious (tuple): minimum IoU threshold for all intersections with + bounding boxes + min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w, + where a >= min_crop_size). + + Note: + The keys for bboxes, labels and masks should be paired. That is, \ + `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \ + `gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`. + """ + + def __init__(self, + crop_ratio_range=None, + crop_choice=None, + bbox_clip_border=True, + big_face_ratio=0, + big_face_crop_choice=None): + + self.crop_ratio_range = crop_ratio_range + self.crop_choice = crop_choice + self.big_face_crop_choice = big_face_crop_choice + self.bbox_clip_border = bbox_clip_border + + assert (self.crop_ratio_range is None) ^ (self.crop_choice is None) + if self.crop_ratio_range is not None: + self.crop_ratio_min, self.crop_ratio_max = self.crop_ratio_range + + self.bbox2label = { + 'gt_bboxes': 'gt_labels', + 'gt_bboxes_ignore': 'gt_labels_ignore' + } + self.bbox2mask = { + 'gt_bboxes': 'gt_masks', + 'gt_bboxes_ignore': 'gt_masks_ignore' + } + assert big_face_ratio >= 0 and big_face_ratio <= 1.0 + self.big_face_ratio = big_face_ratio + + def __call__(self, results): + """Call function to crop images and bounding boxes with minimum IoU + constraint. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with images and bounding boxes cropped, \ + 'img_shape' key is updated. + """ + + if 'img_fields' in results: + assert results['img_fields'] == ['img'], \ + 'Only single img_fields is allowed' + img = results['img'] + assert 'bbox_fields' in results + assert 'gt_bboxes' in results + # try augment big face images + find_bigface = False + if np.random.random() < self.big_face_ratio: + min_size = 100 # h and w + expand_ratio = 0.3 # expand ratio of croped face alongwith both w and h + bbox = results['gt_bboxes'].copy() + lmks = results['gt_keypointss'].copy() + label = results['gt_labels'].copy() + # filter small faces + size_mask = ((bbox[:, 2] - bbox[:, 0]) > min_size) * ( + (bbox[:, 3] - bbox[:, 1]) > min_size) + bbox = bbox[size_mask] + lmks = lmks[size_mask] + label = label[size_mask] + # randomly choose a face that has no overlap with others + if len(bbox) > 0: + overlaps = bbox_overlaps(bbox, bbox) + overlaps -= np.eye(overlaps.shape[0]) + iou_mask = np.sum(overlaps, axis=1) == 0 + bbox = bbox[iou_mask] + lmks = lmks[iou_mask] + label = label[iou_mask] + if len(bbox) > 0: + choice = np.random.randint(len(bbox)) + bbox = bbox[choice] + lmks = lmks[choice] + label = [label[choice]] + w = bbox[2] - bbox[0] + h = bbox[3] - bbox[1] + x1 = bbox[0] - w * expand_ratio + x2 = bbox[2] + w * expand_ratio + y1 = bbox[1] - h * expand_ratio + y2 = bbox[3] + h * expand_ratio + x1, x2 = np.clip([x1, x2], 0, img.shape[1]) + y1, y2 = np.clip([y1, y2], 0, img.shape[0]) + bbox -= np.tile([x1, y1], 2) + lmks -= (x1, y1, 0) + + find_bigface = True + img = img[int(y1):int(y2), int(x1):int(x2), :] + results['gt_bboxes'] = np.expand_dims(bbox, axis=0) + results['gt_keypointss'] = np.expand_dims(lmks, axis=0) + results['gt_labels'] = np.array(label) + results['img'] = img + + boxes = results['gt_bboxes'] + h, w, c = img.shape + + if self.crop_ratio_range is not None: + max_scale = self.crop_ratio_max + else: + max_scale = np.amax(self.crop_choice) + scale_retry = 0 + while True: + scale_retry += 1 + if scale_retry == 1 or max_scale > 1.0: + if self.crop_ratio_range is not None: + scale = np.random.uniform(self.crop_ratio_min, + self.crop_ratio_max) + elif self.crop_choice is not None: + scale = np.random.choice(self.crop_choice) + else: + scale = scale * 1.2 + + if find_bigface: + # select a scale from big_face_crop_choice if in big_face mode + scale = np.random.choice(self.big_face_crop_choice) + + for i in range(250): + long_side = max(w, h) + cw = int(scale * long_side) + ch = cw + + # TODO +1 + if w == cw: + left = 0 + elif w > cw: + left = random.randint(0, w - cw) + else: + left = random.randint(w - cw, 0) + if h == ch: + top = 0 + elif h > ch: + top = random.randint(0, h - ch) + else: + top = random.randint(h - ch, 0) + + patch = np.array( + (int(left), int(top), int(left + cw), int(top + ch)), + dtype=np.int32) + + # center of boxes should inside the crop img + # only adjust boxes and instance masks when the gt is not empty + # adjust boxes + def is_center_of_bboxes_in_patch(boxes, patch): + # TODO >= + center = (boxes[:, :2] + boxes[:, 2:]) / 2 + mask = \ + ((center[:, 0] > patch[0]) + * (center[:, 1] > patch[1]) + * (center[:, 0] < patch[2]) + * (center[:, 1] < patch[3])) + return mask + + mask = is_center_of_bboxes_in_patch(boxes, patch) + if not mask.any(): + continue + for key in results.get('bbox_fields', []): + boxes = results[key].copy() + mask = is_center_of_bboxes_in_patch(boxes, patch) + boxes = boxes[mask] + if self.bbox_clip_border: + boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) + boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) + boxes -= np.tile(patch[:2], 2) + + results[key] = boxes + # labels + label_key = self.bbox2label.get(key) + if label_key in results: + results[label_key] = results[label_key][mask] + + # keypoints field + if key == 'gt_bboxes': + for kps_key in results.get('keypoints_fields', []): + keypointss = results[kps_key].copy() + keypointss = keypointss[mask, :, :] + if self.bbox_clip_border: + keypointss[:, :, : + 2] = keypointss[:, :, :2].clip( + max=patch[2:]) + keypointss[:, :, : + 2] = keypointss[:, :, :2].clip( + min=patch[:2]) + keypointss[:, :, 0] -= patch[0] + keypointss[:, :, 1] -= patch[1] + results[kps_key] = keypointss + + # mask fields + mask_key = self.bbox2mask.get(key) + if mask_key in results: + results[mask_key] = results[mask_key][mask.nonzero() + [0]].crop(patch) + + # adjust the img no matter whether the gt is empty before crop + rimg = np.ones((ch, cw, 3), dtype=img.dtype) * 128 + patch_from = patch.copy() + patch_from[0] = max(0, patch_from[0]) + patch_from[1] = max(0, patch_from[1]) + patch_from[2] = min(img.shape[1], patch_from[2]) + patch_from[3] = min(img.shape[0], patch_from[3]) + patch_to = patch.copy() + patch_to[0] = max(0, patch_to[0] * -1) + patch_to[1] = max(0, patch_to[1] * -1) + patch_to[2] = patch_to[0] + (patch_from[2] - patch_from[0]) + patch_to[3] = patch_to[1] + (patch_from[3] - patch_from[1]) + rimg[patch_to[1]:patch_to[3], + patch_to[0]:patch_to[2], :] = img[ + patch_from[1]:patch_from[3], + patch_from[0]:patch_from[2], :] + img = rimg + results['img'] = img + results['img_shape'] = img.shape + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(min_ious={self.min_iou}, ' + repr_str += f'crop_size={self.crop_size})' + return repr_str diff --git a/modelscope/models/cv/face_detection/mmdet_patch/datasets/retinaface.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/retinaface.py similarity index 97% rename from modelscope/models/cv/face_detection/mmdet_patch/datasets/retinaface.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/retinaface.py index bbacd9be..40c440b9 100755 --- a/modelscope/models/cv/face_detection/mmdet_patch/datasets/retinaface.py +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/retinaface.py @@ -13,7 +13,7 @@ class RetinaFaceDataset(CustomDataset): CLASSES = ('FG', ) def __init__(self, min_size=None, **kwargs): - self.NK = 5 + self.NK = kwargs.pop('num_kps', 5) self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)} self.min_size = min_size self.gt_path = kwargs.get('gt_path') @@ -33,7 +33,8 @@ class RetinaFaceDataset(CustomDataset): if len(values) > 4: if len(values) > 5: kps = np.array( - values[4:19], dtype=np.float32).reshape((self.NK, 3)) + values[4:4 + self.NK * 3], dtype=np.float32).reshape( + (self.NK, 3)) for li in range(kps.shape[0]): if (kps[li, :] == -1).all(): kps[li][2] = 0.0 # weight = 0, ignore diff --git a/modelscope/models/cv/face_detection/mmdet_patch/models/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/__init__.py similarity index 100% rename from modelscope/models/cv/face_detection/mmdet_patch/models/__init__.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/__init__.py diff --git a/modelscope/models/cv/face_detection/mmdet_patch/models/backbones/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/__init__.py similarity index 100% rename from modelscope/models/cv/face_detection/mmdet_patch/models/backbones/__init__.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/__init__.py diff --git a/modelscope/models/cv/face_detection/mmdet_patch/models/backbones/resnet.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/resnet.py similarity index 100% rename from modelscope/models/cv/face_detection/mmdet_patch/models/backbones/resnet.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/resnet.py diff --git a/modelscope/models/cv/face_detection/mmdet_patch/models/dense_heads/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/dense_heads/__init__.py similarity index 100% rename from modelscope/models/cv/face_detection/mmdet_patch/models/dense_heads/__init__.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/dense_heads/__init__.py diff --git a/modelscope/models/cv/face_detection/mmdet_patch/models/dense_heads/scrfd_head.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/dense_heads/scrfd_head.py similarity index 99% rename from modelscope/models/cv/face_detection/mmdet_patch/models/dense_heads/scrfd_head.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/dense_heads/scrfd_head.py index acc45670..77ec99cf 100755 --- a/modelscope/models/cv/face_detection/mmdet_patch/models/dense_heads/scrfd_head.py +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/dense_heads/scrfd_head.py @@ -103,6 +103,7 @@ class SCRFDHead(AnchorHead): scale_mode=1, dw_conv=False, use_kps=False, + num_kps=5, loss_kps=dict( type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.1), **kwargs): @@ -116,7 +117,7 @@ class SCRFDHead(AnchorHead): self.scale_mode = scale_mode self.use_dfl = True self.dw_conv = dw_conv - self.NK = 5 + self.NK = num_kps self.extra_flops = 0.0 if loss_dfl is None or not loss_dfl: self.use_dfl = False @@ -323,8 +324,8 @@ class SCRFDHead(AnchorHead): batch_size, -1, self.cls_out_channels).sigmoid() bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4) - kps_pred = kps_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 10) - + kps_pred = kps_pred.permute(0, 2, 3, + 1).reshape(batch_size, -1, self.NK * 2) return cls_score, bbox_pred, kps_pred def forward_train(self, @@ -788,7 +789,7 @@ class SCRFDHead(AnchorHead): if self.use_dfl: kps_pred = self.integral(kps_pred) * stride[0] else: - kps_pred = kps_pred.reshape((-1, 10)) * stride[0] + kps_pred = kps_pred.reshape((-1, self.NK * 2)) * stride[0] nms_pre = cfg.get('nms_pre', -1) if nms_pre > 0 and scores.shape[0] > nms_pre: @@ -815,7 +816,7 @@ class SCRFDHead(AnchorHead): mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) if mlvl_kps is not None: scale_factor2 = torch.tensor( - [scale_factor[0], scale_factor[1]] * 5) + [scale_factor[0], scale_factor[1]] * self.NK) mlvl_kps /= scale_factor2.to(mlvl_kps.device) mlvl_scores = torch.cat(mlvl_scores) diff --git a/modelscope/models/cv/face_detection/mmdet_patch/models/detectors/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/__init__.py similarity index 100% rename from modelscope/models/cv/face_detection/mmdet_patch/models/detectors/__init__.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/__init__.py diff --git a/modelscope/models/cv/face_detection/mmdet_patch/models/detectors/scrfd.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/scrfd.py similarity index 50% rename from modelscope/models/cv/face_detection/mmdet_patch/models/detectors/scrfd.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/scrfd.py index a5f5cac2..18b46be1 100755 --- a/modelscope/models/cv/face_detection/mmdet_patch/models/detectors/scrfd.py +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/scrfd.py @@ -54,7 +54,13 @@ class SCRFD(SingleStageDetector): gt_bboxes_ignore) return losses - def simple_test(self, img, img_metas, rescale=False): + def simple_test(self, + img, + img_metas, + rescale=False, + repeat_head=1, + output_kps_var=0, + output_results=1): """Test function without test time augmentation. Args: @@ -62,6 +68,9 @@ class SCRFD(SingleStageDetector): img_metas (list[dict]): List of image information. rescale (bool, optional): Whether to rescale the results. Defaults to False. + repeat_head (int): repeat inference times in head + output_kps_var (int): whether output kps var to calculate quality + output_results (int): 0: nothing 1: bbox 2: both bbox and kps Returns: list[list[np.ndarray]]: BBox results of each image and classes. @@ -69,40 +78,71 @@ class SCRFD(SingleStageDetector): corresponds to each class. """ x = self.extract_feat(img) - outs = self.bbox_head(x) - if torch.onnx.is_in_onnx_export(): - print('single_stage.py in-onnx-export') - print(outs.__class__) - cls_score, bbox_pred, kps_pred = outs - for c in cls_score: - print(c.shape) - for c in bbox_pred: - print(c.shape) - if self.bbox_head.use_kps: - for c in kps_pred: + assert repeat_head >= 1 + kps_out0 = [] + kps_out1 = [] + kps_out2 = [] + for i in range(repeat_head): + outs = self.bbox_head(x) + kps_out0 += [outs[2][0].detach().cpu().numpy()] + kps_out1 += [outs[2][1].detach().cpu().numpy()] + kps_out2 += [outs[2][2].detach().cpu().numpy()] + if output_kps_var: + var0 = np.var(np.vstack(kps_out0), axis=0).mean() + var1 = np.var(np.vstack(kps_out1), axis=0).mean() + var2 = np.var(np.vstack(kps_out2), axis=0).mean() + var = np.mean([var0, var1, var2]) + else: + var = None + + if output_results > 0: + if torch.onnx.is_in_onnx_export(): + print('single_stage.py in-onnx-export') + print(outs.__class__) + cls_score, bbox_pred, kps_pred = outs + for c in cls_score: + print(c.shape) + for c in bbox_pred: print(c.shape) - return (cls_score, bbox_pred, kps_pred) - else: - return (cls_score, bbox_pred) - bbox_list = self.bbox_head.get_bboxes( - *outs, img_metas, rescale=rescale) + if self.bbox_head.use_kps: + for c in kps_pred: + print(c.shape) + return (cls_score, bbox_pred, kps_pred) + else: + return (cls_score, bbox_pred) + bbox_list = self.bbox_head.get_bboxes( + *outs, img_metas, rescale=rescale) - # return kps if use_kps - if len(bbox_list[0]) == 2: - bbox_results = [ - bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) - for det_bboxes, det_labels in bbox_list - ] - elif len(bbox_list[0]) == 3: - bbox_results = [ - bbox2result( - det_bboxes, - det_labels, - self.bbox_head.num_classes, - kps=det_kps) - for det_bboxes, det_labels, det_kps in bbox_list - ] - return bbox_results + # return kps if use_kps + if len(bbox_list[0]) == 2: + bbox_results = [ + bbox2result(det_bboxes, det_labels, + self.bbox_head.num_classes) + for det_bboxes, det_labels in bbox_list + ] + elif len(bbox_list[0]) == 3: + if output_results == 2: + bbox_results = [ + bbox2result( + det_bboxes, + det_labels, + self.bbox_head.num_classes, + kps=det_kps, + num_kps=self.bbox_head.NK) + for det_bboxes, det_labels, det_kps in bbox_list + ] + elif output_results == 1: + bbox_results = [ + bbox2result(det_bboxes, det_labels, + self.bbox_head.num_classes) + for det_bboxes, det_labels, _ in bbox_list + ] + else: + bbox_results = None + if var is not None: + return bbox_results, var + else: + return bbox_results def feature_test(self, img): x = self.extract_feat(img) diff --git a/modelscope/models/cv/face_detection/scrfd/scrfd_detect.py b/modelscope/models/cv/face_detection/scrfd/scrfd_detect.py new file mode 100644 index 00000000..59611604 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/scrfd_detect.py @@ -0,0 +1,71 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from copy import deepcopy +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['ScrfdDetect'] + + +@MODELS.register_module(Tasks.face_detection, module_name=Models.scrfd) +class ScrfdDetect(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the face detection model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + from mmcv import Config + from mmcv.parallel import MMDataParallel + from mmcv.runner import load_checkpoint + from mmdet.models import build_detector + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets import RetinaFaceDataset + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import RandomSquareCrop + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.backbones import ResNetV1e + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.dense_heads import SCRFDHead + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.detectors import SCRFD + cfg = Config.fromfile(osp.join(model_dir, 'mmcv_scrfd.py')) + ckpt_path = osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) + cfg.model.test_cfg.score_thr = kwargs.get('score_thr', 0.3) + detector = build_detector(cfg.model) + logger.info(f'loading model from {ckpt_path}') + device = torch.device( + f'cuda:{0}' if torch.cuda.is_available() else 'cpu') + load_checkpoint(detector, ckpt_path, map_location=device) + detector = MMDataParallel(detector, device_ids=[0]) + detector.eval() + self.detector = detector + logger.info('load model done') + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + result = self.detector( + return_loss=False, + rescale=True, + img=[input['img'][0].unsqueeze(0)], + img_metas=[[dict(input['img_metas'][0].data)]], + output_results=2) + assert result is not None + result = result[0][0] + bboxes = result[:, :4].tolist() + kpss = result[:, 5:].tolist() + scores = result[:, 4].tolist() + return { + OutputKeys.SCORES: scores, + OutputKeys.BOXES: bboxes, + OutputKeys.KEYPOINTS: kpss + } + + def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: + return input diff --git a/modelscope/models/cv/face_emotion/emotion_infer.py b/modelscope/models/cv/face_emotion/emotion_infer.py index e3398592..618822ff 100644 --- a/modelscope/models/cv/face_emotion/emotion_infer.py +++ b/modelscope/models/cv/face_emotion/emotion_infer.py @@ -25,9 +25,9 @@ emotion_list = [ ] -def inference(image_path, model, face_model, score_thre=0.5, GPU=0): - image = Image.open(image_path).convert('RGB') - +def inference(image, model, face_model, score_thre=0.5, GPU=0): + image = image.cpu().numpy() + image = Image.fromarray(image) face, bbox = face_detection_PIL_v2(image, face_model) if bbox is None: logger.warn('no face detected!') diff --git a/modelscope/models/cv/face_generation/op/conv2d_gradfix.py b/modelscope/models/cv/face_generation/op/conv2d_gradfix.py index 661f4fc7..a3aba91f 100755 --- a/modelscope/models/cv/face_generation/op/conv2d_gradfix.py +++ b/modelscope/models/cv/face_generation/op/conv2d_gradfix.py @@ -1,3 +1,5 @@ +# The implementation is adopted from stylegan2-pytorch, made public available under the MIT License +# at https://github.com/rosinality/stylegan2-pytorch/blob/master/op/conv2d_gradfix.py import contextlib import warnings diff --git a/modelscope/models/cv/face_generation/op/fused_act.py b/modelscope/models/cv/face_generation/op/fused_act.py index d6e0c10f..a24f5972 100755 --- a/modelscope/models/cv/face_generation/op/fused_act.py +++ b/modelscope/models/cv/face_generation/op/fused_act.py @@ -1,3 +1,5 @@ +# The implementation is adopted from stylegan2-pytorch, made public available under the MIT License +# t https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py import os import torch diff --git a/modelscope/models/cv/face_generation/op/upfirdn2d.py b/modelscope/models/cv/face_generation/op/upfirdn2d.py index 5a44421d..95c987af 100755 --- a/modelscope/models/cv/face_generation/op/upfirdn2d.py +++ b/modelscope/models/cv/face_generation/op/upfirdn2d.py @@ -1,3 +1,5 @@ +# The implementation is adopted from stylegan2-pytorch, made public available under the MIT License +# at https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py import os from collections import abc diff --git a/modelscope/models/cv/face_generation/stylegan2.py b/modelscope/models/cv/face_generation/stylegan2.py index ff9c83ee..4c650f54 100755 --- a/modelscope/models/cv/face_generation/stylegan2.py +++ b/modelscope/models/cv/face_generation/stylegan2.py @@ -1,3 +1,5 @@ +# The implementation is adopted from stylegan2-pytorch, +# made public available under the MIT License at https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py import functools import math import operator diff --git a/modelscope/models/cv/face_human_hand_detection/det_infer.py b/modelscope/models/cv/face_human_hand_detection/det_infer.py index 7a7225ee..6822bd9f 100644 --- a/modelscope/models/cv/face_human_hand_detection/det_infer.py +++ b/modelscope/models/cv/face_human_hand_detection/det_infer.py @@ -115,9 +115,9 @@ std = [57.375, 57.12, 58.395] class_names = ['person', 'face', 'hand'] -def inference(model, device, img_path): +def inference(model, device, img): + img = img.cpu().numpy() img_info = {'id': 0} - img = cv2.imread(img_path) height, width = img.shape[:2] img_info['height'] = height img_info['width'] = width @@ -130,4 +130,9 @@ def inference(model, device, img_path): with torch.no_grad(): res = model(meta) result = overlay_bbox_cv(res[0], class_names, score_thresh=0.35) - return result + cls_list, bbox_list, score_list = [], [], [] + for pred in result: + cls_list.append(pred[0]) + bbox_list.append([pred[1], pred[2], pred[3], pred[4]]) + score_list.append(pred[5]) + return cls_list, bbox_list, score_list diff --git a/modelscope/models/cv/face_human_hand_detection/one_stage_detector.py b/modelscope/models/cv/face_human_hand_detection/one_stage_detector.py index c1d0a52f..0d1cd15d 100644 --- a/modelscope/models/cv/face_human_hand_detection/one_stage_detector.py +++ b/modelscope/models/cv/face_human_hand_detection/one_stage_detector.py @@ -56,9 +56,6 @@ class OneStageDetector(nn.Module): def inference(self, meta): with torch.no_grad(): - torch.cuda.synchronize() preds = self(meta['img']) - torch.cuda.synchronize() results = self.head.post_process(preds, meta) - torch.cuda.synchronize() return results diff --git a/modelscope/models/cv/hand_2d_keypoints/__init__.py b/modelscope/models/cv/hand_2d_keypoints/__init__.py new file mode 100644 index 00000000..2b06f19a --- /dev/null +++ b/modelscope/models/cv/hand_2d_keypoints/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .hand_2d_keypoints import Hand2dKeyPoints + +else: + _import_structure = {'hand_2d_keypoints': ['Hand2dKeyPoints']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/hand_2d_keypoints/hand_2d_keypoints.py b/modelscope/models/cv/hand_2d_keypoints/hand_2d_keypoints.py new file mode 100644 index 00000000..15a97c30 --- /dev/null +++ b/modelscope/models/cv/hand_2d_keypoints/hand_2d_keypoints.py @@ -0,0 +1,16 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from easycv.models.pose import TopDown + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.models.cv.easycv_base import EasyCVBaseModel +from modelscope.utils.constant import Tasks + + +@MODELS.register_module( + group_key=Tasks.hand_2d_keypoints, module_name=Models.hand_2d_keypoints) +class Hand2dKeyPoints(EasyCVBaseModel, TopDown): + + def __init__(self, model_dir=None, *args, **kwargs): + EasyCVBaseModel.__init__(self, model_dir, args, kwargs) + TopDown.__init__(self, *args, **kwargs) diff --git a/modelscope/models/cv/hand_static/hand_model.py b/modelscope/models/cv/hand_static/hand_model.py index 38517307..7a8a323e 100644 --- a/modelscope/models/cv/hand_static/hand_model.py +++ b/modelscope/models/cv/hand_static/hand_model.py @@ -8,7 +8,7 @@ import torch import torch.nn.functional as F from PIL import Image from torch import nn -from torchvision.transforms import transforms +from torchvision import transforms from modelscope.metainfo import Models from modelscope.models.base import TorchModel @@ -80,9 +80,9 @@ class HandStatic(TorchModel): return pred_result -def infer(img_path, model, device): - - img = Image.open(img_path) +def infer(img, model, device): + img = img.cpu().numpy() + img = Image.fromarray(img) clip = spatial_transform(img) clip = clip.unsqueeze(0).to(device).float() outputs = model(clip) diff --git a/modelscope/models/cv/human_wholebody_keypoint/__init__.py b/modelscope/models/cv/human_wholebody_keypoint/__init__.py new file mode 100644 index 00000000..30e23457 --- /dev/null +++ b/modelscope/models/cv/human_wholebody_keypoint/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .human_wholebody_keypoint import HumanWholeBodyKeypoint + +else: + _import_structure = { + 'human_wholebody_keypoint': ['HumanWholeBodyKeypoint'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/human_wholebody_keypoint/human_wholebody_keypoint.py b/modelscope/models/cv/human_wholebody_keypoint/human_wholebody_keypoint.py new file mode 100644 index 00000000..dd3c0290 --- /dev/null +++ b/modelscope/models/cv/human_wholebody_keypoint/human_wholebody_keypoint.py @@ -0,0 +1,17 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from easycv.models.pose.top_down import TopDown + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.models.cv.easycv_base import EasyCVBaseModel +from modelscope.utils.constant import Tasks + + +@MODELS.register_module( + group_key=Tasks.human_wholebody_keypoint, + module_name=Models.human_wholebody_keypoint) +class HumanWholeBodyKeypoint(EasyCVBaseModel, TopDown): + + def __init__(self, model_dir=None, *args, **kwargs): + EasyCVBaseModel.__init__(self, model_dir, args, kwargs) + TopDown.__init__(self, *args, **kwargs) diff --git a/modelscope/models/cv/image_body_reshaping/__init__.py b/modelscope/models/cv/image_body_reshaping/__init__.py new file mode 100644 index 00000000..a04f110d --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .image_body_reshaping import ImageBodyReshaping + +else: + _import_structure = {'image_body_reshaping': ['ImageBodyReshaping']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_body_reshaping/image_body_reshaping.py b/modelscope/models/cv/image_body_reshaping/image_body_reshaping.py new file mode 100644 index 00000000..4aed8d98 --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/image_body_reshaping.py @@ -0,0 +1,128 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict + +import cv2 +import numpy as np +import torch + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .model import FlowGenerator +from .person_info import PersonInfo +from .pose_estimator.body import Body +from .slim_utils import image_warp_grid1, resize_on_long_side + +logger = get_logger() + +__all__ = ['ImageBodyReshaping'] + + +@MODELS.register_module( + Tasks.image_body_reshaping, module_name=Models.image_body_reshaping) +class ImageBodyReshaping(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the image body reshaping model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + + if torch.cuda.is_available(): + self.device = torch.device('cuda') + else: + self.device = torch.device('cpu') + + self.degree = 1.0 + self.reshape_model = FlowGenerator(n_channels=16).to(self.device) + model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE) + checkpoints = torch.load(model_path, map_location=torch.device('cpu')) + self.reshape_model.load_state_dict( + checkpoints['state_dict'], strict=True) + self.reshape_model.eval() + logger.info('load body reshaping model done') + + pose_model_ckpt = os.path.join(model_dir, 'body_pose_model.pth') + self.pose_esti = Body(pose_model_ckpt, self.device) + logger.info('load pose model done') + + def pred_joints(self, img): + if img is None: + return None + small_src, resize_scale = resize_on_long_side(img, 300) + body_joints = self.pose_esti(small_src) + + if body_joints.shape[0] >= 1: + body_joints[:, :, :2] = body_joints[:, :, :2] / resize_scale + + return body_joints + + def pred_flow(self, img): + + body_joints = self.pred_joints(img) + small_size = 1200 + + if img.shape[0] > small_size or img.shape[1] > small_size: + _img, _scale = resize_on_long_side(img, small_size) + body_joints[:, :, :2] = body_joints[:, :, :2] * _scale + else: + _img = img + + # We only reshape one person + if body_joints.shape[0] < 1 or body_joints.shape[0] > 1: + return None + + person = PersonInfo(body_joints[0]) + + with torch.no_grad(): + person_pred = person.pred_flow(_img, self.reshape_model, + self.device) + + flow = np.dstack((person_pred['rDx'], person_pred['rDy'])) + + scale = img.shape[0] * 1.0 / flow.shape[0] + + flow = cv2.resize(flow, (img.shape[1], img.shape[0])) + flow *= scale + + return flow + + def warp(self, src_img, flow): + + X_flow = flow[..., 0] + Y_flow = flow[..., 1] + + X_flow = np.ascontiguousarray(X_flow) + Y_flow = np.ascontiguousarray(Y_flow) + + pred = image_warp_grid1(X_flow, Y_flow, src_img, 1.0, 0, 0) + return pred + + def inference(self, img): + img = img.cpu().numpy() + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + flow = self.pred_flow(img) + + if flow is None: + return img + + assert flow.shape[:2] == img.shape[:2] + + mag, ang = cv2.cartToPolar(flow[..., 0] + 1e-8, flow[..., 1] + 1e-8) + mag -= 3 + mag[mag <= 0] = 0 + + x, y = cv2.polarToCart(mag, ang, angleInDegrees=False) + flow = np.dstack((x, y)) + + flow *= self.degree + pred = self.warp(img, flow) + out_img = np.clip(pred, 0, 255) + logger.info('model inference done') + + return out_img.astype(np.uint8) diff --git a/modelscope/models/cv/image_body_reshaping/model.py b/modelscope/models/cv/image_body_reshaping/model.py new file mode 100644 index 00000000..174428a1 --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/model.py @@ -0,0 +1,189 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ConvLayer(nn.Module): + + def __init__(self, in_ch, out_ch): + super(ConvLayer, self).__init__() + + self.conv = nn.Sequential( + nn.ReflectionPad2d(1), + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=0), + nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True)) + + def forward(self, x): + x = self.conv(x) + return x + + +class SASA(nn.Module): + + def __init__(self, in_dim): + super(SASA, self).__init__() + self.chanel_in = in_dim + + self.query_conv = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) + self.key_conv = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) + self.value_conv = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim, kernel_size=1) + self.mag_conv = nn.Conv2d( + in_channels=5, out_channels=in_dim // 32, kernel_size=1) + + self.gamma = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) # + self.sigmoid = nn.Sigmoid() + + def structure_encoder(self, paf_mag, target_height, target_width): + torso_mask = torch.sum(paf_mag[:, 1:3, :, :], dim=1, keepdim=True) + torso_mask = torch.clamp(torso_mask, 0, 1) + + arms_mask = torch.sum(paf_mag[:, 4:8, :, :], dim=1, keepdim=True) + arms_mask = torch.clamp(arms_mask, 0, 1) + + legs_mask = torch.sum(paf_mag[:, 8:12, :, :], dim=1, keepdim=True) + legs_mask = torch.clamp(legs_mask, 0, 1) + + fg_mask = paf_mag[:, 12, :, :].unsqueeze(1) + bg_mask = 1 - fg_mask + Y = torch.cat((arms_mask, torso_mask, legs_mask, fg_mask, bg_mask), + dim=1) + Y = F.interpolate(Y, size=(target_height, target_width), mode='area') + return Y + + def forward(self, X, PAF_mag): + """extract self-attention features. + Args: + X : input feature maps( B x C x H x W) + PAF_mag : ( B x C x H x W), 1 denotes connectivity, 0 denotes non-connectivity + + Returns: + out : self attention value + input feature + Y: B X N X N (N is Width*Height) + """ + + m_batchsize, C, height, width = X.size() + + Y = self.structure_encoder(PAF_mag, height, width) + + connectivity_mask_vec = self.mag_conv(Y).view(m_batchsize, -1, + width * height) + affinity = torch.bmm( + connectivity_mask_vec.permute(0, 2, 1), connectivity_mask_vec) + affinity_centered = affinity - torch.mean(affinity) + affinity_sigmoid = self.sigmoid(affinity_centered) + + proj_query = self.query_conv(X).view(m_batchsize, -1, + width * height).permute(0, 2, 1) + proj_key = self.key_conv(X).view(m_batchsize, -1, width * height) + selfatten_map = torch.bmm(proj_query, proj_key) + selfatten_centered = selfatten_map - torch.mean( + selfatten_map) # centering + selfatten_sigmoid = self.sigmoid(selfatten_centered) + + SASA_map = selfatten_sigmoid * affinity_sigmoid + + proj_value = self.value_conv(X).view(m_batchsize, -1, width * height) + + out = torch.bmm(proj_value, SASA_map.permute(0, 2, 1)) + out = out.view(m_batchsize, C, height, width) + + out = self.gamma * out + X + return out, Y + + +class FlowGenerator(nn.Module): + + def __init__(self, n_channels, deep_supervision=False): + super(FlowGenerator, self).__init__() + self.deep_supervision = deep_supervision + + self.Encoder = nn.Sequential( + ConvLayer(n_channels, 64), + ConvLayer(64, 64), + nn.MaxPool2d(2), + ConvLayer(64, 128), + ConvLayer(128, 128), + nn.MaxPool2d(2), + ConvLayer(128, 256), + ConvLayer(256, 256), + nn.MaxPool2d(2), + ConvLayer(256, 512), + ConvLayer(512, 512), + nn.MaxPool2d(2), + ConvLayer(512, 1024), + ConvLayer(1024, 1024), + ConvLayer(1024, 1024), + ConvLayer(1024, 1024), + ConvLayer(1024, 1024), + ) + + self.SASA = SASA(in_dim=1024) + + self.Decoder = nn.Sequential( + ConvLayer(1024, 1024), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + ConvLayer(1024, 512), + ConvLayer(512, 512), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + ConvLayer(512, 256), + ConvLayer(256, 256), + ConvLayer(256, 128), + ConvLayer(128, 64), + ConvLayer(64, 32), + nn.Conv2d(32, 2, kernel_size=1, padding=0), + nn.Tanh(), + nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True), + ) + + dilation_ksize = 17 + self.dilation = torch.nn.MaxPool2d( + kernel_size=dilation_ksize, + stride=1, + padding=int((dilation_ksize - 1) / 2)) + + def warp(self, x, flow, mode='bilinear', padding_mode='zeros', coff=0.2): + n, c, h, w = x.size() + yv, xv = torch.meshgrid([torch.arange(h), torch.arange(w)]) + xv = xv.float() / (w - 1) * 2.0 - 1 + yv = yv.float() / (h - 1) * 2.0 - 1 + grid = torch.cat((xv.unsqueeze(-1), yv.unsqueeze(-1)), -1).unsqueeze(0) + grid = grid.to(flow.device) + grid_x = grid + 2 * flow * coff + warp_x = F.grid_sample(x, grid_x, mode=mode, padding_mode=padding_mode) + return warp_x + + def forward(self, img, skeleton_map, coef=0.2): + """extract self-attention features. + Args: + img : input numpy image + skeleton_map : skeleton map of input image + coef: warp degree + + Returns: + warp_x : warped image + flow: predicted flow + """ + + img_concat = torch.cat((img, skeleton_map), dim=1) + X = self.Encoder(img_concat) + + _, _, height, width = X.size() + + # directly get PAF magnitude from skeleton maps via dilation + PAF_mag = self.dilation((skeleton_map + 1.0) * 0.5) + + out, Y = self.SASA(X, PAF_mag) + flow = self.Decoder(out) + + flow = flow.permute(0, 2, 3, 1) # [n, 2, h, w] ==> [n, h, w, 2] + + warp_x = self.warp(img, flow, coff=coef) + warp_x = torch.clamp(warp_x, min=-1.0, max=1.0) + + return warp_x, flow diff --git a/modelscope/models/cv/image_body_reshaping/person_info.py b/modelscope/models/cv/image_body_reshaping/person_info.py new file mode 100644 index 00000000..509a2ce3 --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/person_info.py @@ -0,0 +1,339 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import copy + +import cv2 +import numpy as np +import torch + +from .slim_utils import (enlarge_box_tblr, gen_skeleton_map, + get_map_fusion_map_cuda, get_mask_bbox, + resize_on_long_side) + + +class PersonInfo(object): + + def __init__(self, joints): + self.joints = joints + self.flow = None + self.pad_boder = False + self.height_expand = 0 + self.width_expand = 0 + self.coeff = 0.2 + self.network_input_W = 256 + self.network_input_H = 256 + self.divider = 20 + self.flow_scales = ['upper_2'] + + def update_attribute(self, pad_boder, height_expand, width_expand): + self.pad_boder = pad_boder + self.height_expand = height_expand + self.width_expand = width_expand + if pad_boder: + self.joints[:, 0] += width_expand + self.joints[:, 1] += height_expand + + def pred_flow(self, img, flow_net, device): + with torch.no_grad(): + if img is None: + print('image is none') + self.flow = None + + if len(img.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if self.pad_boder: + height_expand = self.height_expand + width_expand = self.width_expand + pad_img = cv2.copyMakeBorder( + img, + height_expand, + height_expand, + width_expand, + width_expand, + cv2.BORDER_CONSTANT, + value=(127, 127, 127)) + + else: + height_expand = 0 + width_expand = 0 + pad_img = img.copy() + + canvas = np.zeros( + shape=(pad_img.shape[0], pad_img.shape[1]), dtype=np.float32) + + self.human_joint_box = self.__joint_to_body_box() + + self.human_box = enlarge_box_tblr( + self.human_joint_box, pad_img, ratio=0.25) + human_box_height = self.human_box[1] - self.human_box[0] + human_box_width = self.human_box[3] - self.human_box[2] + + self.leg_joint_box = self.__joint_to_leg_box() + self.leg_box = enlarge_box_tblr( + self.leg_joint_box, pad_img, ratio=0.25) + + self.arm_joint_box = self.__joint_to_arm_box() + self.arm_box = enlarge_box_tblr( + self.arm_joint_box, pad_img, ratio=0.1) + + x_flows = [] + y_flows = [] + multi_bbox = [] + + for scale in self.flow_scales: # better for metric + scale_value = float(scale.split('_')[-1]) + + arm_box = copy.deepcopy(self.arm_box) + + if arm_box[0] is None: + arm_box = self.human_box + + arm_box_height = arm_box[1] - arm_box[0] + arm_box_width = arm_box[3] - arm_box[2] + + roi_bbox = None + + if arm_box_width < human_box_width * 0.1 or arm_box_height < human_box_height * 0.1: + roi_bbox = self.human_box + else: + arm_box = enlarge_box_tblr( + arm_box, pad_img, ratio=scale_value) + if scale == 'upper_0.2': + arm_box[0] = min(arm_box[0], int(self.joints[0][1])) + if scale.startswith('upper'): + roi_bbox = [ + max(self.human_box[0], arm_box[0]), + min(self.human_box[1], arm_box[1]), + max(self.human_box[2], arm_box[2]), + min(self.human_box[3], arm_box[3]) + ] + if roi_bbox[1] - roi_bbox[0] < 1 or roi_bbox[ + 3] - roi_bbox[2] < 1: + continue + + elif scale.startswith('lower'): + roi_bbox = [ + max(self.human_box[0], self.leg_box[0]), + min(self.human_box[1], self.leg_box[1]), + max(self.human_box[2], self.leg_box[2]), + min(self.human_box[3], self.leg_box[3]) + ] + + if roi_bbox[1] - roi_bbox[0] < 1 or roi_bbox[ + 3] - roi_bbox[2] < 1: + continue + + skel_map, roi_bbox = gen_skeleton_map( + self.joints, 'depth', input_roi_box=roi_bbox) + + if roi_bbox is None: + continue + + if skel_map.dtype != np.float32: + skel_map = skel_map.astype(np.float32) + + skel_map -= 1.0 # [0,2] ->[-1,1] + + multi_bbox.append(roi_bbox) + + roi_bbox_height = roi_bbox[1] - roi_bbox[0] + roi_bbox_width = roi_bbox[3] - roi_bbox[2] + + assert skel_map.shape[0] == roi_bbox_height + assert skel_map.shape[1] == roi_bbox_width + roi_height_pad = roi_bbox_height // self.divider + roi_width_pad = roi_bbox_width // self.divider + paded_roi_h = roi_bbox_height + 2 * roi_height_pad + paded_roi_w = roi_bbox_width + 2 * roi_width_pad + + roi_height_pad_joint = skel_map.shape[0] // self.divider + roi_width_pad_joint = skel_map.shape[1] // self.divider + skel_map = np.pad( + skel_map, + ((roi_height_pad_joint, roi_height_pad_joint), + (roi_width_pad_joint, roi_width_pad_joint), (0, 0)), + 'constant', + constant_values=-1) + + skel_map_resized = cv2.resize( + skel_map, (self.network_input_W, self.network_input_H)) + + skel_map_resized[skel_map_resized < 0] = -1.0 + skel_map_resized[skel_map_resized > -0.5] = 1.0 + skel_map_transformed = torch.from_numpy( + skel_map_resized.transpose((2, 0, 1))) + + roi_npy = pad_img[roi_bbox[0]:roi_bbox[1], + roi_bbox[2]:roi_bbox[3], :].copy() + if roi_npy.dtype != np.float32: + roi_npy = roi_npy.astype(np.float32) + + roi_npy = np.pad(roi_npy, + ((roi_height_pad, roi_height_pad), + (roi_width_pad, roi_width_pad), (0, 0)), + 'edge') + + roi_npy = roi_npy[:, :, ::-1] + + roi_npy = cv2.resize( + roi_npy, (self.network_input_W, self.network_input_H)) + + roi_npy *= 1.0 / 255 + roi_npy -= 0.5 + roi_npy *= 2 + + rgb_tensor = torch.from_numpy(roi_npy.transpose((2, 0, 1))) + + rgb_tensor = rgb_tensor.unsqueeze(0).to(device) + skel_map_tensor = skel_map_transformed.unsqueeze(0).to(device) + warped_img_val, flow_field_val = flow_net( + rgb_tensor, skel_map_tensor + ) # inference, connectivity_mask [1,12,16,16] + flow_field_val = flow_field_val.detach().squeeze().cpu().numpy( + ) + + flow_field_val = cv2.resize( + flow_field_val, (paded_roi_w, paded_roi_h), + interpolation=cv2.INTER_LINEAR) + flow_field_val[..., 0] = flow_field_val[ + ..., 0] * paded_roi_w * 0.5 * 2 * self.coeff + flow_field_val[..., 1] = flow_field_val[ + ..., 1] * paded_roi_h * 0.5 * 2 * self.coeff + + # remove pad areas + flow_field_val = flow_field_val[ + roi_height_pad:flow_field_val.shape[0] - roi_height_pad, + roi_width_pad:flow_field_val.shape[1] - roi_width_pad, :] + + diffuse_width = max(roi_bbox_width // 3, 1) + diffuse_height = max(roi_bbox_height // 3, 1) + assert roi_bbox_width == flow_field_val.shape[1] + assert roi_bbox_height == flow_field_val.shape[0] + + origin_flow = np.zeros( + (pad_img.shape[0] + 2 * diffuse_height, + pad_img.shape[1] + 2 * diffuse_width, 2), + dtype=np.float32) + + flow_field_val = np.pad(flow_field_val, + ((diffuse_height, diffuse_height), + (diffuse_width, diffuse_width), + (0, 0)), 'linear_ramp') + + origin_flow[roi_bbox[0]:roi_bbox[1] + 2 * diffuse_height, + roi_bbox[2]:roi_bbox[3] + + 2 * diffuse_width] = flow_field_val + + origin_flow = origin_flow[diffuse_height:-diffuse_height, + diffuse_width:-diffuse_width, :] + + x_flows.append(origin_flow[..., 0]) + y_flows.append(origin_flow[..., 1]) + + if len(x_flows) == 0: + return { + 'rDx': np.zeros(canvas.shape[:2], dtype=np.float32), + 'rDy': np.zeros(canvas.shape[:2], dtype=np.float32), + 'multi_bbox': multi_bbox, + 'x_fusion_map': + np.ones(canvas.shape[:2], dtype=np.float32), + 'y_fusion_map': + np.ones(canvas.shape[:2], dtype=np.float32) + } + else: + origin_rDx, origin_rDy, x_fusion_map, y_fusion_map = self.blend_multiscale_flow( + x_flows, y_flows, device=device) + + return { + 'rDx': origin_rDx, + 'rDy': origin_rDy, + 'multi_bbox': multi_bbox, + 'x_fusion_map': x_fusion_map, + 'y_fusion_map': y_fusion_map + } + + @staticmethod + def blend_multiscale_flow(x_flows, y_flows, device=None): + scale_num = len(x_flows) + if scale_num == 1: + return x_flows[0], y_flows[0], np.ones_like( + x_flows[0]), np.ones_like(x_flows[0]) + + origin_rDx = np.zeros((x_flows[0].shape[0], x_flows[0].shape[1]), + dtype=np.float32) + origin_rDy = np.zeros((y_flows[0].shape[0], y_flows[0].shape[1]), + dtype=np.float32) + + x_fusion_map, x_acc_map = get_map_fusion_map_cuda( + x_flows, 1, device=device) + y_fusion_map, y_acc_map = get_map_fusion_map_cuda( + y_flows, 1, device=device) + + x_flow_map = 1.0 / x_fusion_map + y_flow_map = 1.0 / y_fusion_map + + all_acc_map = x_acc_map + y_acc_map + all_acc_map = all_acc_map.astype(np.uint8) + roi_box = get_mask_bbox(all_acc_map, threshold=1) + + if roi_box[0] is None or roi_box[1] - roi_box[0] <= 0 or roi_box[ + 3] - roi_box[2] <= 0: + roi_box = [0, x_flow_map.shape[0], 0, x_flow_map.shape[1]] + + roi_x_flow_map = x_flow_map[roi_box[0]:roi_box[1], + roi_box[2]:roi_box[3]] + roi_y_flow_map = y_flow_map[roi_box[0]:roi_box[1], + roi_box[2]:roi_box[3]] + + roi_width = roi_x_flow_map.shape[1] + roi_height = roi_x_flow_map.shape[0] + + roi_x_flow_map, scale = resize_on_long_side(roi_x_flow_map, 320) + roi_y_flow_map, scale = resize_on_long_side(roi_y_flow_map, 320) + + roi_x_flow_map = cv2.blur(roi_x_flow_map, (55, 55)) + roi_y_flow_map = cv2.blur(roi_y_flow_map, (55, 55)) + + roi_x_flow_map = cv2.resize(roi_x_flow_map, (roi_width, roi_height)) + roi_y_flow_map = cv2.resize(roi_y_flow_map, (roi_width, roi_height)) + + x_flow_map[roi_box[0]:roi_box[1], + roi_box[2]:roi_box[3]] = roi_x_flow_map + y_flow_map[roi_box[0]:roi_box[1], + roi_box[2]:roi_box[3]] = roi_y_flow_map + + for i in range(scale_num): + origin_rDx += x_flows[i] + origin_rDy += y_flows[i] + + origin_rDx *= x_flow_map + origin_rDy *= y_flow_map + + return origin_rDx, origin_rDy, x_flow_map, y_flow_map + + def __joint_to_body_box(self): + joint_left = int(np.min(self.joints, axis=0)[0]) + joint_right = int(np.max(self.joints, axis=0)[0]) + joint_top = int(np.min(self.joints, axis=0)[1]) + joint_bottom = int(np.max(self.joints, axis=0)[1]) + return [joint_top, joint_bottom, joint_left, joint_right] + + def __joint_to_leg_box(self): + leg_joints = self.joints[8:, :] + if np.max(leg_joints, axis=0)[2] < 0.05: + return [0, 0, 0, 0] + joint_left = int(np.min(leg_joints, axis=0)[0]) + joint_right = int(np.max(leg_joints, axis=0)[0]) + joint_top = int(np.min(leg_joints, axis=0)[1]) + joint_bottom = int(np.max(leg_joints, axis=0)[1]) + return [joint_top, joint_bottom, joint_left, joint_right] + + def __joint_to_arm_box(self): + arm_joints = self.joints[2:8, :] + if np.max(arm_joints, axis=0)[2] < 0.05: + return [0, 0, 0, 0] + joint_left = int(np.min(arm_joints, axis=0)[0]) + joint_right = int(np.max(arm_joints, axis=0)[0]) + joint_top = int(np.min(arm_joints, axis=0)[1]) + joint_bottom = int(np.max(arm_joints, axis=0)[1]) + return [joint_top, joint_bottom, joint_left, joint_right] diff --git a/modelscope/models/nlp/star3/__init__.py b/modelscope/models/cv/image_body_reshaping/pose_estimator/__init__.py similarity index 100% rename from modelscope/models/nlp/star3/__init__.py rename to modelscope/models/cv/image_body_reshaping/pose_estimator/__init__.py diff --git a/modelscope/models/cv/image_body_reshaping/pose_estimator/body.py b/modelscope/models/cv/image_body_reshaping/pose_estimator/body.py new file mode 100644 index 00000000..45b02724 --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/pose_estimator/body.py @@ -0,0 +1,272 @@ +# The implementation is based on openpose, available at https://github.com/Hzzone/pytorch-openpose. + +import math + +import cv2 +import numpy as np +import torch +from scipy.ndimage.filters import gaussian_filter + +from .model import BodyposeModel +from .util import pad_rightdown_corner, transfer + + +class Body(object): + + def __init__(self, model_path, device): + self.model = BodyposeModel().to(device) + model_dict = transfer(self.model, torch.load(model_path)) + self.model.load_state_dict(model_dict) + self.model.eval() + + def __call__(self, oriImg): + scale_search = [0.5] + boxsize = 368 + stride = 8 + padValue = 128 + thre1 = 0.1 + thre2 = 0.05 + bodyparts = 18 + multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search] + heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19)) + paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38)) + + for m in range(len(multiplier)): + scale = multiplier[m] + imageToTest = cv2.resize( + oriImg, (0, 0), + fx=scale, + fy=scale, + interpolation=cv2.INTER_CUBIC) + imageToTest_padded, pad = pad_rightdown_corner( + imageToTest, stride, padValue) + im = np.transpose( + np.float32(imageToTest_padded[:, :, :, np.newaxis]), + (3, 2, 0, 1)) / 256 - 0.5 + im = np.ascontiguousarray(im) + + data = torch.from_numpy(im).float() + if torch.cuda.is_available(): + data = data.cuda() + with torch.no_grad(): + Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data) + Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy() + Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy() + + # extract outputs, resize, and remove padding + heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), + (1, 2, 0)) # output 1 is heatmaps + heatmap = cv2.resize( + heatmap, (0, 0), + fx=stride, + fy=stride, + interpolation=cv2.INTER_CUBIC) + heatmap = heatmap[:imageToTest_padded.shape[0] + - pad[2], :imageToTest_padded.shape[1] + - pad[3], :] + heatmap = cv2.resize( + heatmap, (oriImg.shape[1], oriImg.shape[0]), + interpolation=cv2.INTER_CUBIC) + + paf = np.transpose(np.squeeze(Mconv7_stage6_L1), + (1, 2, 0)) # output 0 is PAFs + paf = cv2.resize( + paf, (0, 0), + fx=stride, + fy=stride, + interpolation=cv2.INTER_CUBIC) + paf = paf[:imageToTest_padded.shape[0] + - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + paf = cv2.resize( + paf, (oriImg.shape[1], oriImg.shape[0]), + interpolation=cv2.INTER_CUBIC) + + heatmap_avg += heatmap_avg + heatmap / len(multiplier) + paf_avg += +paf / len(multiplier) + + all_peaks = [] + peak_counter = 0 + + for part in range(bodyparts): + map_ori = heatmap_avg[:, :, part] + one_heatmap = gaussian_filter(map_ori, sigma=3) + + map_left = np.zeros(one_heatmap.shape) + map_left[1:, :] = one_heatmap[:-1, :] + map_right = np.zeros(one_heatmap.shape) + map_right[:-1, :] = one_heatmap[1:, :] + map_up = np.zeros(one_heatmap.shape) + map_up[:, 1:] = one_heatmap[:, :-1] + map_down = np.zeros(one_heatmap.shape) + map_down[:, :-1] = one_heatmap[:, 1:] + + peaks_binary = np.logical_and.reduce( + (one_heatmap >= map_left, one_heatmap >= map_right, + one_heatmap >= map_up, one_heatmap >= map_down, + one_heatmap > thre1)) + peaks = list( + zip(np.nonzero(peaks_binary)[1], + np.nonzero(peaks_binary)[0])) # note reverse + peaks_with_score = [x + (map_ori[x[1], x[0]], ) for x in peaks] + peak_id = range(peak_counter, peak_counter + len(peaks)) + peaks_with_score_and_id = [ + peaks_with_score[i] + (peak_id[i], ) + for i in range(len(peak_id)) + ] + + all_peaks.append(peaks_with_score_and_id) + peak_counter += len(peaks) + + # find connection in the specified sequence, center 29 is in the position 15 + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], + [9, 10], [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], + [1, 15], [15, 17], [1, 16], [16, 18], [3, 17], [6, 18]] + # the middle joints heatmap correpondence + mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], + [19, 20], [21, 22], [23, 24], [25, 26], [27, 28], [29, 30], + [47, 48], [49, 50], [53, 54], [51, 52], [55, 56], [37, 38], + [45, 46]] + + connection_all = [] + special_k = [] + mid_num = 10 + + for k in range(len(mapIdx)): + score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]] + candA = all_peaks[limbSeq[k][0] - 1] + candB = all_peaks[limbSeq[k][1] - 1] + nA = len(candA) + nB = len(candB) + if (nA != 0 and nB != 0): + connection_candidate = [] + for i in range(nA): + for j in range(nB): + vec = np.subtract(candB[j][:2], candA[i][:2]) + norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1]) + norm = max(0.001, norm) + vec = np.divide(vec, norm) + + startend = list( + zip( + np.linspace( + candA[i][0], candB[j][0], num=mid_num), + np.linspace( + candA[i][1], candB[j][1], num=mid_num))) + + vec_x = np.array([ + score_mid[int(round(startend[item][1])), + int(round(startend[item][0])), 0] + for item in range(len(startend)) + ]) + vec_y = np.array([ + score_mid[int(round(startend[item][1])), + int(round(startend[item][0])), 1] + for item in range(len(startend)) + ]) + + score_midpts = np.multiply( + vec_x, vec[0]) + np.multiply(vec_y, vec[1]) + temp1 = sum(score_midpts) / len(score_midpts) + temp2 = min(0.5 * oriImg.shape[0] / norm - 1, 0) + score_with_dist_prior = temp1 + temp2 + criterion1 = len(np.nonzero( + score_midpts > thre2)[0]) > 0.8 * len(score_midpts) + criterion2 = score_with_dist_prior > 0 + if criterion1 and criterion2: + connection_candidate.append([ + i, j, score_with_dist_prior, + score_with_dist_prior + candA[i][2] + + candB[j][2] + ]) + + connection_candidate = sorted( + connection_candidate, key=lambda x: x[2], reverse=True) + connection = np.zeros((0, 5)) + for c in range(len(connection_candidate)): + i, j, s = connection_candidate[c][0:3] + if (i not in connection[:, 3] + and j not in connection[:, 4]): + connection = np.vstack( + [connection, [candA[i][3], candB[j][3], s, i, j]]) + if (len(connection) >= min(nA, nB)): + break + + connection_all.append(connection) + else: + special_k.append(k) + connection_all.append([]) + + # last number in each row is the total parts number of that person + # the second last number in each row is the score of the overall configuration + subset = -1 * np.ones((0, 20)) + candidate = np.array( + [item for sublist in all_peaks for item in sublist]) + + for k in range(len(mapIdx)): + if k not in special_k: + partAs = connection_all[k][:, 0] + partBs = connection_all[k][:, 1] + indexA, indexB = np.array(limbSeq[k]) - 1 + + for i in range(len(connection_all[k])): # = 1:size(temp,1) + found = 0 + subset_idx = [-1, -1] + for j in range(len(subset)): # 1:size(subset,1): + if subset[j][indexA] == partAs[i] or subset[j][ + indexB] == partBs[i]: + subset_idx[found] = j + found += 1 + + if found == 1: + j = subset_idx[0] + if subset[j][indexB] != partBs[i]: + subset[j][indexB] = partBs[i] + subset[j][-1] += 1 + subset[j][-2] += candidate[ + partBs[i].astype(int), + 2] + connection_all[k][i][2] + elif found == 2: # if found 2 and disjoint, merge them + j1, j2 = subset_idx + tmp1 = (subset[j1] >= 0).astype(int) + tmp2 = (subset[j2] >= 0).astype(int) + membership = (tmp1 + tmp2)[:-2] + if len(np.nonzero(membership == 2)[0]) == 0: # merge + subset[j1][:-2] += (subset[j2][:-2] + 1) + subset[j1][-2:] += subset[j2][-2:] + subset[j1][-2] += connection_all[k][i][2] + subset = np.delete(subset, j2, 0) + else: # as like found == 1 + subset[j1][indexB] = partBs[i] + subset[j1][-1] += 1 + subset[j1][-2] += candidate[ + partBs[i].astype(int), + 2] + connection_all[k][i][2] + + # if find no partA in the subset, create a new subset + elif not found and k < 17: + row = -1 * np.ones(20) + row[indexA] = partAs[i] + row[indexB] = partBs[i] + row[-1] = 2 + row[-2] = sum( + candidate[connection_all[k][i, :2].astype(int), + 2]) + connection_all[k][i][2] + subset = np.vstack([subset, row]) + # delete some rows of subset which has few parts occur + deleteIdx = [] + for i in range(len(subset)): + if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4: + deleteIdx.append(i) + subset = np.delete(subset, deleteIdx, axis=0) + + # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts + # candidate: x, y, score, id + count = subset.shape[0] + joints = np.zeros(shape=(count, bodyparts, 3)) + + for i in range(count): + for j in range(bodyparts): + joints[i, j, :3] = candidate[int(subset[i, j]), :3] + confidence = 1.0 if subset[i, j] >= 0 else 0.0 + joints[i, j, 2] *= confidence + return joints diff --git a/modelscope/models/cv/image_body_reshaping/pose_estimator/model.py b/modelscope/models/cv/image_body_reshaping/pose_estimator/model.py new file mode 100644 index 00000000..12f6e84d --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/pose_estimator/model.py @@ -0,0 +1,141 @@ +# The implementation is based on openpose, available at https://github.com/Hzzone/pytorch-openpose. + +from collections import OrderedDict + +import torch +import torch.nn as nn + + +def make_layers(block, no_relu_layers): + layers = [] + for layer_name, v in block.items(): + if 'pool' in layer_name: + layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2]) + layers.append((layer_name, layer)) + else: + conv2d = nn.Conv2d( + in_channels=v[0], + out_channels=v[1], + kernel_size=v[2], + stride=v[3], + padding=v[4]) + layers.append((layer_name, conv2d)) + if layer_name not in no_relu_layers: + layers.append(('relu_' + layer_name, nn.ReLU(inplace=True))) + + return nn.Sequential(OrderedDict(layers)) + + +class BodyposeModel(nn.Module): + + def __init__(self): + super(BodyposeModel, self).__init__() + + # these layers have no relu layer + no_relu_layers = [ + 'conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1', + 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2', + 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1', + 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1' + ] + blocks = {} + block0 = OrderedDict([('conv1_1', [3, 64, 3, 1, 1]), + ('conv1_2', [64, 64, 3, 1, 1]), + ('pool1_stage1', [2, 2, 0]), + ('conv2_1', [64, 128, 3, 1, 1]), + ('conv2_2', [128, 128, 3, 1, 1]), + ('pool2_stage1', [2, 2, 0]), + ('conv3_1', [128, 256, 3, 1, 1]), + ('conv3_2', [256, 256, 3, 1, 1]), + ('conv3_3', [256, 256, 3, 1, 1]), + ('conv3_4', [256, 256, 3, 1, 1]), + ('pool3_stage1', [2, 2, 0]), + ('conv4_1', [256, 512, 3, 1, 1]), + ('conv4_2', [512, 512, 3, 1, 1]), + ('conv4_3_CPM', [512, 256, 3, 1, 1]), + ('conv4_4_CPM', [256, 128, 3, 1, 1])]) + + # Stage 1 + block1_1 = OrderedDict([('conv5_1_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]), + ('conv5_5_CPM_L1', [512, 38, 1, 1, 0])]) + + block1_2 = OrderedDict([('conv5_1_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]), + ('conv5_5_CPM_L2', [512, 19, 1, 1, 0])]) + blocks['block1_1'] = block1_1 + blocks['block1_2'] = block1_2 + + self.model0 = make_layers(block0, no_relu_layers) + + # Stages 2 - 6 + for i in range(2, 7): + blocks['block%d_1' % i] = OrderedDict([ + ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]), + ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0]) + ]) + + blocks['block%d_2' % i] = OrderedDict([ + ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]), + ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0]) + ]) + + for k in blocks.keys(): + blocks[k] = make_layers(blocks[k], no_relu_layers) + + self.model1_1 = blocks['block1_1'] + self.model2_1 = blocks['block2_1'] + self.model3_1 = blocks['block3_1'] + self.model4_1 = blocks['block4_1'] + self.model5_1 = blocks['block5_1'] + self.model6_1 = blocks['block6_1'] + + self.model1_2 = blocks['block1_2'] + self.model2_2 = blocks['block2_2'] + self.model3_2 = blocks['block3_2'] + self.model4_2 = blocks['block4_2'] + self.model5_2 = blocks['block5_2'] + self.model6_2 = blocks['block6_2'] + + def forward(self, x): + + out1 = self.model0(x) + + out1_1 = self.model1_1(out1) + out1_2 = self.model1_2(out1) + out2 = torch.cat([out1_1, out1_2, out1], 1) + + out2_1 = self.model2_1(out2) + out2_2 = self.model2_2(out2) + out3 = torch.cat([out2_1, out2_2, out1], 1) + + out3_1 = self.model3_1(out3) + out3_2 = self.model3_2(out3) + out4 = torch.cat([out3_1, out3_2, out1], 1) + + out4_1 = self.model4_1(out4) + out4_2 = self.model4_2(out4) + out5 = torch.cat([out4_1, out4_2, out1], 1) + + out5_1 = self.model5_1(out5) + out5_2 = self.model5_2(out5) + out6 = torch.cat([out5_1, out5_2, out1], 1) + + out6_1 = self.model6_1(out6) + out6_2 = self.model6_2(out6) + + return out6_1, out6_2 diff --git a/modelscope/models/cv/image_body_reshaping/pose_estimator/util.py b/modelscope/models/cv/image_body_reshaping/pose_estimator/util.py new file mode 100644 index 00000000..13a42074 --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/pose_estimator/util.py @@ -0,0 +1,33 @@ +# The implementation is based on openpose, available at https://github.com/Hzzone/pytorch-openpose. +import numpy as np + + +def pad_rightdown_corner(img, stride, padValue): + h = img.shape[0] + w = img.shape[1] + + pad = 4 * [None] + pad[0] = 0 # up + pad[1] = 0 # left + pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down + pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right + + img_padded = img + pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1)) + img_padded = np.concatenate((pad_up, img_padded), axis=0) + pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1)) + img_padded = np.concatenate((pad_left, img_padded), axis=1) + pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1)) + img_padded = np.concatenate((img_padded, pad_down), axis=0) + pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1)) + img_padded = np.concatenate((img_padded, pad_right), axis=1) + + return img_padded, pad + + +def transfer(model, model_weights): + transfered_model_weights = {} + for weights_name in model.state_dict().keys(): + transfered_model_weights[weights_name] = model_weights['.'.join( + weights_name.split('.')[1:])] + return transfered_model_weights diff --git a/modelscope/models/cv/image_body_reshaping/slim_utils.py b/modelscope/models/cv/image_body_reshaping/slim_utils.py new file mode 100644 index 00000000..23d5a741 --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/slim_utils.py @@ -0,0 +1,507 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import os +import random + +import cv2 +import numba +import numpy as np +import torch + + +def resize_on_long_side(img, long_side=800): + src_height = img.shape[0] + src_width = img.shape[1] + + if src_height > src_width: + scale = long_side * 1.0 / src_height + _img = cv2.resize( + img, (int(src_width * scale), long_side), + interpolation=cv2.INTER_LINEAR) + else: + scale = long_side * 1.0 / src_width + _img = cv2.resize( + img, (long_side, int(src_height * scale)), + interpolation=cv2.INTER_LINEAR) + + return _img, scale + + +def point_in_box(pt, box): + pt_x = pt[0] + pt_y = pt[1] + + if pt_x >= box[0] and pt_x <= box[0] + box[2] and pt_y >= box[ + 1] and pt_y <= box[1] + box[3]: + return True + else: + return False + + +def enlarge_box_tblr(roi_bbox, mask, ratio=0.4, use_long_side=True): + if roi_bbox is None or None in roi_bbox: + return [None, None, None, None] + + top = roi_bbox[0] + bottom = roi_bbox[1] + left = roi_bbox[2] + right = roi_bbox[3] + + roi_width = roi_bbox[3] - roi_bbox[2] + roi_height = roi_bbox[1] - roi_bbox[0] + right = left + roi_width + bottom = top + roi_height + + long_side = roi_width if roi_width > roi_height else roi_height + + if use_long_side: + new_left = left - int(long_side * ratio) + else: + new_left = left - int(roi_width * ratio) + new_left = 1 if new_left < 0 else new_left + + if use_long_side: + new_top = top - int(long_side * ratio) + else: + new_top = top - int(roi_height * ratio) + new_top = 1 if new_top < 0 else new_top + + if use_long_side: + new_right = right + int(long_side * ratio) + else: + new_right = right + int(roi_width * ratio) + new_right = mask.shape[1] - 2 if new_right > mask.shape[1] else new_right + + if use_long_side: + new_bottom = bottom + int(long_side * ratio) + else: + new_bottom = bottom + int(roi_height * ratio) + new_bottom = mask.shape[0] - 2 if new_bottom > mask.shape[0] else new_bottom + + bbox = [new_top, new_bottom, new_left, new_right] + return bbox + + +def gen_PAF(image, joints): + + assert joints.shape[0] == 18 + assert joints.shape[1] == 3 + + org_h = image.shape[0] + org_w = image.shape[1] + small_image, resize_scale = resize_on_long_side(image, 120) + + joints[:, :2] = joints[:, :2] * resize_scale + + joint_left = int(np.min(joints, axis=0)[0]) + joint_right = int(np.max(joints, axis=0)[0]) + joint_top = int(np.min(joints, axis=0)[1]) + joint_bottom = int(np.max(joints, axis=0)[1]) + + limb_width = min( + abs(joint_right - joint_left), abs(joint_bottom - joint_top)) // 6 + + if limb_width % 2 == 0: + limb_width += 1 + kernel_size = limb_width + + part_orders = [(5, 11), (2, 8), (5, 6), (6, 7), (2, 3), (3, 4), (11, 12), + (12, 13), (8, 9), (9, 10)] + + map_list = [] + mask_list = [] + PAF_all = np.zeros( + shape=(small_image.shape[0], small_image.shape[1], 2), + dtype=np.float32) + for c, pair in enumerate(part_orders): + idx_a_name = pair[0] + idx_b_name = pair[1] + + jointa = joints[idx_a_name] + jointb = joints[idx_b_name] + + confidence_threshold = 0.05 + if jointa[2] > confidence_threshold and jointb[ + 2] > confidence_threshold: + canvas = np.zeros( + shape=(small_image.shape[0], small_image.shape[1]), + dtype=np.uint8) + + canvas = cv2.line(canvas, (int(jointa[0]), int(jointa[1])), + (int(jointb[0]), int(jointb[1])), + (255, 255, 255), 5) + + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, + (kernel_size, kernel_size)) + + canvas = cv2.dilate(canvas, kernel, 1) + canvas = cv2.GaussianBlur(canvas, (kernel_size, kernel_size), 0) + canvas = canvas.astype(np.float32) / 255 + PAF = np.zeros( + shape=(small_image.shape[0], small_image.shape[1], 2), + dtype=np.float32) + PAF[..., 0] = jointb[0] - jointa[0] + PAF[..., 1] = jointb[1] - jointa[1] + mag, ang = cv2.cartToPolar(PAF[..., 0], PAF[..., 1]) + PAF /= (np.dstack((mag, mag)) + 1e-5) + + single_PAF = PAF * np.dstack((canvas, canvas)) + map_list.append( + cv2.GaussianBlur(single_PAF, + (kernel_size * 3, kernel_size * 3), 0)) + + mask_list.append( + cv2.GaussianBlur(canvas.copy(), + (kernel_size * 3, kernel_size * 3), 0)) + PAF_all = PAF_all * (1.0 - np.dstack( + (canvas, canvas))) + single_PAF + + PAF_all = cv2.GaussianBlur(PAF_all, (kernel_size * 3, kernel_size * 3), 0) + PAF_all = cv2.resize( + PAF_all, (org_w, org_h), interpolation=cv2.INTER_LINEAR) + map_list.append(PAF_all) + return PAF_all, map_list, mask_list + + +def gen_skeleton_map(joints, stack_mode='column', input_roi_box=None): + if type(joints) == list: + joints = np.array(joints) + assert stack_mode == 'column' or stack_mode == 'depth' + + part_orders = [(2, 5), (5, 11), (2, 8), (8, 11), (5, 6), (6, 7), (2, 3), + (3, 4), (11, 12), (12, 13), (8, 9), (9, 10)] + + def link(img, a, b, color, line_width, scale=1.0, x_offset=0, y_offset=0): + jointa = joints[a] + jointb = joints[b] + + temp1 = int((jointa[0] - x_offset) * scale) + temp2 = int((jointa[1] - y_offset) * scale) + temp3 = int((jointb[0] - x_offset) * scale) + temp4 = int((jointb[1] - y_offset) * scale) + + cv2.line(img, (temp1, temp2), (temp3, temp4), color, line_width) + + roi_box = input_roi_box + + roi_box_width = roi_box[3] - roi_box[2] + roi_box_height = roi_box[1] - roi_box[0] + short_side_length = min(roi_box_width, roi_box_height) + line_width = short_side_length // 30 + + line_width = max(line_width, 2) + + map_cube = np.zeros( + shape=(roi_box_height, roi_box_width, len(part_orders) + 1), + dtype=np.float32) + + use_line_width = min(5, line_width) + fx = use_line_width * 1.0 / line_width # fx 最大值为1 + + if fx < 0.99: + map_cube = cv2.resize(map_cube, (0, 0), fx=fx, fy=fx) + + for c, pair in enumerate(part_orders): + tmp = map_cube[..., c].copy() + link( + tmp, + pair[0], + pair[1], (2.0, 2.0, 2.0), + use_line_width, + scale=fx, + x_offset=roi_box[2], + y_offset=roi_box[0]) + map_cube[..., c] = tmp + + tmp = map_cube[..., -1].copy() + link( + tmp, + pair[0], + pair[1], (2.0, 2.0, 2.0), + use_line_width, + scale=fx, + x_offset=roi_box[2], + y_offset=roi_box[0]) + map_cube[..., -1] = tmp + + map_cube = cv2.resize(map_cube, (roi_box_width, roi_box_height)) + + if stack_mode == 'depth': + return map_cube, roi_box + elif stack_mode == 'column': + joint_maps = [] + for c in range(len(part_orders) + 1): + joint_maps.append(map_cube[..., c]) + joint_map = np.column_stack(joint_maps) + + return joint_map, roi_box + + +def plot_one_box(x, img, color=None, label=None, line_thickness=None): + tl = line_thickness or round( + 0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness + color = color or [random.randint(0, 255) for _ in range(3)] + c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) + cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA) + if label: + tf = max(tl - 1, 1) # font thickness + t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] + c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 + cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled + cv2.putText( + img, + label, (c1[0], c1[1] - 2), + 0, + tl / 3, [225, 255, 255], + thickness=tf, + lineType=cv2.LINE_AA) + + +def draw_line(im, points, color, stroke_size=2, closed=False): + points = points.astype(np.int32) + for i in range(len(points) - 1): + cv2.line(im, tuple(points[i]), tuple(points[i + 1]), color, + stroke_size) + if closed: + cv2.line(im, tuple(points[0]), tuple(points[-1]), color, stroke_size) + + +def enlarged_bbox(bbox, img_width, img_height, enlarge_ratio=0.2): + left = bbox[0] + top = bbox[1] + + right = bbox[2] + bottom = bbox[3] + + roi_width = right - left + roi_height = bottom - top + + new_left = left - int(roi_width * enlarge_ratio) + new_left = 0 if new_left < 0 else new_left + + new_top = top - int(roi_height * enlarge_ratio) + new_top = 0 if new_top < 0 else new_top + + new_right = right + int(roi_width * enlarge_ratio) + new_right = img_width if new_right > img_width else new_right + + new_bottom = bottom + int(roi_height * enlarge_ratio) + new_bottom = img_height if new_bottom > img_height else new_bottom + + bbox = [new_left, new_top, new_right, new_bottom] + + bbox = [int(x) for x in bbox] + + return bbox + + +def get_map_fusion_map_cuda(map_list, threshold=1, device=torch.device('cpu')): + map_list_cuda = [torch.from_numpy(x).to(device) for x in map_list] + map_concat = torch.stack(tuple(map_list_cuda), dim=-1) + + map_concat = torch.abs(map_concat) + + map_concat[map_concat < threshold] = 0 + map_concat[map_concat > 1e-5] = 1.0 + + sum_map = torch.sum(map_concat, dim=2) + a = torch.ones_like(sum_map) + acc_map = torch.where(sum_map > 0, a * 2.0, torch.zeros_like(sum_map)) + + fusion_map = torch.where(sum_map < 0.5, a * 1.5, sum_map) + + fusion_map = fusion_map.float() + acc_map = acc_map.float() + + fusion_map = fusion_map.cpu().numpy().astype(np.float32) + acc_map = acc_map.cpu().numpy().astype(np.float32) + + return fusion_map, acc_map + + +def gen_border_shade(height, width, height_band, width_band): + height_ratio = height_band * 1.0 / height + width_ratio = width_band * 1.0 / width + + _height_band = int(256 * height_ratio) + _width_band = int(256 * width_ratio) + + canvas = np.zeros((256, 256), dtype=np.float32) + + canvas[_height_band // 2:-_height_band // 2, + _width_band // 2:-_width_band // 2] = 1.0 + + canvas = cv2.blur(canvas, (_height_band, _width_band)) + + canvas = cv2.resize(canvas, (width, height)) + + return canvas + + +def get_mask_bbox(mask, threshold=127): + ret, mask = cv2.threshold(mask, threshold, 1, 0) + + if cv2.countNonZero(mask) == 0: + return [None, None, None, None] + + col_acc = np.sum(mask, 0) + row_acc = np.sum(mask, 1) + + col_acc = col_acc.tolist() + row_acc = row_acc.tolist() + + for x in range(len(col_acc)): + if col_acc[x] > 0: + left = x + break + + for x in range(1, len(col_acc)): + if col_acc[-x] > 0: + right = len(col_acc) - x + break + + for x in range(len(row_acc)): + if row_acc[x] > 0: + top = x + break + + for x in range(1, len(row_acc)): + if row_acc[-x] > 0: + bottom = len(row_acc[::-1]) - x + break + return [top, bottom, left, right] + + +def visualize_flow(flow): + h, w = flow.shape[:2] + hsv = np.zeros((h, w, 3), np.uint8) + mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1]) + + hsv[..., 0] = ang * 180 / np.pi / 2 + hsv[..., 1] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) + hsv[..., 2] = 255 + bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) + bgr = bgr * 1.0 / 255 + return bgr.astype(np.float32) + + +def vis_joints(image, joints, color, show_text=True, confidence_threshold=0.1): + + part_orders = [(2, 5), (5, 11), (2, 8), (8, 11), (5, 6), (6, 7), (2, 3), + (3, 4), (11, 12), (12, 13), (8, 9), (9, 10)] + + abandon_idxs = [0, 1, 14, 15, 16, 17] + # draw joints + for i, joint in enumerate(joints): + if i in abandon_idxs: + continue + if joint[-1] > confidence_threshold: + + cv2.circle(image, (int(joint[0]), int(joint[1])), 1, color, 2) + if show_text: + cv2.putText(image, + str(i) + '[{:.2f}]'.format(joint[-1]), + (int(joint[0]), int(joint[1])), + cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) + # draw link + for pair in part_orders: + if joints[pair[0]][-1] > confidence_threshold and joints[ + pair[1]][-1] > confidence_threshold: + cv2.line(image, (int(joints[pair[0]][0]), int(joints[pair[0]][1])), + (int(joints[pair[1]][0]), int(joints[pair[1]][1])), color, + 2) + return image + + +def get_heatmap_cv(img, magn, max_flow_mag): + min_flow_mag = .5 + cv_magn = np.clip( + 255 * (magn - min_flow_mag) / (max_flow_mag - min_flow_mag + 1e-7), + a_min=0, + a_max=255).astype(np.uint8) + if img.dtype != np.uint8: + img = (255 * img).astype(np.uint8) + + heatmap_img = cv2.applyColorMap(cv_magn, cv2.COLORMAP_JET) + heatmap_img = heatmap_img[..., ::-1] + + h, w = magn.shape + img_alpha = np.ones((h, w), dtype=np.double)[:, :, None] + heatmap_alpha = np.clip( + magn / (max_flow_mag + 1e-7), a_min=1e-7, a_max=1)[:, :, None]**.7 + heatmap_alpha[heatmap_alpha < .2]**.5 + pm_hm = heatmap_img * heatmap_alpha + pm_img = img * img_alpha + cv_out = pm_hm + pm_img * (1 - heatmap_alpha) + cv_out = np.clip(cv_out, a_min=0, a_max=255).astype(np.uint8) + + return cv_out + + +def save_heatmap_cv(img, flow, supression=2): + + flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2) + flow_magn -= supression + flow_magn[flow_magn <= 0] = 0 + cv_out = get_heatmap_cv(img, flow_magn, np.max(flow_magn) * 1.3) + return cv_out + + +@numba.jit(nopython=True, parallel=False) +def bilinear_interp(x, y, v11, v12, v21, v22): + temp1 = (v11 * (1 - y) + v12 * y) * (1 - x) + temp2 = (v21 * (1 - y) + v22 * y) * x + result = temp1 + temp2 + return result + + +@numba.jit(nopython=True, parallel=False) +def image_warp_grid1(rDx, rDy, oriImg, transRatio, width_expand, + height_expand): + srcW = oriImg.shape[1] + srcH = oriImg.shape[0] + + newImg = oriImg.copy() + + for i in range(srcH): + for j in range(srcW): + _i = i + _j = j + + deltaX = rDx[_i, _j] + deltaY = rDy[_i, _j] + + nx = _j + deltaX * transRatio + ny = _i + deltaY * transRatio + + if nx >= srcW - width_expand - 1: + if nx > srcW - 1: + nx = srcW - 1 + + if ny >= srcH - height_expand - 1: + if ny > srcH - 1: + ny = srcH - 1 + + if nx < width_expand: + if nx < 0: + nx = 0 + + if ny < height_expand: + if ny < 0: + ny = 0 + + nxi = int(math.floor(nx)) + nyi = int(math.floor(ny)) + nxi1 = int(math.ceil(nx)) + nyi1 = int(math.ceil(ny)) + + for ll in range(3): + newImg[_i, _j, + ll] = bilinear_interp(ny - nyi, nx - nxi, + oriImg[nyi, nxi, + ll], oriImg[nyi, nxi1, ll], + oriImg[nyi1, nxi, + ll], oriImg[nyi1, nxi1, + ll]) + return newImg diff --git a/modelscope/models/cv/image_color_enhance/csrnet.py b/modelscope/models/cv/image_color_enhance/csrnet.py index 782cd528..502abf88 100644 --- a/modelscope/models/cv/image_color_enhance/csrnet.py +++ b/modelscope/models/cv/image_color_enhance/csrnet.py @@ -1,3 +1,6 @@ +# The implementation is adopted from Jingwen He, +# made publicly available at https://github.com/hejingwenhejingwen/CSRNet + import functools import math diff --git a/modelscope/models/cv/image_color_enhance/image_color_enhance.py b/modelscope/models/cv/image_color_enhance/image_color_enhance.py index 382cc152..0bd74197 100644 --- a/modelscope/models/cv/image_color_enhance/image_color_enhance.py +++ b/modelscope/models/cv/image_color_enhance/image_color_enhance.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp from copy import deepcopy from typing import Dict, Union diff --git a/modelscope/models/cv/image_colorization/unet.py b/modelscope/models/cv/image_colorization/unet.py index 8123651e..19f6ab62 100644 --- a/modelscope/models/cv/image_colorization/unet.py +++ b/modelscope/models/cv/image_colorization/unet.py @@ -1,3 +1,5 @@ +# The implementation here is modified based on DeOldify, originally MIT License +# and publicly available at https://github.com/jantic/DeOldify/blob/master/deoldify/unet.py import numpy as np import torch import torch.nn as nn diff --git a/modelscope/models/cv/image_colorization/utils.py b/modelscope/models/cv/image_colorization/utils.py index 03473f90..b8968aa0 100644 --- a/modelscope/models/cv/image_colorization/utils.py +++ b/modelscope/models/cv/image_colorization/utils.py @@ -1,3 +1,5 @@ +# The implementation here is modified based on DeOldify, originally MIT License and +# publicly available at https://github.com/jantic/DeOldify/blob/master/fastai/callbacks/hooks.py import functools from enum import Enum diff --git a/modelscope/models/cv/image_denoise/nafnet/NAFNet_arch.py b/modelscope/models/cv/image_denoise/nafnet/NAFNet_arch.py index 5b4e8ce1..c4de0729 100644 --- a/modelscope/models/cv/image_denoise/nafnet/NAFNet_arch.py +++ b/modelscope/models/cv/image_denoise/nafnet/NAFNet_arch.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------ +# Modified from https://github.com/megvii-research/NAFNet/blob/main/basicsr/models/archs/NAFNet_arch.py +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ + import numpy as np import torch import torch.nn as nn diff --git a/modelscope/models/cv/image_denoise/nafnet/arch_util.py b/modelscope/models/cv/image_denoise/nafnet/arch_util.py index df394dd5..2d406141 100644 --- a/modelscope/models/cv/image_denoise/nafnet/arch_util.py +++ b/modelscope/models/cv/image_denoise/nafnet/arch_util.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ + import torch import torch.nn as nn diff --git a/modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py b/modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py index c484b37b..4e8fc0ed 100644 --- a/modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py +++ b/modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py @@ -1,8 +1,8 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os from copy import deepcopy from typing import Any, Dict, Union -import numpy as np import torch.cuda from torch.nn.parallel import DataParallel, DistributedDataParallel @@ -77,13 +77,8 @@ class NAFNetForImageDenoise(TorchModel): def _evaluate_postprocess(self, input: Tensor, target: Tensor) -> Dict[str, list]: preds = self.model(input) - preds = list(torch.split(preds, 1, 0)) - targets = list(torch.split(target, 1, 0)) - - preds = [(pred.data * 255.).squeeze(0).permute( - 1, 2, 0).cpu().numpy().astype(np.uint8) for pred in preds] - targets = [(target.data * 255.).squeeze(0).permute( - 1, 2, 0).cpu().numpy().astype(np.uint8) for target in targets] + preds = list(torch.split(preds.clamp(0, 1), 1, 0)) + targets = list(torch.split(target.clamp(0, 1), 1, 0)) return {'pred': preds, 'target': targets} diff --git a/modelscope/models/nlp/backbones/__init__.py b/modelscope/models/cv/image_inpainting/__init__.py similarity index 83% rename from modelscope/models/nlp/backbones/__init__.py rename to modelscope/models/cv/image_inpainting/__init__.py index 749cf995..e7c63cd4 100644 --- a/modelscope/models/nlp/backbones/__init__.py +++ b/modelscope/models/cv/image_inpainting/__init__.py @@ -4,10 +4,11 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .structbert import SbertModel + from .model import FFTInpainting + else: _import_structure = { - 'structbert': ['SbertModel'], + 'model': ['FFTInpainting'], } import sys diff --git a/modelscope/models/cv/image_inpainting/base.py b/modelscope/models/cv/image_inpainting/base.py new file mode 100644 index 00000000..04e73630 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/base.py @@ -0,0 +1,75 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +from typing import Dict, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.utils.logger import get_logger +from .modules.adversarial import NonSaturatingWithR1 +from .modules.ffc import FFCResNetGenerator +from .modules.perceptual import ResNetPL +from .modules.pix2pixhd import NLayerDiscriminator + +LOGGER = get_logger() + + +class BaseInpaintingTrainingModule(nn.Module): + + def __init__(self, + model_dir='', + use_ddp=True, + predict_only=False, + visualize_each_iters=100, + average_generator=False, + generator_avg_beta=0.999, + average_generator_start_step=30000, + average_generator_period=10, + store_discr_outputs_for_vis=False, + **kwargs): + super().__init__() + LOGGER.info( + f'BaseInpaintingTrainingModule init called, predict_only is {predict_only}' + ) + + self.generator = FFCResNetGenerator() + self.use_ddp = use_ddp + + if not predict_only: + self.discriminator = NLayerDiscriminator() + self.adversarial_loss = NonSaturatingWithR1( + weight=10, + gp_coef=0.001, + mask_as_fake_target=True, + allow_scale_mask=True) + + self.average_generator = average_generator + self.generator_avg_beta = generator_avg_beta + self.average_generator_start_step = average_generator_start_step + self.average_generator_period = average_generator_period + self.generator_average = None + self.last_generator_averaging_step = -1 + self.store_discr_outputs_for_vis = store_discr_outputs_for_vis + + self.loss_l1 = nn.L1Loss(reduction='none') + + self.loss_resnet_pl = ResNetPL(weight=30, weights_path=model_dir) + + self.visualize_each_iters = visualize_each_iters + LOGGER.info('BaseInpaintingTrainingModule init done') + + def forward(self, batch: Dict[str, + torch.Tensor]) -> Dict[str, torch.Tensor]: + """Pass data through generator and obtain at leas 'predicted_image' and 'inpainted' keys""" + raise NotImplementedError() + + def generator_loss(self, + batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + raise NotImplementedError() + + def discriminator_loss( + self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + raise NotImplementedError() diff --git a/modelscope/models/cv/image_inpainting/default.py b/modelscope/models/cv/image_inpainting/default.py new file mode 100644 index 00000000..5f57d63f --- /dev/null +++ b/modelscope/models/cv/image_inpainting/default.py @@ -0,0 +1,210 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +import bisect + +import torch +import torch.nn.functional as F + +from modelscope.utils.logger import get_logger +from .base import BaseInpaintingTrainingModule +from .modules.feature_matching import feature_matching_loss, masked_l1_loss + +LOGGER = get_logger() + + +def set_requires_grad(module, value): + for param in module.parameters(): + param.requires_grad = value + + +def add_prefix_to_keys(dct, prefix): + return {prefix + k: v for k, v in dct.items()} + + +class LinearRamp: + + def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0): + self.start_value = start_value + self.end_value = end_value + self.start_iter = start_iter + self.end_iter = end_iter + + def __call__(self, i): + if i < self.start_iter: + return self.start_value + if i >= self.end_iter: + return self.end_value + part = (i - self.start_iter) / (self.end_iter - self.start_iter) + return self.start_value * (1 - part) + self.end_value * part + + +class LadderRamp: + + def __init__(self, start_iters, values): + self.start_iters = start_iters + self.values = values + assert len(values) == len(start_iters) + 1, (len(values), + len(start_iters)) + + def __call__(self, i): + segment_i = bisect.bisect_right(self.start_iters, i) + return self.values[segment_i] + + +def get_ramp(kind='ladder', **kwargs): + if kind == 'linear': + return LinearRamp(**kwargs) + if kind == 'ladder': + return LadderRamp(**kwargs) + raise ValueError(f'Unexpected ramp kind: {kind}') + + +class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule): + + def __init__(self, + model_dir='', + predict_only=False, + concat_mask=True, + rescale_scheduler_kwargs=None, + image_to_discriminator='predicted_image', + add_noise_kwargs=None, + noise_fill_hole=False, + const_area_crop_kwargs=None, + distance_weighter_kwargs=None, + distance_weighted_mask_for_discr=False, + fake_fakes_proba=0, + fake_fakes_generator_kwargs=None, + **kwargs): + super().__init__(model_dir=model_dir, predict_only=predict_only) + self.concat_mask = concat_mask + self.rescale_size_getter = get_ramp( + **rescale_scheduler_kwargs + ) if rescale_scheduler_kwargs is not None else None + self.image_to_discriminator = image_to_discriminator + self.add_noise_kwargs = add_noise_kwargs + self.noise_fill_hole = noise_fill_hole + self.const_area_crop_kwargs = const_area_crop_kwargs + self.refine_mask_for_losses = None + self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr + + self.feature_matching_weight = 100 + self.losses_l1_weight_known = 10 + self.losses_l1_weight_missing = 0 + self.fake_fakes_proba = fake_fakes_proba + + def forward(self, batch): + img = batch['image'] + mask = batch['mask'] + + masked_img = img * (1 - mask) + + if self.concat_mask: + masked_img = torch.cat([masked_img, mask], dim=1) + + batch['predicted_image'] = self.generator(masked_img) + batch['inpainted'] = mask * batch['predicted_image'] + ( + 1 - mask) * batch['image'] + + batch['mask_for_losses'] = mask + + return batch + + def generator_loss(self, batch): + img = batch['image'] + predicted_img = batch[self.image_to_discriminator] + original_mask = batch['mask'] + supervised_mask = batch['mask_for_losses'] + + # L1 + l1_value = masked_l1_loss(predicted_img, img, supervised_mask, + self.losses_l1_weight_known, + self.losses_l1_weight_missing) + + total_loss = l1_value + metrics = dict(gen_l1=l1_value) + + # discriminator + # adversarial_loss calls backward by itself + mask_for_discr = supervised_mask if self.distance_weighted_mask_for_discr else original_mask + self.adversarial_loss.pre_generator_step( + real_batch=img, + fake_batch=predicted_img, + generator=self.generator, + discriminator=self.discriminator) + discr_real_pred, discr_real_features = self.discriminator(img) + discr_fake_pred, discr_fake_features = self.discriminator( + predicted_img) + adv_gen_loss, adv_metrics = self.adversarial_loss.generator_loss( + real_batch=img, + fake_batch=predicted_img, + discr_real_pred=discr_real_pred, + discr_fake_pred=discr_fake_pred, + mask=mask_for_discr) + total_loss = total_loss + adv_gen_loss + metrics['gen_adv'] = adv_gen_loss + metrics.update(add_prefix_to_keys(adv_metrics, 'adv_')) + + # feature matching + if self.feature_matching_weight > 0: + need_mask_in_fm = False + mask_for_fm = supervised_mask if need_mask_in_fm else None + fm_value = feature_matching_loss( + discr_fake_features, discr_real_features, + mask=mask_for_fm) * self.feature_matching_weight + total_loss = total_loss + fm_value + metrics['gen_fm'] = fm_value + + if self.loss_resnet_pl is not None: + resnet_pl_value = self.loss_resnet_pl(predicted_img, img) + total_loss = total_loss + resnet_pl_value + metrics['gen_resnet_pl'] = resnet_pl_value + + return total_loss, metrics + + def discriminator_loss(self, batch): + total_loss = 0 + metrics = {} + + predicted_img = batch[self.image_to_discriminator].detach() + self.adversarial_loss.pre_discriminator_step( + real_batch=batch['image'], + fake_batch=predicted_img, + generator=self.generator, + discriminator=self.discriminator) + discr_real_pred, discr_real_features = self.discriminator( + batch['image']) + discr_fake_pred, discr_fake_features = self.discriminator( + predicted_img) + adv_discr_loss, adv_metrics = self.adversarial_loss.discriminator_loss( + real_batch=batch['image'], + fake_batch=predicted_img, + discr_real_pred=discr_real_pred, + discr_fake_pred=discr_fake_pred, + mask=batch['mask']) + + total_loss = (total_loss + adv_discr_loss) * 0.1 + metrics['discr_adv'] = adv_discr_loss + metrics.update(add_prefix_to_keys(adv_metrics, 'adv_')) + + return total_loss, metrics + + def _do_step(self, batch, optimizer_idx=None): + if optimizer_idx == 0: # step for generator + set_requires_grad(self.generator, True) + set_requires_grad(self.discriminator, False) + elif optimizer_idx == 1: # step for discriminator + set_requires_grad(self.generator, False) + set_requires_grad(self.discriminator, True) + + batch = self(batch) + total_loss = 0 + if optimizer_idx is None or optimizer_idx == 0: # step for generator + total_loss, metrics = self.generator_loss(batch) + + elif optimizer_idx is None or optimizer_idx == 1: # step for discriminator + total_loss, metrics = self.discriminator_loss(batch) + + result = dict(loss=total_loss) + return result diff --git a/modelscope/models/cv/image_inpainting/model.py b/modelscope/models/cv/image_inpainting/model.py new file mode 100644 index 00000000..b12f6edd --- /dev/null +++ b/modelscope/models/cv/image_inpainting/model.py @@ -0,0 +1,36 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +LOGGER = get_logger() + + +@MODELS.register_module( + Tasks.image_inpainting, module_name=Models.image_inpainting) +class FFTInpainting(TorchModel): + + def __init__(self, model_dir: str, **kwargs): + super().__init__(model_dir, **kwargs) + + from .default import DefaultInpaintingTrainingModule + pretrained = kwargs.get('pretrained', True) + predict_only = kwargs.get('predict_only', False) + net = DefaultInpaintingTrainingModule( + model_dir=model_dir, predict_only=predict_only) + if pretrained: + path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE) + LOGGER.info(f'loading pretrained model from {path}') + state = torch.load(path, map_location='cpu') + net.load_state_dict(state, strict=False) + self.model = net + + def forward(self, inputs): + return self.model(inputs) diff --git a/modelscope/preprocessors/star3/fields/__init__.py b/modelscope/models/cv/image_inpainting/modules/__init__.py similarity index 100% rename from modelscope/preprocessors/star3/fields/__init__.py rename to modelscope/models/cv/image_inpainting/modules/__init__.py diff --git a/modelscope/models/cv/image_inpainting/modules/ade20k/__init__.py b/modelscope/models/cv/image_inpainting/modules/ade20k/__init__.py new file mode 100644 index 00000000..89c3e293 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/ade20k/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .base import ModelBuilder diff --git a/modelscope/models/cv/image_inpainting/modules/ade20k/base.py b/modelscope/models/cv/image_inpainting/modules/ade20k/base.py new file mode 100644 index 00000000..02bd3cc4 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/ade20k/base.py @@ -0,0 +1,380 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" + +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules import BatchNorm2d + +from . import resnet + +NUM_CLASS = 150 + + +# Model Builder +class ModelBuilder: + # custom weights initialization + @staticmethod + def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.kaiming_normal_(m.weight.data) + elif classname.find('BatchNorm') != -1: + m.weight.data.fill_(1.) + m.bias.data.fill_(1e-4) + + @staticmethod + def build_encoder(arch='resnet50dilated', + fc_dim=512, + weights='', + model_dir=''): + pretrained = True if len(weights) == 0 else False + arch = arch.lower() + if arch == 'resnet50dilated': + orig_resnet = resnet.__dict__['resnet50']( + pretrained=pretrained, model_dir=model_dir) + net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) + elif arch == 'resnet50': + orig_resnet = resnet.__dict__['resnet50']( + pretrained=pretrained, model_dir=model_dir) + net_encoder = Resnet(orig_resnet) + else: + raise Exception('Architecture undefined!') + + # encoders are usually pretrained + # net_encoder.apply(ModelBuilder.weights_init) + if len(weights) > 0: + print('Loading weights for net_encoder') + net_encoder.load_state_dict( + torch.load(weights, map_location=lambda storage, loc: storage), + strict=False) + return net_encoder + + @staticmethod + def build_decoder(arch='ppm_deepsup', + fc_dim=512, + num_class=NUM_CLASS, + weights='', + use_softmax=False, + drop_last_conv=False): + arch = arch.lower() + if arch == 'ppm_deepsup': + net_decoder = PPMDeepsup( + num_class=num_class, + fc_dim=fc_dim, + use_softmax=use_softmax, + drop_last_conv=drop_last_conv) + elif arch == 'c1_deepsup': + net_decoder = C1DeepSup( + num_class=num_class, + fc_dim=fc_dim, + use_softmax=use_softmax, + drop_last_conv=drop_last_conv) + else: + raise Exception('Architecture undefined!') + + net_decoder.apply(ModelBuilder.weights_init) + if len(weights) > 0: + print('Loading weights for net_decoder') + net_decoder.load_state_dict( + torch.load(weights, map_location=lambda storage, loc: storage), + strict=False) + return net_decoder + + @staticmethod + def get_decoder(weights_path, arch_encoder, arch_decoder, fc_dim, + drop_last_conv, *arts, **kwargs): + path = os.path.join( + weights_path, 'ade20k', + f'ade20k-{arch_encoder}-{arch_decoder}/decoder_epoch_20.pth') + return ModelBuilder.build_decoder( + arch=arch_decoder, + fc_dim=fc_dim, + weights=path, + use_softmax=True, + drop_last_conv=drop_last_conv) + + @staticmethod + def get_encoder(weights_path, arch_encoder, arch_decoder, fc_dim, + segmentation, *arts, **kwargs): + if segmentation: + path = os.path.join( + weights_path, 'ade20k', + f'ade20k-{arch_encoder}-{arch_decoder}/encoder_epoch_20.pth') + else: + path = '' + return ModelBuilder.build_encoder( + arch=arch_encoder, + fc_dim=fc_dim, + weights=path, + model_dir=weights_path) + + +def conv3x3_bn_relu(in_planes, out_planes, stride=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False), + BatchNorm2d(out_planes), + nn.ReLU(inplace=True), + ) + + +# pyramid pooling, deep supervision +class PPMDeepsup(nn.Module): + + def __init__(self, + num_class=NUM_CLASS, + fc_dim=4096, + use_softmax=False, + pool_scales=(1, 2, 3, 6), + drop_last_conv=False): + super().__init__() + self.use_softmax = use_softmax + self.drop_last_conv = drop_last_conv + + self.ppm = [] + for scale in pool_scales: + self.ppm.append( + nn.Sequential( + nn.AdaptiveAvgPool2d(scale), + nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), + BatchNorm2d(512), nn.ReLU(inplace=True))) + self.ppm = nn.ModuleList(self.ppm) + self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) + + self.conv_last = nn.Sequential( + nn.Conv2d( + fc_dim + len(pool_scales) * 512, + 512, + kernel_size=3, + padding=1, + bias=False), BatchNorm2d(512), nn.ReLU(inplace=True), + nn.Dropout2d(0.1), nn.Conv2d(512, num_class, kernel_size=1)) + self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + self.dropout_deepsup = nn.Dropout2d(0.1) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + + input_size = conv5.size() + ppm_out = [conv5] + for pool_scale in self.ppm: + ppm_out.append( + nn.functional.interpolate( + pool_scale(conv5), (input_size[2], input_size[3]), + mode='bilinear', + align_corners=False)) + ppm_out = torch.cat(ppm_out, 1) + + if self.drop_last_conv: + return ppm_out + else: + x = self.conv_last(ppm_out) + + if self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + x = nn.functional.softmax(x, dim=1) + return x + + # deep sup + conv4 = conv_out[-2] + _ = self.cbr_deepsup(conv4) + _ = self.dropout_deepsup(_) + _ = self.conv_last_deepsup(_) + + x = nn.functional.log_softmax(x, dim=1) + _ = nn.functional.log_softmax(_, dim=1) + + return (x, _) + + +class Resnet(nn.Module): + + def __init__(self, orig_resnet): + super(Resnet, self).__init__() + + # take pretrained resnet, except AvgPool and FC + self.conv1 = orig_resnet.conv1 + self.bn1 = orig_resnet.bn1 + self.relu1 = orig_resnet.relu1 + self.conv2 = orig_resnet.conv2 + self.bn2 = orig_resnet.bn2 + self.relu2 = orig_resnet.relu2 + self.conv3 = orig_resnet.conv3 + self.bn3 = orig_resnet.bn3 + self.relu3 = orig_resnet.relu3 + self.maxpool = orig_resnet.maxpool + self.layer1 = orig_resnet.layer1 + self.layer2 = orig_resnet.layer2 + self.layer3 = orig_resnet.layer3 + self.layer4 = orig_resnet.layer4 + + def forward(self, x, return_feature_maps=False): + conv_out = [] + + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x) + conv_out.append(x) + x = self.layer2(x) + conv_out.append(x) + x = self.layer3(x) + conv_out.append(x) + x = self.layer4(x) + conv_out.append(x) + + if return_feature_maps: + return conv_out + return [x] + + +# Resnet Dilated +class ResnetDilated(nn.Module): + + def __init__(self, orig_resnet, dilate_scale=8): + super().__init__() + from functools import partial + + if dilate_scale == 8: + orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2)) + orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4)) + elif dilate_scale == 16: + orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2)) + + # take pretrained resnet, except AvgPool and FC + self.conv1 = orig_resnet.conv1 + self.bn1 = orig_resnet.bn1 + self.relu1 = orig_resnet.relu1 + self.conv2 = orig_resnet.conv2 + self.bn2 = orig_resnet.bn2 + self.relu2 = orig_resnet.relu2 + self.conv3 = orig_resnet.conv3 + self.bn3 = orig_resnet.bn3 + self.relu3 = orig_resnet.relu3 + self.maxpool = orig_resnet.maxpool + self.layer1 = orig_resnet.layer1 + self.layer2 = orig_resnet.layer2 + self.layer3 = orig_resnet.layer3 + self.layer4 = orig_resnet.layer4 + + def _nostride_dilate(self, m, dilate): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + # the convolution with stride + if m.stride == (2, 2): + m.stride = (1, 1) + if m.kernel_size == (3, 3): + m.dilation = (dilate // 2, dilate // 2) + m.padding = (dilate // 2, dilate // 2) + # other convoluions + else: + if m.kernel_size == (3, 3): + m.dilation = (dilate, dilate) + m.padding = (dilate, dilate) + + def forward(self, x, return_feature_maps=False): + conv_out = [] + + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x) + conv_out.append(x) + x = self.layer2(x) + conv_out.append(x) + x = self.layer3(x) + conv_out.append(x) + x = self.layer4(x) + conv_out.append(x) + + if return_feature_maps: + return conv_out + return [x] + + +# last conv, deep supervision +class C1DeepSup(nn.Module): + + def __init__(self, + num_class=150, + fc_dim=2048, + use_softmax=False, + drop_last_conv=False): + super(C1DeepSup, self).__init__() + self.use_softmax = use_softmax + self.drop_last_conv = drop_last_conv + + self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) + self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) + + # last conv + self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + + x = self.cbr(conv5) + + if self.drop_last_conv: + return x + else: + x = self.conv_last(x) + + if self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + x = nn.functional.softmax(x, dim=1) + return x + + # deep sup + conv4 = conv_out[-2] + _ = self.cbr_deepsup(conv4) + _ = self.conv_last_deepsup(_) + + x = nn.functional.log_softmax(x, dim=1) + _ = nn.functional.log_softmax(_, dim=1) + + return (x, _) + + +# last conv +class C1(nn.Module): + + def __init__(self, num_class=150, fc_dim=2048, use_softmax=False): + super(C1, self).__init__() + self.use_softmax = use_softmax + + self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) + + # last conv + self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + x = self.cbr(conv5) + x = self.conv_last(x) + + if self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + x = nn.functional.softmax(x, dim=1) + else: + x = nn.functional.log_softmax(x, dim=1) + + return x diff --git a/modelscope/models/cv/image_inpainting/modules/ade20k/resnet.py b/modelscope/models/cv/image_inpainting/modules/ade20k/resnet.py new file mode 100644 index 00000000..7da9ff07 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/ade20k/resnet.py @@ -0,0 +1,183 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +import math +import os + +import torch +import torch.nn as nn +from torch.nn import BatchNorm2d + +__all__ = ['ResNet', 'resnet50'] + + +def conv3x3(in_planes, out_planes, stride=1): + '3x3 convolution with padding' + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn2 = BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000): + self.inplanes = 128 + super(ResNet, self).__init__() + self.conv1 = conv3x3(3, 64, stride=2) + self.bn1 = BatchNorm2d(64) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = conv3x3(64, 64) + self.bn2 = BatchNorm2d(64) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = conv3x3(64, 128) + self.bn3 = BatchNorm2d(128) + self.relu3 = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +def resnet50(pretrained=False, model_dir='', **kwargs): + """Constructs a ResNet-50 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + cached_file = os.path.join(model_dir, 'resnet50-imagenet.pth') + model.load_state_dict( + torch.load(cached_file, map_location='cpu'), strict=False) + return model diff --git a/modelscope/models/cv/image_inpainting/modules/adversarial.py b/modelscope/models/cv/image_inpainting/modules/adversarial.py new file mode 100644 index 00000000..b183876b --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/adversarial.py @@ -0,0 +1,167 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +from typing import Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BaseAdversarialLoss: + + def pre_generator_step(self, real_batch: torch.Tensor, + fake_batch: torch.Tensor, generator: nn.Module, + discriminator: nn.Module): + """ + Prepare for generator step + :param real_batch: Tensor, a batch of real samples + :param fake_batch: Tensor, a batch of samples produced by generator + :param generator: + :param discriminator: + :return: None + """ + + def pre_discriminator_step(self, real_batch: torch.Tensor, + fake_batch: torch.Tensor, generator: nn.Module, + discriminator: nn.Module): + """ + Prepare for discriminator step + :param real_batch: Tensor, a batch of real samples + :param fake_batch: Tensor, a batch of samples produced by generator + :param generator: + :param discriminator: + :return: None + """ + + def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, + mask: Optional[torch.Tensor] = None) \ + -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Calculate generator loss + :param real_batch: Tensor, a batch of real samples + :param fake_batch: Tensor, a batch of samples produced by generator + :param discr_real_pred: Tensor, discriminator output for real_batch + :param discr_fake_pred: Tensor, discriminator output for fake_batch + :param mask: Tensor, actual mask, which was at input of generator when making fake_batch + :return: total generator loss along with some values that might be interesting to log + """ + raise NotImplementedError + + def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, + mask: Optional[torch.Tensor] = None) \ + -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Calculate discriminator loss and call .backward() on it + :param real_batch: Tensor, a batch of real samples + :param fake_batch: Tensor, a batch of samples produced by generator + :param discr_real_pred: Tensor, discriminator output for real_batch + :param discr_fake_pred: Tensor, discriminator output for fake_batch + :param mask: Tensor, actual mask, which was at input of generator when making fake_batch + :return: total discriminator loss along with some values that might be interesting to log + """ + raise NotImplementedError + + def interpolate_mask(self, mask, shape): + assert mask is not None + assert self.allow_scale_mask or shape == mask.shape[-2:] + if shape != mask.shape[-2:] and self.allow_scale_mask: + if self.mask_scale_mode == 'maxpool': + mask = F.adaptive_max_pool2d(mask, shape) + else: + mask = F.interpolate( + mask, size=shape, mode=self.mask_scale_mode) + return mask + + +def make_r1_gp(discr_real_pred, real_batch): + if torch.is_grad_enabled(): + grad_real = torch.autograd.grad( + outputs=discr_real_pred.sum(), + inputs=real_batch, + create_graph=True)[0] + grad_penalty = (grad_real.view(grad_real.shape[0], + -1).norm(2, dim=1)**2).mean() + else: + grad_penalty = 0 + real_batch.requires_grad = False + + return grad_penalty + + +class NonSaturatingWithR1(BaseAdversarialLoss): + + def __init__(self, + gp_coef=5, + weight=1, + mask_as_fake_target=False, + allow_scale_mask=False, + mask_scale_mode='nearest', + extra_mask_weight_for_gen=0, + use_unmasked_for_gen=True, + use_unmasked_for_discr=True): + self.gp_coef = gp_coef + self.weight = weight + # use for discr => use for gen; + # otherwise we teach only the discr to pay attention to very small difference + assert use_unmasked_for_gen or (not use_unmasked_for_discr) + # mask as target => use unmasked for discr: + # if we don't care about unmasked regions at all + # then it doesn't matter if the value of mask_as_fake_target is true or false + assert use_unmasked_for_discr or (not mask_as_fake_target) + self.use_unmasked_for_gen = use_unmasked_for_gen + self.use_unmasked_for_discr = use_unmasked_for_discr + self.mask_as_fake_target = mask_as_fake_target + self.allow_scale_mask = allow_scale_mask + self.mask_scale_mode = mask_scale_mode + self.extra_mask_weight_for_gen = extra_mask_weight_for_gen + + def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, + mask=None) \ + -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + fake_loss = F.softplus(-discr_fake_pred) + if (self.mask_as_fake_target and self.extra_mask_weight_for_gen > 0) or \ + not self.use_unmasked_for_gen: # == if masked region should be treated differently + mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:]) + if not self.use_unmasked_for_gen: + fake_loss = fake_loss * mask + else: + pixel_weights = 1 + mask * self.extra_mask_weight_for_gen + fake_loss = fake_loss * pixel_weights + + return fake_loss.mean() * self.weight, dict() + + def pre_discriminator_step(self, real_batch: torch.Tensor, + fake_batch: torch.Tensor, generator: nn.Module, + discriminator: nn.Module): + real_batch.requires_grad = True + + def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, + mask=None) \ + -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + + real_loss = F.softplus(-discr_real_pred) + grad_penalty = make_r1_gp(discr_real_pred, real_batch) * self.gp_coef + fake_loss = F.softplus(discr_fake_pred) + + if not self.use_unmasked_for_discr or self.mask_as_fake_target: + # == if masked region should be treated differently + mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:]) + # use_unmasked_for_discr=False only makes sense for fakes; + # for reals there is no difference beetween two regions + fake_loss = fake_loss * mask + if self.mask_as_fake_target: + fake_loss = fake_loss + (1 + - mask) * F.softplus(-discr_fake_pred) + + sum_discr_loss = real_loss + grad_penalty + fake_loss + metrics = dict( + discr_real_out=discr_real_pred.mean(), + discr_fake_out=discr_fake_pred.mean(), + discr_real_gp=grad_penalty) + return sum_discr_loss.mean(), metrics diff --git a/modelscope/models/cv/image_inpainting/modules/feature_matching.py b/modelscope/models/cv/image_inpainting/modules/feature_matching.py new file mode 100644 index 00000000..c2effb20 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/feature_matching.py @@ -0,0 +1,45 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +from typing import List + +import torch +import torch.nn.functional as F + + +def masked_l2_loss(pred, target, mask, weight_known, weight_missing): + per_pixel_l2 = F.mse_loss(pred, target, reduction='none') + pixel_weights = mask * weight_missing + (1 - mask) * weight_known + return (pixel_weights * per_pixel_l2).mean() + + +def masked_l1_loss(pred, target, mask, weight_known, weight_missing): + per_pixel_l1 = F.l1_loss(pred, target, reduction='none') + pixel_weights = mask * weight_missing + (1 - mask) * weight_known + return (pixel_weights * per_pixel_l1).mean() + + +def feature_matching_loss(fake_features: List[torch.Tensor], + target_features: List[torch.Tensor], + mask=None): + if mask is None: + res = torch.stack([ + F.mse_loss(fake_feat, target_feat) + for fake_feat, target_feat in zip(fake_features, target_features) + ]).mean() + else: + res = 0 + norm = 0 + for fake_feat, target_feat in zip(fake_features, target_features): + cur_mask = F.interpolate( + mask, + size=fake_feat.shape[-2:], + mode='bilinear', + align_corners=False) + error_weights = 1 - cur_mask + cur_val = ((fake_feat - target_feat).pow(2) * error_weights).mean() + res = res + cur_val + norm += 1 + res = res / norm + return res diff --git a/modelscope/models/cv/image_inpainting/modules/ffc.py b/modelscope/models/cv/image_inpainting/modules/ffc.py new file mode 100644 index 00000000..c74425e3 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/ffc.py @@ -0,0 +1,588 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from kornia.geometry.transform import rotate + + +def get_activation(kind='tanh'): + if kind == 'tanh': + return nn.Tanh() + if kind == 'sigmoid': + return nn.Sigmoid() + if kind is False: + return nn.Identity() + raise ValueError(f'Unknown activation kind {kind}') + + +class SELayer(nn.Module): + + def __init__(self, channel, reduction=16): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel, bias=False), nn.Sigmoid()) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + res = x * y.expand_as(x) + return res + + +class FourierUnit(nn.Module): + + def __init__(self, + in_channels, + out_channels, + groups=1, + spatial_scale_factor=None, + spatial_scale_mode='bilinear', + spectral_pos_encoding=False, + use_se=False, + se_kwargs=None, + ffc3d=False, + fft_norm='ortho'): + # bn_layer not used + super(FourierUnit, self).__init__() + self.groups = groups + + self.conv_layer = torch.nn.Conv2d( + in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0), + out_channels=out_channels * 2, + kernel_size=1, + stride=1, + padding=0, + groups=self.groups, + bias=False) + self.bn = torch.nn.BatchNorm2d(out_channels * 2) + self.relu = torch.nn.ReLU(inplace=True) + + # squeeze and excitation block + self.use_se = use_se + if use_se: + if se_kwargs is None: + se_kwargs = {} + self.se = SELayer(self.conv_layer.in_channels, **se_kwargs) + + self.spatial_scale_factor = spatial_scale_factor + self.spatial_scale_mode = spatial_scale_mode + self.spectral_pos_encoding = spectral_pos_encoding + self.ffc3d = ffc3d + self.fft_norm = fft_norm + + def forward(self, x): + batch = x.shape[0] + + if self.spatial_scale_factor is not None: + orig_size = x.shape[-2:] + x = F.interpolate( + x, + scale_factor=self.spatial_scale_factor, + mode=self.spatial_scale_mode, + align_corners=False) + + # (batch, c, h, w/2+1, 2) + fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1) + ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) + ffted = torch.stack((ffted.real, ffted.imag), dim=-1) + ffted = ffted.permute(0, 1, 4, 2, + 3).contiguous() # (batch, c, 2, h, w/2+1) + ffted = ffted.view(( + batch, + -1, + ) + ffted.size()[3:]) + + if self.spectral_pos_encoding: + height, width = ffted.shape[-2:] + coords_vert = torch.linspace(0, 1, + height)[None, None, :, None].expand( + batch, 1, height, width).to(ffted) + coords_hor = torch.linspace(0, 1, + width)[None, None, None, :].expand( + batch, 1, height, width).to(ffted) + ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1) + + if self.use_se: + ffted = self.se(ffted) + + ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1) + ffted = self.relu(self.bn(ffted)) + + ffted = ffted.view(( + batch, + -1, + 2, + ) + ffted.size()[2:]).permute( + 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2) + ffted = torch.complex(ffted[..., 0], ffted[..., 1]) + + ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:] + output = torch.fft.irfftn( + ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) + + if self.spatial_scale_factor is not None: + output = F.interpolate( + output, + size=orig_size, + mode=self.spatial_scale_mode, + align_corners=False) + + return output + + +class SpectralTransform(nn.Module): + + def __init__(self, + in_channels, + out_channels, + stride=1, + groups=1, + enable_lfu=True, + **fu_kwargs): + # bn_layer not used + super(SpectralTransform, self).__init__() + self.enable_lfu = enable_lfu + if stride == 2: + self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2) + else: + self.downsample = nn.Identity() + + self.stride = stride + self.conv1 = nn.Sequential( + nn.Conv2d( + in_channels, + out_channels // 2, + kernel_size=1, + groups=groups, + bias=False), nn.BatchNorm2d(out_channels // 2), + nn.ReLU(inplace=True)) + self.fu = FourierUnit(out_channels // 2, out_channels // 2, groups, + **fu_kwargs) + if self.enable_lfu: + self.lfu = FourierUnit(out_channels // 2, out_channels // 2, + groups) + self.conv2 = torch.nn.Conv2d( + out_channels // 2, + out_channels, + kernel_size=1, + groups=groups, + bias=False) + + def forward(self, x): + + x = self.downsample(x) + x = self.conv1(x) + output = self.fu(x) + + if self.enable_lfu: + n, c, h, w = x.shape + split_no = 2 + split_s = h // split_no + xs = torch.cat( + torch.split(x[:, :c // 4], split_s, dim=-2), + dim=1).contiguous() + xs = torch.cat( + torch.split(xs, split_s, dim=-1), dim=1).contiguous() + xs = self.lfu(xs) + xs = xs.repeat(1, 1, split_no, split_no).contiguous() + else: + xs = 0 + + output = self.conv2(x + output + xs) + + return output + + +class LearnableSpatialTransformWrapper(nn.Module): + + def __init__(self, + impl, + pad_coef=0.5, + angle_init_range=80, + train_angle=True): + super().__init__() + self.impl = impl + self.angle = torch.rand(1) * angle_init_range + if train_angle: + self.angle = nn.Parameter(self.angle, requires_grad=True) + self.pad_coef = pad_coef + + def forward(self, x): + if torch.is_tensor(x): + return self.inverse_transform(self.impl(self.transform(x)), x) + elif isinstance(x, tuple): + x_trans = tuple(self.transform(elem) for elem in x) + y_trans = self.impl(x_trans) + return tuple( + self.inverse_transform(elem, orig_x) + for elem, orig_x in zip(y_trans, x)) + else: + raise ValueError(f'Unexpected input type {type(x)}') + + def transform(self, x): + height, width = x.shape[2:] + pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) + x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode='reflect') + x_padded_rotated = rotate(x_padded, angle=self.angle.to(x_padded)) + return x_padded_rotated + + def inverse_transform(self, y_padded_rotated, orig_x): + height, width = orig_x.shape[2:] + pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) + + y_padded = rotate( + y_padded_rotated, angle=-self.angle.to(y_padded_rotated)) + y_height, y_width = y_padded.shape[2:] + y = y_padded[:, :, pad_h:y_height - pad_h, pad_w:y_width - pad_w] + return y + + +class FFC(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + ratio_gin, + ratio_gout, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False, + enable_lfu=True, + padding_type='reflect', + gated=False, + **spectral_kwargs): + super(FFC, self).__init__() + + assert stride == 1 or stride == 2, 'Stride should be 1 or 2.' + self.stride = stride + + in_cg = int(in_channels * ratio_gin) + in_cl = in_channels - in_cg + out_cg = int(out_channels * ratio_gout) + out_cl = out_channels - out_cg + + self.ratio_gin = ratio_gin + self.ratio_gout = ratio_gout + self.global_in_num = in_cg + + module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d + self.convl2l = module( + in_cl, + out_cl, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode=padding_type) + module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d + self.convl2g = module( + in_cl, + out_cg, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode=padding_type) + module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d + self.convg2l = module( + in_cg, + out_cl, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode=padding_type) + module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform + self.convg2g = module(in_cg, out_cg, stride, + 1 if groups == 1 else groups // 2, enable_lfu, + **spectral_kwargs) + + self.gated = gated + module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d + self.gate = module(in_channels, 2, 1) + + def forward(self, x): + x_l, x_g = x if type(x) is tuple else (x, 0) + out_xl, out_xg = 0, 0 + + if self.gated: + total_input_parts = [x_l] + if torch.is_tensor(x_g): + total_input_parts.append(x_g) + total_input = torch.cat(total_input_parts, dim=1) + + gates = torch.sigmoid(self.gate(total_input)) + g2l_gate, l2g_gate = gates.chunk(2, dim=1) + else: + g2l_gate, l2g_gate = 1, 1 + + if self.ratio_gout != 1: + out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate + if self.ratio_gout != 0: + out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g) + + return out_xl, out_xg + + +class FFC_BN_ACT(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + ratio_gin, + ratio_gout, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False, + norm_layer=nn.BatchNorm2d, + activation_layer=nn.Identity, + padding_type='reflect', + enable_lfu=True, + **kwargs): + super(FFC_BN_ACT, self).__init__() + self.ffc = FFC( + in_channels, + out_channels, + kernel_size, + ratio_gin, + ratio_gout, + stride, + padding, + dilation, + groups, + bias, + enable_lfu, + padding_type=padding_type, + **kwargs) + lnorm = nn.Identity if ratio_gout == 1 else norm_layer + gnorm = nn.Identity if ratio_gout == 0 else norm_layer + global_channels = int(out_channels * ratio_gout) + self.bn_l = lnorm(out_channels - global_channels) + self.bn_g = gnorm(global_channels) + + lact = nn.Identity if ratio_gout == 1 else activation_layer + gact = nn.Identity if ratio_gout == 0 else activation_layer + self.act_l = lact(inplace=True) + self.act_g = gact(inplace=True) + + def forward(self, x): + x_l, x_g = self.ffc(x) + x_l = self.act_l(self.bn_l(x_l)) + x_g = self.act_g(self.bn_g(x_g)) + return x_l, x_g + + +class FFCResnetBlock(nn.Module): + + def __init__(self, + dim, + padding_type, + norm_layer, + activation_layer=nn.ReLU, + dilation=1, + spatial_transform_kwargs=None, + inline=False, + **conv_kwargs): + super().__init__() + self.conv1 = FFC_BN_ACT( + dim, + dim, + kernel_size=3, + padding=dilation, + dilation=dilation, + norm_layer=norm_layer, + activation_layer=activation_layer, + padding_type=padding_type, + **conv_kwargs) + self.conv2 = FFC_BN_ACT( + dim, + dim, + kernel_size=3, + padding=dilation, + dilation=dilation, + norm_layer=norm_layer, + activation_layer=activation_layer, + padding_type=padding_type, + **conv_kwargs) + if spatial_transform_kwargs is not None: + self.conv1 = LearnableSpatialTransformWrapper( + self.conv1, **spatial_transform_kwargs) + self.conv2 = LearnableSpatialTransformWrapper( + self.conv2, **spatial_transform_kwargs) + self.inline = inline + + def forward(self, x): + if self.inline: + x_l, x_g = x[:, :-self.conv1.ffc. + global_in_num], x[:, -self.conv1.ffc.global_in_num:] + else: + x_l, x_g = x if type(x) is tuple else (x, 0) + + id_l, id_g = x_l, x_g + + x_l, x_g = self.conv1((x_l, x_g)) + x_l, x_g = self.conv2((x_l, x_g)) + + x_l, x_g = id_l + x_l, id_g + x_g + out = x_l, x_g + if self.inline: + out = torch.cat(out, dim=1) + return out + + +class ConcatTupleLayer(nn.Module): + + def forward(self, x): + assert isinstance(x, tuple) + x_l, x_g = x + assert torch.is_tensor(x_l) or torch.is_tensor(x_g) + if not torch.is_tensor(x_g): + return x_l + return torch.cat(x, dim=1) + + +class FFCResNetGenerator(nn.Module): + + def __init__(self, + input_nc=4, + output_nc=3, + ngf=64, + n_downsampling=3, + n_blocks=18, + norm_layer=nn.BatchNorm2d, + padding_type='reflect', + activation_layer=nn.ReLU, + up_norm_layer=nn.BatchNorm2d, + up_activation=nn.ReLU(True), + init_conv_kwargs={ + 'ratio_gin': 0, + 'ratio_gout': 0, + 'enable_lfu': False + }, + downsample_conv_kwargs={ + 'ratio_gin': 0, + 'ratio_gout': 0, + 'enable_lfu': False + }, + resnet_conv_kwargs={ + 'ratio_gin': 0.75, + 'ratio_gout': 0.75, + 'enable_lfu': False + }, + spatial_transform_layers=None, + spatial_transform_kwargs={}, + add_out_act='sigmoid', + max_features=1024, + out_ffc=False, + out_ffc_kwargs={}): + assert (n_blocks >= 0) + super().__init__() + + model = [ + nn.ReflectionPad2d(3), + FFC_BN_ACT( + input_nc, + ngf, + kernel_size=7, + padding=0, + norm_layer=norm_layer, + activation_layer=activation_layer, + **init_conv_kwargs) + ] + + # downsample + for i in range(n_downsampling): + mult = 2**i + if i == n_downsampling - 1: + cur_conv_kwargs = dict(downsample_conv_kwargs) + cur_conv_kwargs['ratio_gout'] = resnet_conv_kwargs.get( + 'ratio_gin', 0) + else: + cur_conv_kwargs = downsample_conv_kwargs + model += [ + FFC_BN_ACT( + min(max_features, ngf * mult), + min(max_features, ngf * mult * 2), + kernel_size=3, + stride=2, + padding=1, + norm_layer=norm_layer, + activation_layer=activation_layer, + **cur_conv_kwargs) + ] + + mult = 2**n_downsampling + feats_num_bottleneck = min(max_features, ngf * mult) + + # resnet blocks + for i in range(n_blocks): + cur_resblock = FFCResnetBlock( + feats_num_bottleneck, + padding_type=padding_type, + activation_layer=activation_layer, + norm_layer=norm_layer, + **resnet_conv_kwargs) + if spatial_transform_layers is not None and i in spatial_transform_layers: + cur_resblock = LearnableSpatialTransformWrapper( + cur_resblock, **spatial_transform_kwargs) + model += [cur_resblock] + + model += [ConcatTupleLayer()] + + # upsample + for i in range(n_downsampling): + mult = 2**(n_downsampling - i) + model += [ + nn.ConvTranspose2d( + min(max_features, ngf * mult), + min(max_features, int(ngf * mult / 2)), + kernel_size=3, + stride=2, + padding=1, + output_padding=1), + up_norm_layer(min(max_features, int(ngf * mult / 2))), + up_activation + ] + + if out_ffc: + model += [ + FFCResnetBlock( + ngf, + padding_type=padding_type, + activation_layer=activation_layer, + norm_layer=norm_layer, + inline=True, + **out_ffc_kwargs) + ] + + model += [ + nn.ReflectionPad2d(3), + nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0) + ] + if add_out_act: + model.append( + get_activation('tanh' if add_out_act is True else add_out_act)) + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) diff --git a/modelscope/models/cv/image_inpainting/modules/inception.py b/modelscope/models/cv/image_inpainting/modules/inception.py new file mode 100644 index 00000000..5070533d --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/inception.py @@ -0,0 +1,324 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import models + +from modelscope.utils.logger import get_logger + +try: + from torchvision.models.utils import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +# Inception weights ported to Pytorch from +# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz +FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/' \ + 'fid_weights/pt_inception-2015-12-05-6726825d.pth' + +LOGGER = get_logger() + + +class InceptionV3(nn.Module): + """Pretrained InceptionV3 network returning feature maps""" + + # Index of default block of inception to return, + # corresponds to output of final average pooling + DEFAULT_BLOCK_INDEX = 3 + + # Maps feature dimensionality to their output blocks indices + BLOCK_INDEX_BY_DIM = { + 64: 0, # First max pooling features + 192: 1, # Second max pooling featurs + 768: 2, # Pre-aux classifier features + 2048: 3 # Final average pooling features + } + + def __init__(self, + output_blocks=[DEFAULT_BLOCK_INDEX], + resize_input=True, + normalize_input=True, + requires_grad=False, + use_fid_inception=True): + """Build pretrained InceptionV3 + + Parameters + ---------- + output_blocks : list of int + Indices of blocks to return features of. Possible values are: + - 0: corresponds to output of first max pooling + - 1: corresponds to output of second max pooling + - 2: corresponds to output which is fed to aux classifier + - 3: corresponds to output of final average pooling + resize_input : bool + If true, bilinearly resizes input to width and height 299 before + feeding input to model. As the network without fully connected + layers is fully convolutional, it should be able to handle inputs + of arbitrary size, so resizing might not be strictly needed + normalize_input : bool + If true, scales the input from range (0, 1) to the range the + pretrained Inception network expects, namely (-1, 1) + requires_grad : bool + If true, parameters of the model require gradients. Possibly useful + for finetuning the network + use_fid_inception : bool + If true, uses the pretrained Inception model used in Tensorflow's + FID implementation. If false, uses the pretrained Inception model + available in torchvision. The FID Inception model has different + weights and a slightly different structure from torchvision's + Inception model. If you want to compute FID scores, you are + strongly advised to set this parameter to true to get comparable + results. + """ + super(InceptionV3, self).__init__() + + self.resize_input = resize_input + self.normalize_input = normalize_input + self.output_blocks = sorted(output_blocks) + self.last_needed_block = max(output_blocks) + + assert self.last_needed_block <= 3, \ + 'Last possible output block index is 3' + + self.blocks = nn.ModuleList() + + if use_fid_inception: + inception = fid_inception_v3() + else: + inception = models.inception_v3(pretrained=True) + + # Block 0: input to maxpool1 + block0 = [ + inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, + inception.Conv2d_2b_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block0)) + + # Block 1: maxpool1 to maxpool2 + if self.last_needed_block >= 1: + block1 = [ + inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block1)) + + # Block 2: maxpool2 to aux classifier + if self.last_needed_block >= 2: + block2 = [ + inception.Mixed_5b, + inception.Mixed_5c, + inception.Mixed_5d, + inception.Mixed_6a, + inception.Mixed_6b, + inception.Mixed_6c, + inception.Mixed_6d, + inception.Mixed_6e, + ] + self.blocks.append(nn.Sequential(*block2)) + + # Block 3: aux classifier to final avgpool + if self.last_needed_block >= 3: + block3 = [ + inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c, + nn.AdaptiveAvgPool2d(output_size=(1, 1)) + ] + self.blocks.append(nn.Sequential(*block3)) + + for param in self.parameters(): + param.requires_grad = requires_grad + + def forward(self, inp): + """Get Inception feature maps + + Parameters + ---------- + inp : torch.autograd.Variable + Input tensor of shape Bx3xHxW. Values are expected to be in + range (0, 1) + + Returns + ------- + List of torch.autograd.Variable, corresponding to the selected output + block, sorted ascending by index + """ + outp = [] + x = inp + + if self.resize_input: + x = F.interpolate( + x, size=(299, 299), mode='bilinear', align_corners=False) + + if self.normalize_input: + x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) + + for idx, block in enumerate(self.blocks): + x = block(x) + if idx in self.output_blocks: + outp.append(x) + + if idx == self.last_needed_block: + break + + return outp + + +def fid_inception_v3(): + """Build pretrained Inception model for FID computation + + The Inception model for FID computation uses a different set of weights + and has a slightly different structure than torchvision's Inception. + + This method first constructs torchvision's Inception and then patches the + necessary parts that are different in the FID Inception model. + """ + LOGGER.info('fid_inception_v3 called') + inception = models.inception_v3( + num_classes=1008, aux_logits=False, pretrained=False) + LOGGER.info('models.inception_v3 done') + inception.Mixed_5b = FIDInceptionA(192, pool_features=32) + inception.Mixed_5c = FIDInceptionA(256, pool_features=64) + inception.Mixed_5d = FIDInceptionA(288, pool_features=64) + inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) + inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) + inception.Mixed_7b = FIDInceptionE_1(1280) + inception.Mixed_7c = FIDInceptionE_2(2048) + + LOGGER.info('fid_inception_v3 patching done') + + state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) + LOGGER.info('fid_inception_v3 weights downloaded') + + inception.load_state_dict(state_dict) + LOGGER.info('fid_inception_v3 weights loaded into model') + + return inception + + +class FIDInceptionA(models.inception.InceptionA): + """InceptionA block patched for FID computation""" + + def __init__(self, in_channels, pool_features): + super(FIDInceptionA, self).__init__(in_channels, pool_features) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d( + x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionC(models.inception.InceptionC): + """InceptionC block patched for FID computation""" + + def __init__(self, in_channels, channels_7x7): + super(FIDInceptionC, self).__init__(in_channels, channels_7x7) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d( + x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_1(models.inception.InceptionE): + """First InceptionE block patched for FID computation""" + + def __init__(self, in_channels): + super(FIDInceptionE_1, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d( + x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_2(models.inception.InceptionE): + """Second InceptionE block patched for FID computation""" + + def __init__(self, in_channels): + super(FIDInceptionE_2, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: The FID Inception model uses max pooling instead of average + # pooling. This is likely an error in this specific Inception + # implementation, as other Inception models use average pooling here + # (which matches the description in the paper). + branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) diff --git a/modelscope/models/cv/image_inpainting/modules/perceptual.py b/modelscope/models/cv/image_inpainting/modules/perceptual.py new file mode 100644 index 00000000..80fe2b96 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/perceptual.py @@ -0,0 +1,47 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +from .ade20k import ModelBuilder + +IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None] +IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None] + + +class ResNetPL(nn.Module): + + def __init__(self, + weight=1, + weights_path=None, + arch_encoder='resnet50dilated', + segmentation=True): + super().__init__() + self.impl = ModelBuilder.get_encoder( + weights_path=weights_path, + arch_encoder=arch_encoder, + arch_decoder='ppm_deepsup', + fc_dim=2048, + segmentation=segmentation) + self.impl.eval() + for w in self.impl.parameters(): + w.requires_grad_(False) + + self.weight = weight + + def forward(self, pred, target): + pred = (pred - IMAGENET_MEAN.to(pred)) / IMAGENET_STD.to(pred) + target = (target - IMAGENET_MEAN.to(target)) / IMAGENET_STD.to(target) + + pred_feats = self.impl(pred, return_feature_maps=True) + target_feats = self.impl(target, return_feature_maps=True) + + result = torch.stack([ + F.mse_loss(cur_pred, cur_target) + for cur_pred, cur_target in zip(pred_feats, target_feats) + ]).sum() * self.weight + return result diff --git a/modelscope/models/cv/image_inpainting/modules/pix2pixhd.py b/modelscope/models/cv/image_inpainting/modules/pix2pixhd.py new file mode 100644 index 00000000..32e18f3e --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/pix2pixhd.py @@ -0,0 +1,75 @@ +""" +The implementation is adopted from +https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py +""" +import collections +import functools +import logging +from collections import defaultdict +from functools import partial + +import numpy as np +import torch.nn as nn + + +# Defines the PatchGAN discriminator with the specified arguments. +class NLayerDiscriminator(nn.Module): + + def __init__( + self, + input_nc=3, + ndf=64, + n_layers=4, + norm_layer=nn.BatchNorm2d, + ): + super().__init__() + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw - 1.0) / 2)) + sequence = [[ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True) + ]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + + cur_model = [] + cur_model += [ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ] + sequence.append(cur_model) + + nf_prev = nf + nf = min(nf * 2, 512) + + cur_model = [] + cur_model += [ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ] + sequence.append(cur_model) + + sequence += [[ + nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw) + ]] + + for n in range(len(sequence)): + setattr(self, 'model' + str(n), nn.Sequential(*sequence[n])) + + def get_all_activations(self, x): + res = [x] + for n in range(self.n_layers + 2): + model = getattr(self, 'model' + str(n)) + res.append(model(res[-1])) + return res[1:] + + def forward(self, x): + act = self.get_all_activations(x) + return act[-1], act[:-1] diff --git a/modelscope/models/cv/image_inpainting/refinement.py b/modelscope/models/cv/image_inpainting/refinement.py new file mode 100644 index 00000000..662d8a05 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/refinement.py @@ -0,0 +1,393 @@ +''' +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +''' +import cv2 +import numpy as np +import torch +import torch.nn as nn +from kornia.filters import gaussian_blur2d +from kornia.geometry.transform import resize +from kornia.morphology import erosion +from torch.nn import functional as F +from torch.optim import SGD, Adam +from tqdm import tqdm + +from .modules.ffc import FFCResnetBlock + + +def move_to_device(obj, device): + if isinstance(obj, nn.Module): + return obj.to(device) + if torch.is_tensor(obj): + return obj.to(device) + if isinstance(obj, (tuple, list)): + return [move_to_device(el, device) for el in obj] + if isinstance(obj, dict): + return {name: move_to_device(val, device) for name, val in obj.items()} + raise ValueError(f'Unexpected type {type(obj)}') + + +def ceil_modulo(x, mod): + if x % mod == 0: + return x + return (x // mod + 1) * mod + + +def pad_tensor_to_modulo(img, mod): + batch_size, channels, height, width = img.shape + out_height = ceil_modulo(height, mod) + out_width = ceil_modulo(width, mod) + return F.pad( + img, + pad=(0, out_width - width, 0, out_height - height), + mode='reflect') + + +def _pyrdown(im: torch.Tensor, downsize: tuple = None): + """downscale the image""" + if downsize is None: + downsize = (im.shape[2] // 2, im.shape[3] // 2) + assert im.shape[ + 1] == 3, 'Expected shape for the input to be (n,3,height,width)' + im = gaussian_blur2d(im, kernel_size=(5, 5), sigma=(1.0, 1.0)) + im = F.interpolate(im, size=downsize, mode='bilinear', align_corners=False) + return im + + +def _pyrdown_mask(mask: torch.Tensor, + downsize: tuple = None, + eps: float = 1e-8, + blur_mask: bool = True, + round_up: bool = True): + """downscale the mask tensor + + Parameters + ---------- + mask : torch.Tensor + mask of size (B, 1, H, W) + downsize : tuple, optional + size to downscale to. If None, image is downscaled to half, by default None + eps : float, optional + threshold value for binarizing the mask, by default 1e-8 + blur_mask : bool, optional + if True, apply gaussian filter before downscaling, by default True + round_up : bool, optional + if True, values above eps are marked 1, else, values below 1-eps are marked 0, by default True + + Returns + ------- + torch.Tensor + downscaled mask + """ + + if downsize is None: + downsize = (mask.shape[2] // 2, mask.shape[3] // 2) + assert mask.shape[ + 1] == 1, 'Expected shape for the input to be (n,1,height,width)' + if blur_mask is True: + mask = gaussian_blur2d(mask, kernel_size=(5, 5), sigma=(1.0, 1.0)) + mask = F.interpolate( + mask, size=downsize, mode='bilinear', align_corners=False) + else: + mask = F.interpolate( + mask, size=downsize, mode='bilinear', align_corners=False) + if round_up: + mask[mask >= eps] = 1 + mask[mask < eps] = 0 + else: + mask[mask >= 1.0 - eps] = 1 + mask[mask < 1.0 - eps] = 0 + return mask + + +def _erode_mask(mask: torch.Tensor, + ekernel: torch.Tensor = None, + eps: float = 1e-8): + """erode the mask, and set gray pixels to 0""" + if ekernel is not None: + mask = erosion(mask, ekernel) + mask[mask >= 1.0 - eps] = 1 + mask[mask < 1.0 - eps] = 0 + return mask + + +def _l1_loss(pred: torch.Tensor, + pred_downscaled: torch.Tensor, + ref: torch.Tensor, + mask: torch.Tensor, + mask_downscaled: torch.Tensor, + image: torch.Tensor, + on_pred: bool = True): + """l1 loss on src pixels, and downscaled predictions if on_pred=True""" + loss = torch.mean(torch.abs(pred[mask < 1e-8] - image[mask < 1e-8])) + if on_pred: + loss += torch.mean( + torch.abs(pred_downscaled[mask_downscaled >= 1e-8] + - ref[mask_downscaled >= 1e-8])) + return loss + + +def _infer(image: torch.Tensor, + mask: torch.Tensor, + forward_front: nn.Module, + forward_rears: nn.Module, + ref_lower_res: torch.Tensor, + orig_shape: tuple, + devices: list, + scale_ind: int, + n_iters: int = 15, + lr: float = 0.002): + """Performs inference with refinement at a given scale. + + Parameters + ---------- + image : torch.Tensor + input image to be inpainted, of size (1,3,H,W) + mask : torch.Tensor + input inpainting mask, of size (1,1,H,W) + forward_front : nn.Module + the front part of the inpainting network + forward_rears : nn.Module + the rear part of the inpainting network + ref_lower_res : torch.Tensor + the inpainting at previous scale, used as reference image + orig_shape : tuple + shape of the original input image before padding + devices : list + list of available devices + scale_ind : int + the scale index + n_iters : int, optional + number of iterations of refinement, by default 15 + lr : float, optional + learning rate, by default 0.002 + + Returns + ------- + torch.Tensor + inpainted image + """ + masked_image = image * (1 - mask) + masked_image = torch.cat([masked_image, mask], dim=1) + + mask = mask.repeat(1, 3, 1, 1) + if ref_lower_res is not None: + ref_lower_res = ref_lower_res.detach() + with torch.no_grad(): + z1, z2 = forward_front(masked_image) + # Inference + mask = mask.to(devices[-1]) + ekernel = torch.from_numpy( + cv2.getStructuringElement(cv2.MORPH_ELLIPSE, + (15, 15)).astype(bool)).float() + ekernel = ekernel.to(devices[-1]) + image = image.to(devices[-1]) + z1, z2 = z1.detach().to(devices[0]), z2.detach().to(devices[0]) + z1.requires_grad, z2.requires_grad = True, True + + optimizer = Adam([z1, z2], lr=lr) + + pbar = tqdm(range(n_iters), leave=False) + for idi in pbar: + optimizer.zero_grad() + input_feat = (z1, z2) + for idd, forward_rear in enumerate(forward_rears): + output_feat = forward_rear(input_feat) + if idd < len(devices) - 1: + midz1, midz2 = output_feat + midz1, midz2 = midz1.to(devices[idd + 1]), midz2.to( + devices[idd + 1]) + input_feat = (midz1, midz2) + else: + pred = output_feat + + if ref_lower_res is None: + break + losses = {} + # scaled loss with downsampler + pred_downscaled = _pyrdown(pred[:, :, :orig_shape[0], :orig_shape[1]]) + mask_downscaled = _pyrdown_mask( + mask[:, :1, :orig_shape[0], :orig_shape[1]], + blur_mask=False, + round_up=False) + mask_downscaled = _erode_mask(mask_downscaled, ekernel=ekernel) + mask_downscaled = mask_downscaled.repeat(1, 3, 1, 1) + losses['ms_l1'] = _l1_loss( + pred, + pred_downscaled, + ref_lower_res, + mask, + mask_downscaled, + image, + on_pred=True) + + loss = sum(losses.values()) + pbar.set_description( + 'Refining scale {} using scale {} ...current loss: {:.4f}'.format( + scale_ind + 1, scale_ind, loss.item())) + if idi < n_iters - 1: + loss.backward() + optimizer.step() + del pred_downscaled + del loss + del pred + # "pred" is the prediction after Plug-n-Play module + inpainted = mask * pred + (1 - mask) * image + inpainted = inpainted.detach().cpu() + return inpainted + + +def _get_image_mask_pyramid(batch: dict, min_side: int, max_scales: int, + px_budget: int): + """Build the image mask pyramid + + Parameters + ---------- + batch : dict + batch containing image, mask, etc + min_side : int + minimum side length to limit the number of scales of the pyramid + max_scales : int + maximum number of scales allowed + px_budget : int + the product H*W cannot exceed this budget, because of resource constraints + + Returns + ------- + tuple + image-mask pyramid in the form of list of images and list of masks + """ + + assert batch['image'].shape[ + 0] == 1, 'refiner works on only batches of size 1!' + + h, w = batch['unpad_to_size'] + h, w = h[0].item(), w[0].item() + + image = batch['image'][..., :h, :w] + mask = batch['mask'][..., :h, :w] + if h * w > px_budget: + # resize + ratio = np.sqrt(px_budget / float(h * w)) + h_orig, w_orig = h, w + h, w = int(h * ratio), int(w * ratio) + print( + f'Original image too large for refinement! Resizing {(h_orig,w_orig)} to {(h,w)}...' + ) + image = resize( + image, (h, w), interpolation='bilinear', align_corners=False) + mask = resize( + mask, (h, w), interpolation='bilinear', align_corners=False) + mask[mask > 1e-8] = 1 + breadth = min(h, w) + n_scales = min(1 + int(round(max(0, np.log2(breadth / min_side)))), + max_scales) + ls_images = [] + ls_masks = [] + + ls_images.append(image) + ls_masks.append(mask) + + for _ in range(n_scales - 1): + image_p = _pyrdown(ls_images[-1]) + mask_p = _pyrdown_mask(ls_masks[-1]) + ls_images.append(image_p) + ls_masks.append(mask_p) + # reverse the lists because we want the lowest resolution image as index 0 + return ls_images[::-1], ls_masks[::-1] + + +def refine_predict(batch: dict, inpainter: nn.Module, gpu_ids: str, + modulo: int, n_iters: int, lr: float, min_side: int, + max_scales: int, px_budget: int): + """Refines the inpainting of the network + + Parameters + ---------- + batch : dict + image-mask batch, currently we assume the batchsize to be 1 + inpainter : nn.Module + the inpainting neural network + gpu_ids : str + the GPU ids of the machine to use. If only single GPU, use: "0," + modulo : int + pad the image to ensure dimension % modulo == 0 + n_iters : int + number of iterations of refinement for each scale + lr : float + learning rate + min_side : int + all sides of image on all scales should be >= min_side / sqrt(2) + max_scales : int + max number of downscaling scales for the image-mask pyramid + px_budget : int + pixels budget. Any image will be resized to satisfy height*width <= px_budget + + Returns + ------- + torch.Tensor + inpainted image of size (1,3,H,W) + """ + inpainter = inpainter.model + assert not inpainter.training + assert not inpainter.add_noise_kwargs + assert inpainter.concat_mask + + gpu_ids = [ + f'cuda:{gpuid}' for gpuid in gpu_ids.replace(' ', '').split(',') + if gpuid.isdigit() + ] + n_resnet_blocks = 0 + first_resblock_ind = 0 + found_first_resblock = False + for idl in range(len(inpainter.generator.model)): + if isinstance(inpainter.generator.model[idl], FFCResnetBlock): + n_resnet_blocks += 1 + found_first_resblock = True + elif not found_first_resblock: + first_resblock_ind += 1 + resblocks_per_gpu = n_resnet_blocks // len(gpu_ids) + + devices = [torch.device(gpu_id) for gpu_id in gpu_ids] + + # split the model into front, and rear parts + forward_front = inpainter.generator.model[0:first_resblock_ind] + forward_front.to(devices[0]) + forward_rears = [] + for idd in range(len(gpu_ids)): + if idd < len(gpu_ids) - 1: + forward_rears.append( + inpainter.generator.model[first_resblock_ind + + resblocks_per_gpu + * (idd):first_resblock_ind + + resblocks_per_gpu * (idd + 1)]) + else: + forward_rears.append( + inpainter.generator.model[first_resblock_ind + + resblocks_per_gpu * (idd):]) + forward_rears[idd].to(devices[idd]) + + ls_images, ls_masks = _get_image_mask_pyramid(batch, min_side, max_scales, + px_budget) + image_inpainted = None + + for ids, (image, mask) in enumerate(zip(ls_images, ls_masks)): + orig_shape = image.shape[2:] + image = pad_tensor_to_modulo(image, modulo) + mask = pad_tensor_to_modulo(mask, modulo) + mask[mask >= 1e-8] = 1.0 + mask[mask < 1e-8] = 0.0 + image, mask = move_to_device(image, devices[0]), move_to_device( + mask, devices[0]) + if image_inpainted is not None: + image_inpainted = move_to_device(image_inpainted, devices[-1]) + image_inpainted = _infer(image, mask, forward_front, forward_rears, + image_inpainted, orig_shape, devices, ids, + n_iters, lr) + image_inpainted = image_inpainted[:, :, :orig_shape[0], :orig_shape[1]] + # detach everything to save resources + image = image.detach().cpu() + mask = mask.detach().cpu() + + return image_inpainted diff --git a/modelscope/models/cv/image_portrait_enhancement/align_faces.py b/modelscope/models/cv/image_portrait_enhancement/align_faces.py index 776b06d8..e6852f8c 100755 --- a/modelscope/models/cv/image_portrait_enhancement/align_faces.py +++ b/modelscope/models/cv/image_portrait_enhancement/align_faces.py @@ -1,3 +1,5 @@ +# Part of the implementation is borrowed and modified from Face-Alignment, +# publicly available at https://github.com/foamliu/Face-Alignment/blob/master/align_faces.py import cv2 import numpy as np from skimage import transform as trans diff --git a/modelscope/models/cv/image_portrait_enhancement/eqface/fqa.py b/modelscope/models/cv/image_portrait_enhancement/eqface/fqa.py index fe4081a4..51f2206e 100755 --- a/modelscope/models/cv/image_portrait_enhancement/eqface/fqa.py +++ b/modelscope/models/cv/image_portrait_enhancement/eqface/fqa.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os import cv2 diff --git a/modelscope/models/cv/image_portrait_enhancement/eqface/model_resnet.py b/modelscope/models/cv/image_portrait_enhancement/eqface/model_resnet.py index ea3c4f2a..e0e8e9d5 100644 --- a/modelscope/models/cv/image_portrait_enhancement/eqface/model_resnet.py +++ b/modelscope/models/cv/image_portrait_enhancement/eqface/model_resnet.py @@ -1,3 +1,5 @@ +# The implementation is adopted from FaceQuality, made publicly available under the MIT License +# at https://github.com/deepcam-cn/FaceQuality/blob/master/models/model_resnet.py import torch from torch import nn diff --git a/modelscope/models/cv/image_portrait_enhancement/gpen.py b/modelscope/models/cv/image_portrait_enhancement/gpen.py index 2e21dbc0..86009a41 100755 --- a/modelscope/models/cv/image_portrait_enhancement/gpen.py +++ b/modelscope/models/cv/image_portrait_enhancement/gpen.py @@ -1,3 +1,5 @@ +# The GPEN implementation is also open-sourced by the authors, +# and available at https://github.com/yangxy/GPEN/blob/main/face_model/gpen_model.py import functools import itertools import math diff --git a/modelscope/models/cv/image_portrait_enhancement/image_portrait_enhancement.py b/modelscope/models/cv/image_portrait_enhancement/image_portrait_enhancement.py index 3250d393..26e9e532 100644 --- a/modelscope/models/cv/image_portrait_enhancement/image_portrait_enhancement.py +++ b/modelscope/models/cv/image_portrait_enhancement/image_portrait_enhancement.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import math import os.path as osp from copy import deepcopy @@ -34,7 +35,7 @@ class ImagePortraitEnhancement(TorchModel): """ super().__init__(model_dir, *args, **kwargs) - self.size = 512 + self.size = 256 self.style_dim = 512 self.n_mlp = 8 self.mean_path_length = 0 @@ -130,9 +131,9 @@ class ImagePortraitEnhancement(TorchModel): return path_penalty, path_mean.detach(), path_lengths @torch.no_grad() - def _evaluate_postprocess(self, src: Tensor, + def _evaluate_postprocess(self, input: Tensor, target: Tensor) -> Dict[str, list]: - preds, _ = self.generator(src) + preds, _ = self.generator(input) preds = list(torch.split(preds, 1, 0)) targets = list(torch.split(target, 1, 0)) @@ -143,11 +144,11 @@ class ImagePortraitEnhancement(TorchModel): return {'pred': preds, 'target': targets} - def _train_forward_d(self, src: Tensor, target: Tensor) -> Tensor: + def _train_forward_d(self, input: Tensor, target: Tensor) -> Tensor: self.requires_grad(self.generator, False) self.requires_grad(self.discriminator, True) - preds, _ = self.generator(src) + preds, _ = self.generator(input) fake_pred = self.discriminator(preds) real_pred = self.discriminator(target) @@ -155,27 +156,27 @@ class ImagePortraitEnhancement(TorchModel): return d_loss - def _train_forward_d_r1(self, src: Tensor, target: Tensor) -> Tensor: - src.requires_grad = True + def _train_forward_d_r1(self, input: Tensor, target: Tensor) -> Tensor: + input.requires_grad = True target.requires_grad = True real_pred = self.discriminator(target) r1_loss = self.d_r1_loss(real_pred, target) return r1_loss - def _train_forward_g(self, src: Tensor, target: Tensor) -> Tensor: + def _train_forward_g(self, input: Tensor, target: Tensor) -> Tensor: self.requires_grad(self.generator, True) self.requires_grad(self.discriminator, False) - preds, _ = self.generator(src) + preds, _ = self.generator(input) fake_pred = self.discriminator(preds) - g_loss = self.g_nonsaturating_loss(fake_pred, preds, target, src) + g_loss = self.g_nonsaturating_loss(fake_pred, preds, target, input) return g_loss - def _train_forward_g_path(self, src: Tensor, target: Tensor) -> Tensor: - fake_img, latents = self.generator(src, return_latents=True) + def _train_forward_g_path(self, input: Tensor, target: Tensor) -> Tensor: + fake_img, latents = self.generator(input, return_latents=True) path_loss, self.mean_path_length, path_lengths = self.g_path_regularize( fake_img, latents, self.mean_path_length) @@ -183,8 +184,8 @@ class ImagePortraitEnhancement(TorchModel): return path_loss @torch.no_grad() - def _inference_forward(self, src: Tensor) -> Dict[str, Tensor]: - return {'outputs': (self.generator(src)[0] * 0.5 + 0.5).clamp(0, 1)} + def _inference_forward(self, input: Tensor) -> Dict[str, Tensor]: + return {'outputs': (self.generator(input)[0] * 0.5 + 0.5).clamp(0, 1)} def forward(self, input: Dict[str, Tensor]) -> Dict[str, Union[list, Tensor]]: diff --git a/modelscope/models/cv/image_portrait_enhancement/losses/helpers.py b/modelscope/models/cv/image_portrait_enhancement/losses/helpers.py index 35ca202f..86f6f227 100644 --- a/modelscope/models/cv/image_portrait_enhancement/losses/helpers.py +++ b/modelscope/models/cv/image_portrait_enhancement/losses/helpers.py @@ -1,3 +1,5 @@ +# The implementation is adopted from InsightFace_Pytorch, +# made publicly available under the MIT License at https://github.com/TreB1eN/InsightFace_Pytorch/blob/master/model.py from collections import namedtuple import torch diff --git a/modelscope/models/cv/image_portrait_enhancement/losses/losses.py b/modelscope/models/cv/image_portrait_enhancement/losses/losses.py index 8934eee7..0f5198c3 100644 --- a/modelscope/models/cv/image_portrait_enhancement/losses/losses.py +++ b/modelscope/models/cv/image_portrait_enhancement/losses/losses.py @@ -1,3 +1,5 @@ +# The GPEN implementation is also open-sourced by the authors, +# and available at https://github.com/yangxy/GPEN/tree/main/training/loss/id_loss.py import torch import torch.nn as nn import torch.nn.functional as F diff --git a/modelscope/models/cv/image_portrait_enhancement/losses/model_irse.py b/modelscope/models/cv/image_portrait_enhancement/losses/model_irse.py index 3b87d7fd..00dc7c52 100644 --- a/modelscope/models/cv/image_portrait_enhancement/losses/model_irse.py +++ b/modelscope/models/cv/image_portrait_enhancement/losses/model_irse.py @@ -1,3 +1,5 @@ +# The implementation is adopted from InsightFace_Pytorch, +# made publicly available under the MIT License at https://github.com/TreB1eN/InsightFace_Pytorch/blob/master/model.py from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, Module, PReLU, Sequential) diff --git a/modelscope/models/cv/image_portrait_enhancement/retinaface/detection.py b/modelscope/models/cv/image_portrait_enhancement/retinaface/detection.py index c294438a..7ad780a8 100755 --- a/modelscope/models/cv/image_portrait_enhancement/retinaface/detection.py +++ b/modelscope/models/cv/image_portrait_enhancement/retinaface/detection.py @@ -1,3 +1,5 @@ +# The GPEN implementation is also open-sourced by the authors, +# and available at https://github.com/yangxy/GPEN/blob/main/face_detect/retinaface_detection.py import os import cv2 diff --git a/modelscope/models/cv/image_portrait_enhancement/retinaface/models/net.py b/modelscope/models/cv/image_portrait_enhancement/retinaface/models/net.py index 0546e0bb..24451e96 100755 --- a/modelscope/models/cv/image_portrait_enhancement/retinaface/models/net.py +++ b/modelscope/models/cv/image_portrait_enhancement/retinaface/models/net.py @@ -1,3 +1,5 @@ +# The implementation is adopted from Pytorch_Retinaface, made pubicly available under the MIT License +# at https://github.com/biubug6/Pytorch_Retinaface/tree/master/models/net.py import time import torch diff --git a/modelscope/models/cv/image_portrait_enhancement/retinaface/models/retinaface.py b/modelscope/models/cv/image_portrait_enhancement/retinaface/models/retinaface.py index af1d706d..64d95971 100755 --- a/modelscope/models/cv/image_portrait_enhancement/retinaface/models/retinaface.py +++ b/modelscope/models/cv/image_portrait_enhancement/retinaface/models/retinaface.py @@ -1,3 +1,5 @@ +# The implementation is adopted from Pytorch_Retinaface, made pubicly available under the MIT License +# at https://github.com/biubug6/Pytorch_Retinaface/tree/master/models/retinaface.py from collections import OrderedDict import torch diff --git a/modelscope/models/cv/image_to_image_generation/__init__.py b/modelscope/models/cv/image_to_image_generation/__init__.py index fb408086..1af3e55f 100644 --- a/modelscope/models/cv/image_to_image_generation/__init__.py +++ b/modelscope/models/cv/image_to_image_generation/__init__.py @@ -1,2 +1,2 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. from . import data, models, ops diff --git a/modelscope/models/cv/image_to_image_generation/data/__init__.py b/modelscope/models/cv/image_to_image_generation/data/__init__.py index 33c8cf44..22b9d22c 100644 --- a/modelscope/models/cv/image_to_image_generation/data/__init__.py +++ b/modelscope/models/cv/image_to_image_generation/data/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule diff --git a/modelscope/models/cv/image_to_image_generation/data/transforms.py b/modelscope/models/cv/image_to_image_generation/data/transforms.py index 5376d813..29a25b4b 100644 --- a/modelscope/models/cv/image_to_image_generation/data/transforms.py +++ b/modelscope/models/cv/image_to_image_generation/data/transforms.py @@ -1,3 +1,4 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. import math import random diff --git a/modelscope/models/cv/image_to_image_generation/models/__init__.py b/modelscope/models/cv/image_to_image_generation/models/__init__.py index ec6a46fd..e98421f2 100644 --- a/modelscope/models/cv/image_to_image_generation/models/__init__.py +++ b/modelscope/models/cv/image_to_image_generation/models/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule diff --git a/modelscope/models/cv/image_to_image_generation/ops/__init__.py b/modelscope/models/cv/image_to_image_generation/ops/__init__.py index 49674b49..e3dac584 100644 --- a/modelscope/models/cv/image_to_image_generation/ops/__init__.py +++ b/modelscope/models/cv/image_to_image_generation/ops/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule diff --git a/modelscope/models/cv/image_to_image_translation/__init__.py b/modelscope/models/cv/image_to_image_translation/__init__.py index e69de29b..35aab6be 100644 --- a/modelscope/models/cv/image_to_image_translation/__init__.py +++ b/modelscope/models/cv/image_to_image_translation/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .model_translation import UNet + +else: + _import_structure = { + 'image_to_image_translation_model': ['UNet'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_to_image_translation/data/__init__.py b/modelscope/models/cv/image_to_image_translation/data/__init__.py index 72450016..724bca04 100644 --- a/modelscope/models/cv/image_to_image_translation/data/__init__.py +++ b/modelscope/models/cv/image_to_image_translation/data/__init__.py @@ -1 +1,2 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. from .transforms import * # noqa F403 diff --git a/modelscope/models/cv/image_to_image_translation/models/__init__.py b/modelscope/models/cv/image_to_image_translation/models/__init__.py index 322d78f2..7fdd8189 100644 --- a/modelscope/models/cv/image_to_image_translation/models/__init__.py +++ b/modelscope/models/cv/image_to_image_translation/models/__init__.py @@ -1,2 +1,3 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. from .autoencoder import * # noqa F403 from .clip import * # noqa F403 diff --git a/modelscope/models/cv/image_to_image_translation/ops/__init__.py b/modelscope/models/cv/image_to_image_translation/ops/__init__.py index 59082d72..474c811b 100644 --- a/modelscope/models/cv/image_to_image_translation/ops/__init__.py +++ b/modelscope/models/cv/image_to_image_translation/ops/__init__.py @@ -1,3 +1,4 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. from .degradation import * # noqa F403 from .diffusion import * # noqa F403 from .losses import * # noqa F403 diff --git a/modelscope/models/cv/movie_scene_segmentation/model.py b/modelscope/models/cv/movie_scene_segmentation/model.py index 1232d427..8117961a 100644 --- a/modelscope/models/cv/movie_scene_segmentation/model.py +++ b/modelscope/models/cv/movie_scene_segmentation/model.py @@ -67,7 +67,6 @@ class MovieSceneSegmentationModel(TorchModel): mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) - self.infer_result = {'vid': [], 'sid': [], 'pred': []} sampling_method = self.cfg.dataset.sampling_method.name self.neighbor_size = self.cfg.dataset.sampling_method.params[ sampling_method].neighbor_size @@ -104,6 +103,8 @@ class MovieSceneSegmentationModel(TorchModel): shot_num = len(sids) cnt = shot_num // bs + 1 + infer_sid, infer_pred = [], [] + infer_result = {} for i in range(cnt): start = i * bs end = (i + 1) * bs if (i + 1) * bs < shot_num else shot_num @@ -112,13 +113,14 @@ class MovieSceneSegmentationModel(TorchModel): input_ = torch.stack(input_) outputs = self.shared_step(input_) # shape [b,2] prob = F.softmax(outputs, dim=1) - self.infer_result['sid'].extend(sid_.cpu().detach().numpy()) - self.infer_result['pred'].extend(prob[:, 1].cpu().detach().numpy()) - self.infer_result['pred'] = np.stack(self.infer_result['pred']) + infer_sid.extend(sid_.cpu().detach().numpy()) + infer_pred.extend(prob[:, 1].cpu().detach().numpy()) + infer_result.update({'pred': np.stack(infer_pred)}) + infer_result.update({'sid': infer_sid}) - assert len(self.infer_result['sid']) == len(sids) - assert len(self.infer_result['pred']) == len(inputs) - return self.infer_result + assert len(infer_result['sid']) == len(sids) + assert len(infer_result['pred']) == len(inputs) + return infer_result def shared_step(self, inputs): with torch.no_grad(): @@ -162,11 +164,12 @@ class MovieSceneSegmentationModel(TorchModel): thres = self.cfg.pipeline.save_threshold anno_dict = get_pred_boundary(pred_dict, thres) - scene_dict_lst, scene_list = pred2scene(self.shot2keyf, anno_dict) + scene_dict_lst, scene_list, shot_num, shot_dict_lst = pred2scene( + self.shot2keyf, anno_dict) if self.cfg.pipeline.save_split_scene: re_dir = scene2video(inputs['input_video_pth'], scene_list, thres) print(f'Split scene video saved to {re_dir}') - return len(scene_list), scene_dict_lst + return len(scene_list), scene_dict_lst, shot_num, shot_dict_lst def preprocess(self, inputs): logger.info('Begin shot detect......') diff --git a/modelscope/models/cv/movie_scene_segmentation/utils/save_op.py b/modelscope/models/cv/movie_scene_segmentation/utils/save_op.py index b350ff13..3339e1a3 100644 --- a/modelscope/models/cv/movie_scene_segmentation/utils/save_op.py +++ b/modelscope/models/cv/movie_scene_segmentation/utils/save_op.py @@ -22,15 +22,23 @@ def pred2scene(shot2keyf, anno_dict): scene_list, pair_list = get_demo_scene_list(shot2keyf, anno_dict) scene_dict_lst = [] + shot_num = len(shot2keyf) + shot_dict_lst = [] + for item in shot2keyf: + tmp = item.split(' ') + shot_dict_lst.append({ + 'frame': [tmp[0], tmp[1]], + 'timestamps': [tmp[-2], tmp[-1]] + }) assert len(scene_list) == len(pair_list) for scene_ind, scene_item in enumerate(scene_list): scene_dict_lst.append({ 'shot': pair_list[scene_ind], 'frame': scene_item[0], - 'timestamp': scene_item[1] + 'timestamps': scene_item[1] }) - return scene_dict_lst, scene_list + return scene_dict_lst, scene_list, shot_num, shot_dict_lst def scene2video(source_movie_fn, scene_list, thres): diff --git a/modelscope/models/cv/object_detection/__init__.py b/modelscope/models/cv/object_detection/__init__.py index 974375ce..0c782d7b 100644 --- a/modelscope/models/cv/object_detection/__init__.py +++ b/modelscope/models/cv/object_detection/__init__.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: else: _import_structure = { 'mmdet_model': ['DetectionModel'], - 'yolox_pai': ['YOLOX'] + 'yolox_pai': ['YOLOX'], } import sys diff --git a/modelscope/models/cv/object_detection/yolox_pai.py b/modelscope/models/cv/object_detection/yolox_pai.py index 985cc136..46bd4e3c 100644 --- a/modelscope/models/cv/object_detection/yolox_pai.py +++ b/modelscope/models/cv/object_detection/yolox_pai.py @@ -9,6 +9,9 @@ from modelscope.utils.constant import Tasks @MODELS.register_module( group_key=Tasks.image_object_detection, module_name=Models.yolox) +@MODELS.register_module( + group_key=Tasks.image_object_detection, + module_name=Models.image_object_detection_auto) class YOLOX(EasyCVBaseModel, _YOLOX): def __init__(self, model_dir=None, *args, **kwargs): diff --git a/modelscope/models/cv/product_retrieval_embedding/__init__.py b/modelscope/models/cv/product_retrieval_embedding/__init__.py index 7a02a60f..2cbc9099 100644 --- a/modelscope/models/cv/product_retrieval_embedding/__init__.py +++ b/modelscope/models/cv/product_retrieval_embedding/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule diff --git a/modelscope/models/cv/product_retrieval_embedding/item_detection.py b/modelscope/models/cv/product_retrieval_embedding/item_detection.py index d5589969..2002c6cb 100644 --- a/modelscope/models/cv/product_retrieval_embedding/item_detection.py +++ b/modelscope/models/cv/product_retrieval_embedding/item_detection.py @@ -1,3 +1,4 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. import cv2 import numpy as np diff --git a/modelscope/models/cv/product_retrieval_embedding/item_embedding.py b/modelscope/models/cv/product_retrieval_embedding/item_embedding.py index 0444596c..ea9ec846 100644 --- a/modelscope/models/cv/product_retrieval_embedding/item_embedding.py +++ b/modelscope/models/cv/product_retrieval_embedding/item_embedding.py @@ -1,3 +1,4 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. import cv2 import numpy as np import torch.nn as nn diff --git a/modelscope/models/cv/product_retrieval_embedding/item_model.py b/modelscope/models/cv/product_retrieval_embedding/item_model.py index 85a636c0..3964efbe 100644 --- a/modelscope/models/cv/product_retrieval_embedding/item_model.py +++ b/modelscope/models/cv/product_retrieval_embedding/item_model.py @@ -1,3 +1,5 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + import os.path as osp from typing import Any, Dict diff --git a/modelscope/models/cv/product_segmentation/seg_infer.py b/modelscope/models/cv/product_segmentation/seg_infer.py index 876fac66..8814d619 100644 --- a/modelscope/models/cv/product_segmentation/seg_infer.py +++ b/modelscope/models/cv/product_segmentation/seg_infer.py @@ -59,9 +59,8 @@ mean, std = np.array([[[124.55, 118.90, 102.94]]]), np.array([[[56.77, 55.97, 57.50]]]) -def inference(model, device, input_path): - img = Image.open(input_path) - img = np.array(img.convert('RGB')).astype(np.float32) +def inference(model, device, img): + img = img.cpu().numpy() img = (img - mean) / std img = cv2.resize(img, dsize=(448, 448), interpolation=cv2.INTER_LINEAR) img = torch.from_numpy(img) diff --git a/modelscope/models/cv/realtime_object_detection/__init__.py b/modelscope/models/cv/realtime_object_detection/__init__.py index aed13cec..66156977 100644 --- a/modelscope/models/cv/realtime_object_detection/__init__.py +++ b/modelscope/models/cv/realtime_object_detection/__init__.py @@ -5,9 +5,11 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .realtime_detector import RealtimeDetector + from .realtime_video_detector import RealtimeVideoDetector else: _import_structure = { 'realtime_detector': ['RealtimeDetector'], + 'realtime_video_detector': ['RealtimeVideoDetector'], } import sys diff --git a/modelscope/models/cv/realtime_object_detection/realtime_video_detector.py b/modelscope/models/cv/realtime_object_detection/realtime_video_detector.py new file mode 100644 index 00000000..3830fb42 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/realtime_video_detector.py @@ -0,0 +1,121 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import argparse +import logging as logger +import os +import os.path as osp +import time + +import cv2 +import json +import torch +from tqdm import tqdm + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.preprocessors import LoadImage +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from .utils import timestamp_format +from .yolox.data.data_augment import ValTransform +from .yolox.exp import get_exp_by_name +from .yolox.utils import postprocess + + +@MODELS.register_module( + group_key=Tasks.video_object_detection, + module_name=Models.realtime_video_object_detection) +class RealtimeVideoDetector(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + self.config = Config.from_file( + os.path.join(self.model_dir, ModelFile.CONFIGURATION)) + + # model type + self.exp = get_exp_by_name(self.config.model_type) + + # build model + self.model = self.exp.get_model() + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) + ckpt = torch.load(model_path, map_location='cpu') + + # load the model state dict + self.model.load_state_dict(ckpt['model']) + self.model.eval() + + # params setting + self.exp.num_classes = self.config.num_classes + self.confthre = self.config.conf_thr + self.num_classes = self.exp.num_classes + self.nmsthre = self.exp.nmsthre + self.test_size = self.exp.test_size + self.preproc = ValTransform(legacy=False) + self.current_buffer = None + self.label_mapping = self.config['labels'] + + def inference(self, img): + with torch.no_grad(): + outputs, self.current_buffer = self.model( + img, buffer=self.current_buffer, mode='on_pipe') + return outputs + + def forward(self, inputs): + return self.inference_video(inputs) + + def preprocess(self, img): + img = LoadImage.convert_to_ndarray(img) + height, width = img.shape[:2] + self.ratio = min(self.test_size[0] / img.shape[0], + self.test_size[1] / img.shape[1]) + + img, _ = self.preproc(img, None, self.test_size) + img = torch.from_numpy(img).unsqueeze(0) + img = img.float() + + # Video decoding and preprocessing automatically are not supported by Pipeline/Model + # Sending preprocessed video frame tensor to GPU buffer self-adaptively + if next(self.model.parameters()).is_cuda: + img = img.to(next(self.model.parameters()).device) + return img + + def postprocess(self, input): + outputs = postprocess( + input, + self.num_classes, + self.confthre, + self.nmsthre, + class_agnostic=True) + + if len(outputs) == 1: + bboxes = outputs[0][:, 0:4].cpu().numpy() / self.ratio + scores = outputs[0][:, 5].cpu().numpy() + labels = outputs[0][:, 6].cpu().int().numpy() + pred_label_names = [] + for lab in labels: + pred_label_names.append(self.label_mapping[lab]) + + return bboxes, scores, pred_label_names + + def inference_video(self, v_path): + outputs = [] + desc = 'Detecting video: {}'.format(v_path) + for frame_idx, (frame, result) in enumerate( + tqdm(self.inference_video_iter(v_path), desc=desc)): + result = result + (timestamp_format(seconds=frame_idx + / self.fps), ) + outputs.append(result) + + return outputs + + def inference_video_iter(self, v_path): + capture = cv2.VideoCapture(v_path) + self.fps = capture.get(cv2.CAP_PROP_FPS) + while capture.isOpened(): + ret, frame = capture.read() + if not ret: + break + output = self.preprocess(frame) + output = self.inference(output) + output = self.postprocess(output) + yield frame, output diff --git a/modelscope/models/cv/realtime_object_detection/utils.py b/modelscope/models/cv/realtime_object_detection/utils.py new file mode 100644 index 00000000..c3d7a4c6 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/utils.py @@ -0,0 +1,9 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math + + +def timestamp_format(seconds): + m, s = divmod(seconds, 60) + h, m = divmod(m, 60) + time = '%02d:%02d:%06.3f' % (h, m, s) + return time diff --git a/modelscope/models/cv/realtime_object_detection/yolox/exp/build.py b/modelscope/models/cv/realtime_object_detection/yolox/exp/build.py index 4858100c..5865c53b 100644 --- a/modelscope/models/cv/realtime_object_detection/yolox/exp/build.py +++ b/modelscope/models/cv/realtime_object_detection/yolox/exp/build.py @@ -13,6 +13,8 @@ def get_exp_by_name(exp_name): from .default import YoloXNanoExp as YoloXExp elif exp == 'yolox_tiny': from .default import YoloXTinyExp as YoloXExp + elif exp == 'streamyolo': + from .default import StreamYoloExp as YoloXExp else: pass return YoloXExp() diff --git a/modelscope/models/cv/realtime_object_detection/yolox/exp/default/__init__.py b/modelscope/models/cv/realtime_object_detection/yolox/exp/default/__init__.py index 552bbccd..cfec836c 100644 --- a/modelscope/models/cv/realtime_object_detection/yolox/exp/default/__init__.py +++ b/modelscope/models/cv/realtime_object_detection/yolox/exp/default/__init__.py @@ -1,5 +1,5 @@ # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX - +from .streamyolo import StreamYoloExp from .yolox_nano import YoloXNanoExp from .yolox_s import YoloXSExp from .yolox_tiny import YoloXTinyExp diff --git a/modelscope/models/cv/realtime_object_detection/yolox/exp/default/streamyolo.py b/modelscope/models/cv/realtime_object_detection/yolox/exp/default/streamyolo.py new file mode 100644 index 00000000..5a62c8fc --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/exp/default/streamyolo.py @@ -0,0 +1,43 @@ +# The implementation is based on StreamYOLO, available at https://github.com/yancie-yjr/StreamYOLO +import os +import sys + +import torch + +from ..yolox_base import Exp as YoloXExp + + +class StreamYoloExp(YoloXExp): + + def __init__(self): + super(YoloXExp, self).__init__() + self.depth = 1.0 + self.width = 1.0 + self.num_classes = 8 + self.test_size = (600, 960) + self.test_conf = 0.3 + self.nmsthre = 0.65 + + def get_model(self): + from ...models import StreamYOLO, DFPPAFPN, TALHead + + def init_yolo(M): + for m in M.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eps = 1e-3 + m.momentum = 0.03 + + if getattr(self, 'model', None) is None: + in_channels = [256, 512, 1024] + backbone = DFPPAFPN( + self.depth, self.width, in_channels=in_channels) + head = TALHead( + self.num_classes, + self.width, + in_channels=in_channels, + gamma=1.0, + ignore_thr=0.5, + ignore_value=1.6) + self.model = StreamYOLO(backbone, head) + + return self.model diff --git a/modelscope/models/cv/realtime_object_detection/yolox/exp/yolox_base.py b/modelscope/models/cv/realtime_object_detection/yolox/exp/yolox_base.py index a2a41535..c5159a9f 100644 --- a/modelscope/models/cv/realtime_object_detection/yolox/exp/yolox_base.py +++ b/modelscope/models/cv/realtime_object_detection/yolox/exp/yolox_base.py @@ -1,5 +1,4 @@ # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX - import os import random diff --git a/modelscope/models/cv/realtime_object_detection/yolox/models/__init__.py b/modelscope/models/cv/realtime_object_detection/yolox/models/__init__.py index 20b1a0d1..d2e889f1 100644 --- a/modelscope/models/cv/realtime_object_detection/yolox/models/__init__.py +++ b/modelscope/models/cv/realtime_object_detection/yolox/models/__init__.py @@ -1,6 +1,9 @@ # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX from .darknet import CSPDarknet, Darknet +from .dfp_pafpn import DFPPAFPN +from .streamyolo import StreamYOLO +from .tal_head import TALHead from .yolo_fpn import YOLOFPN from .yolo_head import YOLOXHead from .yolo_pafpn import YOLOPAFPN diff --git a/modelscope/models/cv/realtime_object_detection/yolox/models/dfp_pafpn.py b/modelscope/models/cv/realtime_object_detection/yolox/models/dfp_pafpn.py new file mode 100644 index 00000000..01284791 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/models/dfp_pafpn.py @@ -0,0 +1,307 @@ +# The implementation is based on StreamYOLO, available at https://github.com/yancie-yjr/StreamYOLO +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .darknet import CSPDarknet +from .network_blocks import BaseConv, CSPLayer, DWConv + + +class DFPPAFPN(nn.Module): + """ + YOLOv3 model. Darknet 53 is the default backbone of this model. + """ + + def __init__( + self, + depth=1.0, + width=1.0, + in_features=('dark3', 'dark4', 'dark5'), + in_channels=[256, 512, 1024], + depthwise=False, + act='silu', + ): + super().__init__() + self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act) + self.in_features = in_features + self.in_channels = in_channels + Conv = DWConv if depthwise else BaseConv + + self.lateral_conv0 = BaseConv( + int(in_channels[2] * width), + int(in_channels[1] * width), + 1, + 1, + act=act) + self.C3_p4 = CSPLayer( + int(2 * in_channels[1] * width), + int(in_channels[1] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) # cat + + self.reduce_conv1 = BaseConv( + int(in_channels[1] * width), + int(in_channels[0] * width), + 1, + 1, + act=act) + self.C3_p3 = CSPLayer( + int(2 * in_channels[0] * width), + int(in_channels[0] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) + + # bottom-up conv + self.bu_conv2 = Conv( + int(in_channels[0] * width), + int(in_channels[0] * width), + 3, + 2, + act=act) + self.C3_n3 = CSPLayer( + int(2 * in_channels[0] * width), + int(in_channels[1] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) + + # bottom-up conv + self.bu_conv1 = Conv( + int(in_channels[1] * width), + int(in_channels[1] * width), + 3, + 2, + act=act) + self.C3_n4 = CSPLayer( + int(2 * in_channels[1] * width), + int(in_channels[2] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) + + self.jian2 = Conv( + in_channels=int(in_channels[0] * width), + out_channels=int(in_channels[0] * width) // 2, + ksize=1, + stride=1, + act=act, + ) + + self.jian1 = Conv( + in_channels=int(in_channels[1] * width), + out_channels=int(in_channels[1] * width) // 2, + ksize=1, + stride=1, + act=act, + ) + + self.jian0 = Conv( + in_channels=int(in_channels[2] * width), + out_channels=int(in_channels[2] * width) // 2, + ksize=1, + stride=1, + act=act, + ) + + def off_forward(self, input): + """ + Args: + inputs: input images. + + Returns: + Tuple[Tensor]: FPN feature. + """ + + # backbone + rurrent_out_features = self.backbone(torch.split(input, 3, dim=1)[0]) + rurrent_features = [rurrent_out_features[f] for f in self.in_features] + [rurrent_x2, rurrent_x1, rurrent_x0] = rurrent_features + + rurrent_fpn_out0 = self.lateral_conv0(rurrent_x0) # 1024->512/32 + rurrent_f_out0 = F.interpolate( + rurrent_fpn_out0, size=rurrent_x1.shape[2:4], + mode='nearest') # 512/16 + rurrent_f_out0 = torch.cat([rurrent_f_out0, rurrent_x1], + 1) # 512->1024/16 + rurrent_f_out0 = self.C3_p4(rurrent_f_out0) # 1024->512/16 + + rurrent_fpn_out1 = self.reduce_conv1(rurrent_f_out0) # 512->256/16 + rurrent_f_out1 = F.interpolate( + rurrent_fpn_out1, size=rurrent_x2.shape[2:4], + mode='nearest') # 256/8 + rurrent_f_out1 = torch.cat([rurrent_f_out1, rurrent_x2], + 1) # 256->512/8 + rurrent_pan_out2 = self.C3_p3(rurrent_f_out1) # 512->256/8 + + rurrent_p_out1 = self.bu_conv2(rurrent_pan_out2) # 256->256/16 + rurrent_p_out1 = torch.cat([rurrent_p_out1, rurrent_fpn_out1], + 1) # 256->512/16 + rurrent_pan_out1 = self.C3_n3(rurrent_p_out1) # 512->512/16 + + rurrent_p_out0 = self.bu_conv1(rurrent_pan_out1) # 512->512/32 + rurrent_p_out0 = torch.cat([rurrent_p_out0, rurrent_fpn_out0], + 1) # 512->1024/32 + rurrent_pan_out0 = self.C3_n4(rurrent_p_out0) # 1024->1024/32 + + ##### + + support_out_features = self.backbone(torch.split(input, 3, dim=1)[1]) + support_features = [support_out_features[f] for f in self.in_features] + [support_x2, support_x1, support_x0] = support_features + + support_fpn_out0 = self.lateral_conv0(support_x0) # 1024->512/32 + support_f_out0 = F.interpolate( + support_fpn_out0, size=support_x1.shape[2:4], + mode='nearest') # 512/16 + support_f_out0 = torch.cat([support_f_out0, support_x1], + 1) # 512->1024/16 + support_f_out0 = self.C3_p4(support_f_out0) # 1024->512/16 + + support_fpn_out1 = self.reduce_conv1(support_f_out0) # 512->256/16 + support_f_out1 = F.interpolate( + support_fpn_out1, size=support_x2.shape[2:4], + mode='nearest') # 256/8 + support_f_out1 = torch.cat([support_f_out1, support_x2], + 1) # 256->512/8 + support_pan_out2 = self.C3_p3(support_f_out1) # 512->256/8 + + support_p_out1 = self.bu_conv2(support_pan_out2) # 256->256/16 + support_p_out1 = torch.cat([support_p_out1, support_fpn_out1], + 1) # 256->512/16 + support_pan_out1 = self.C3_n3(support_p_out1) # 512->512/16 + + support_p_out0 = self.bu_conv1(support_pan_out1) # 512->512/32 + support_p_out0 = torch.cat([support_p_out0, support_fpn_out0], + 1) # 512->1024/32 + support_pan_out0 = self.C3_n4(support_p_out0) # 1024->1024/32 + + # 0.5 channel + pan_out2 = torch.cat( + [self.jian2(rurrent_pan_out2), + self.jian2(support_pan_out2)], + dim=1) + rurrent_pan_out2 + pan_out1 = torch.cat( + [self.jian1(rurrent_pan_out1), + self.jian1(support_pan_out1)], + dim=1) + rurrent_pan_out1 + pan_out0 = torch.cat( + [self.jian0(rurrent_pan_out0), + self.jian0(support_pan_out0)], + dim=1) + rurrent_pan_out0 + + outputs = (pan_out2, pan_out1, pan_out0) + + return outputs + + def online_forward(self, input, buffer=None, node='star'): + """ + Args: + inputs: input images. + + Returns: + Tuple[Tensor]: FPN feature. + """ + + # backbone + rurrent_out_features = self.backbone(input) + rurrent_features = [rurrent_out_features[f] for f in self.in_features] + [rurrent_x2, rurrent_x1, rurrent_x0] = rurrent_features + + rurrent_fpn_out0 = self.lateral_conv0(rurrent_x0) # 1024->512/32 + rurrent_f_out0 = F.interpolate( + rurrent_fpn_out0, size=rurrent_x1.shape[2:4], + mode='nearest') # 512/16 + rurrent_f_out0 = torch.cat([rurrent_f_out0, rurrent_x1], + 1) # 512->1024/16 + rurrent_f_out0 = self.C3_p4(rurrent_f_out0) # 1024->512/16 + + rurrent_fpn_out1 = self.reduce_conv1(rurrent_f_out0) # 512->256/16 + rurrent_f_out1 = F.interpolate( + rurrent_fpn_out1, size=rurrent_x2.shape[2:4], + mode='nearest') # 256/8 + rurrent_f_out1 = torch.cat([rurrent_f_out1, rurrent_x2], + 1) # 256->512/8 + rurrent_pan_out2 = self.C3_p3(rurrent_f_out1) # 512->256/8 + + rurrent_p_out1 = self.bu_conv2(rurrent_pan_out2) # 256->256/16 + rurrent_p_out1 = torch.cat([rurrent_p_out1, rurrent_fpn_out1], + 1) # 256->512/16 + rurrent_pan_out1 = self.C3_n3(rurrent_p_out1) # 512->512/16 + + rurrent_p_out0 = self.bu_conv1(rurrent_pan_out1) # 512->512/32 + rurrent_p_out0 = torch.cat([rurrent_p_out0, rurrent_fpn_out0], + 1) # 512->1024/32 + rurrent_pan_out0 = self.C3_n4(rurrent_p_out0) # 1024->1024/32 + + ##### + if node == 'star': + pan_out2 = torch.cat( + [self.jian2(rurrent_pan_out2), + self.jian2(rurrent_pan_out2)], + dim=1) + rurrent_pan_out2 + pan_out1 = torch.cat( + [self.jian1(rurrent_pan_out1), + self.jian1(rurrent_pan_out1)], + dim=1) + rurrent_pan_out1 + pan_out0 = torch.cat( + [self.jian0(rurrent_pan_out0), + self.jian0(rurrent_pan_out0)], + dim=1) + rurrent_pan_out0 + elif node == 'buffer': + + [support_pan_out2, support_pan_out1, support_pan_out0] = buffer + + pan_out2 = torch.cat( + [self.jian2(rurrent_pan_out2), + self.jian2(support_pan_out2)], + dim=1) + rurrent_pan_out2 + pan_out1 = torch.cat( + [self.jian1(rurrent_pan_out1), + self.jian1(support_pan_out1)], + dim=1) + rurrent_pan_out1 + pan_out0 = torch.cat( + [self.jian0(rurrent_pan_out0), + self.jian0(support_pan_out0)], + dim=1) + rurrent_pan_out0 + + outputs = (pan_out2, pan_out1, pan_out0) + + buffer_ = (rurrent_pan_out2, rurrent_pan_out1, rurrent_pan_out0) + + return outputs, buffer_ + + def forward(self, input, buffer=None, mode='off_pipe'): + + if mode == 'off_pipe': + # Glops caculate mode + if input.size()[1] == 3: + input = torch.cat([input, input], dim=1) + output = self.off_forward(input) + # offline train mode + elif input.size()[1] == 6: + output = self.off_forward(input) + + return output + + elif mode == 'on_pipe': + # online star state + if buffer is None: + output, buffer_ = self.online_forward(input, node='star') + # online inference + else: + assert len(buffer) == 3 + assert input.size()[1] == 3 + output, buffer_ = self.online_forward( + input, buffer=buffer, node='buffer') + + return output, buffer_ diff --git a/modelscope/models/cv/realtime_object_detection/yolox/models/network_blocks.py b/modelscope/models/cv/realtime_object_detection/yolox/models/network_blocks.py index fd15c1c1..88bd55c7 100644 --- a/modelscope/models/cv/realtime_object_detection/yolox/models/network_blocks.py +++ b/modelscope/models/cv/realtime_object_detection/yolox/models/network_blocks.py @@ -1,5 +1,4 @@ # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX - import torch import torch.nn as nn diff --git a/modelscope/models/cv/realtime_object_detection/yolox/models/streamyolo.py b/modelscope/models/cv/realtime_object_detection/yolox/models/streamyolo.py new file mode 100644 index 00000000..b3ec3504 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/models/streamyolo.py @@ -0,0 +1,41 @@ +# The implementation is based on StreamYOLO, available at https://github.com/yancie-yjr/StreamYOLO +import torch.nn as nn + +from .dfp_pafpn import DFPPAFPN +from .tal_head import TALHead + + +class StreamYOLO(nn.Module): + """ + YOLOX model module. The module list is defined by create_yolov3_modules function. + The network returns loss values from three YOLO layers during training + and detection results during test. + """ + + def __init__(self, backbone=None, head=None): + super().__init__() + if backbone is None: + backbone = DFPPAFPN() + if head is None: + head = TALHead(20) + + self.backbone = backbone + self.head = head + + def forward(self, x, targets=None, buffer=None, mode='off_pipe'): + # fpn output content features of [dark3, dark4, dark5] + assert mode in ['off_pipe', 'on_pipe'] + + if mode == 'off_pipe': + fpn_outs = self.backbone(x, buffer=buffer, mode='off_pipe') + if self.training: + pass + else: + outputs = self.head(fpn_outs, imgs=x) + + return outputs + elif mode == 'on_pipe': + fpn_outs, buffer_ = self.backbone(x, buffer=buffer, mode='on_pipe') + outputs = self.head(fpn_outs) + + return outputs, buffer_ diff --git a/modelscope/models/cv/realtime_object_detection/yolox/models/tal_head.py b/modelscope/models/cv/realtime_object_detection/yolox/models/tal_head.py new file mode 100644 index 00000000..7a82f8c6 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/models/tal_head.py @@ -0,0 +1,170 @@ +# The implementation is based on StreamYOLO, available at https://github.com/yancie-yjr/StreamYOLO +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .network_blocks import BaseConv, DWConv + + +class TALHead(nn.Module): + + def __init__( + self, + num_classes, + width=1.0, + strides=[8, 16, 32], + in_channels=[256, 512, 1024], + act='silu', + depthwise=False, + gamma=1.5, + ignore_thr=0.2, + ignore_value=0.2, + ): + """ + Args: + act (str): activation type of conv. Defalut value: "silu". + depthwise (bool): wheather apply depthwise conv in conv branch. Defalut value: False. + """ + super().__init__() + + self.gamma = gamma + self.ignore_thr = ignore_thr + self.ignore_value = ignore_value + + self.n_anchors = 1 + self.num_classes = num_classes + self.decode_in_inference = True # for deploy, set to False + + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + self.cls_preds = nn.ModuleList() + self.reg_preds = nn.ModuleList() + self.obj_preds = nn.ModuleList() + self.stems = nn.ModuleList() + Conv = DWConv if depthwise else BaseConv + + for i in range(len(in_channels)): + self.stems.append( + BaseConv( + in_channels=int(in_channels[i] * width), + out_channels=int(256 * width), + ksize=1, + stride=1, + act=act, + )) + self.cls_convs.append( + nn.Sequential(*[ + Conv( + in_channels=int(256 * width), + out_channels=int(256 * width), + ksize=3, + stride=1, + act=act, + ), + Conv( + in_channels=int(256 * width), + out_channels=int(256 * width), + ksize=3, + stride=1, + act=act, + ), + ])) + self.reg_convs.append( + nn.Sequential(*[ + Conv( + in_channels=int(256 * width), + out_channels=int(256 * width), + ksize=3, + stride=1, + act=act, + ), + Conv( + in_channels=int(256 * width), + out_channels=int(256 * width), + ksize=3, + stride=1, + act=act, + ), + ])) + self.cls_preds.append( + nn.Conv2d( + in_channels=int(256 * width), + out_channels=self.n_anchors * self.num_classes, + kernel_size=1, + stride=1, + padding=0, + )) + self.reg_preds.append( + nn.Conv2d( + in_channels=int(256 * width), + out_channels=4, + kernel_size=1, + stride=1, + padding=0, + )) + self.obj_preds.append( + nn.Conv2d( + in_channels=int(256 * width), + out_channels=self.n_anchors * 1, + kernel_size=1, + stride=1, + padding=0, + )) + + self.strides = strides + self.grids = [torch.zeros(1)] * len(in_channels) + self.expanded_strides = [None] * len(in_channels) + + def forward(self, xin, labels=None, imgs=None): + outputs = [] + for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate( + zip(self.cls_convs, self.reg_convs, self.strides, xin)): + x = self.stems[k](x) + cls_x = x + reg_x = x + + cls_feat = cls_conv(cls_x) + cls_output = self.cls_preds[k](cls_feat) + + reg_feat = reg_conv(reg_x) + reg_output = self.reg_preds[k](reg_feat) + obj_output = self.obj_preds[k](reg_feat) + + if self.training: + pass + + else: + output = torch.cat( + [reg_output, + obj_output.sigmoid(), + cls_output.sigmoid()], 1) + + outputs.append(output) + + if self.training: + pass + else: + self.hw = [x.shape[-2:] for x in outputs] + outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], + dim=2).permute(0, 2, 1) + if self.decode_in_inference: + return self.decode_outputs(outputs, dtype=xin[0].type()) + else: + return outputs + + def decode_outputs(self, outputs, dtype): + grids = [] + strides = [] + for (hsize, wsize), stride in zip(self.hw, self.strides): + yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)]) + grid = torch.stack((xv, yv), 2).view(1, -1, 2) + grids.append(grid) + shape = grid.shape[:2] + strides.append(torch.full((*shape, 1), stride)) + + grids = torch.cat(grids, dim=1).type(dtype) + strides = torch.cat(strides, dim=1).type(dtype) + + outputs[..., :2] = (outputs[..., :2] + grids) * strides + outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides + return outputs diff --git a/modelscope/msdatasets/image_denoise_data/__init__.py b/modelscope/models/cv/referring_video_object_segmentation/__init__.py similarity index 77% rename from modelscope/msdatasets/image_denoise_data/__init__.py rename to modelscope/models/cv/referring_video_object_segmentation/__init__.py index ba1d2df8..58dbf7b0 100644 --- a/modelscope/msdatasets/image_denoise_data/__init__.py +++ b/modelscope/models/cv/referring_video_object_segmentation/__init__.py @@ -4,11 +4,12 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .image_denoise_dataset import PairedImageDataset + + from .model import MovieSceneSegmentation else: _import_structure = { - 'image_denoise_dataset': ['PairedImageDataset'], + 'model': ['MovieSceneSegmentation'], } import sys diff --git a/modelscope/models/cv/referring_video_object_segmentation/model.py b/modelscope/models/cv/referring_video_object_segmentation/model.py new file mode 100644 index 00000000..902a3416 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/model.py @@ -0,0 +1,65 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .utils import (MTTR, A2DSentencesPostProcess, ReferYoutubeVOSPostProcess, + nested_tensor_from_videos_list) + +logger = get_logger() + + +@MODELS.register_module( + Tasks.referring_video_object_segmentation, + module_name=Models.referring_video_object_segmentation) +class ReferringVideoObjectSegmentation(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """str -- model file root.""" + super().__init__(model_dir, *args, **kwargs) + + config_path = osp.join(model_dir, ModelFile.CONFIGURATION) + self.cfg = Config.from_file(config_path) + self.model = MTTR(**self.cfg.model) + + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) + params_dict = torch.load(model_path, map_location='cpu') + if 'model_state_dict' in params_dict.keys(): + params_dict = params_dict['model_state_dict'] + self.model.load_state_dict(params_dict, strict=True) + + dataset_name = self.cfg.pipeline.dataset_name + if dataset_name == 'a2d_sentences' or dataset_name == 'jhmdb_sentences': + self.postprocessor = A2DSentencesPostProcess() + elif dataset_name == 'ref_youtube_vos': + self.postprocessor = ReferYoutubeVOSPostProcess() + else: + assert False, f'postprocessing for dataset: {dataset_name} is not supported' + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]: + return inputs + + def inference(self, **kwargs): + window = kwargs['window'] + text_query = kwargs['text_query'] + video_metadata = kwargs['metadata'] + + window = nested_tensor_from_videos_list([window]) + valid_indices = torch.arange(len(window.tensors)) + if self._device_name == 'gpu': + valid_indices = valid_indices.cuda() + outputs = self.model(window, valid_indices, [text_query]) + window_masks = self.postprocessor( + outputs, [video_metadata], + window.tensors.shape[-2:])[0]['pred_masks'] + return window_masks + + def postprocess(self, inputs: Dict[str, Any], **kwargs): + return inputs diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/__init__.py b/modelscope/models/cv/referring_video_object_segmentation/utils/__init__.py new file mode 100644 index 00000000..796bd6f4 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .misc import nested_tensor_from_videos_list +from .mttr import MTTR +from .postprocessing import A2DSentencesPostProcess, ReferYoutubeVOSPostProcess diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/backbone.py b/modelscope/models/cv/referring_video_object_segmentation/utils/backbone.py new file mode 100644 index 00000000..afa384c1 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/backbone.py @@ -0,0 +1,198 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR + +import torch +import torch.nn.functional as F +import torchvision +from einops import rearrange +from torch import nn +from torchvision.models._utils import IntermediateLayerGetter + +from .misc import NestedTensor, is_main_process +from .swin_transformer import SwinTransformer3D + + +class VideoSwinTransformerBackbone(nn.Module): + """ + A wrapper which allows using Video-Swin Transformer as a temporal encoder for MTTR. + Check out video-swin's original paper at: https://arxiv.org/abs/2106.13230 for more info about this architecture. + Only the 'tiny' version of video swin was tested and is currently supported in our project. + Additionally, we slightly modify video-swin to make it output per-frame embeddings as required by MTTR (check our + paper's supplementary for more details), and completely discard of its 4th block. + """ + + def __init__(self, backbone_pretrained, backbone_pretrained_path, + train_backbone, running_mode, **kwargs): + super(VideoSwinTransformerBackbone, self).__init__() + # patch_size is (1, 4, 4) instead of the original (2, 4, 4). + # this prevents swinT's original temporal downsampling so we can get per-frame features. + swin_backbone = SwinTransformer3D( + patch_size=(1, 4, 4), + embed_dim=96, + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + window_size=(8, 7, 7), + drop_path_rate=0.1, + patch_norm=True) + if backbone_pretrained and running_mode == 'train': + state_dict = torch.load(backbone_pretrained_path)['state_dict'] + # extract swinT's kinetics-400 pretrained weights and ignore the rest (prediction head etc.) + state_dict = { + k[9:]: v + for k, v in state_dict.items() if 'backbone.' in k + } + + # sum over the patch embedding weight temporal dim [96, 3, 2, 4, 4] --> [96, 3, 1, 4, 4] + patch_embed_weight = state_dict['patch_embed.proj.weight'] + patch_embed_weight = patch_embed_weight.sum(dim=2, keepdims=True) + state_dict['patch_embed.proj.weight'] = patch_embed_weight + swin_backbone.load_state_dict(state_dict) + + self.patch_embed = swin_backbone.patch_embed + self.pos_drop = swin_backbone.pos_drop + self.layers = swin_backbone.layers[:-1] + self.downsamples = nn.ModuleList() + for layer in self.layers: + self.downsamples.append(layer.downsample) + layer.downsample = None + self.downsamples[ + -1] = None # downsampling after the last layer is not necessary + + self.layer_output_channels = [ + swin_backbone.embed_dim * 2**i for i in range(len(self.layers)) + ] + self.train_backbone = train_backbone + if not train_backbone: + for parameter in self.parameters(): + parameter.requires_grad_(False) + + def forward(self, samples: NestedTensor): + vid_frames = rearrange(samples.tensors, 't b c h w -> b c t h w') + + vid_embeds = self.patch_embed(vid_frames) + vid_embeds = self.pos_drop(vid_embeds) + layer_outputs = [] # layer outputs before downsampling + for layer, downsample in zip(self.layers, self.downsamples): + vid_embeds = layer(vid_embeds.contiguous()) + layer_outputs.append(vid_embeds) + if downsample: + vid_embeds = rearrange(vid_embeds, 'b c t h w -> b t h w c') + vid_embeds = downsample(vid_embeds) + vid_embeds = rearrange(vid_embeds, 'b t h w c -> b c t h w') + layer_outputs = [ + rearrange(o, 'b c t h w -> t b c h w') for o in layer_outputs + ] + + outputs = [] + orig_pad_mask = samples.mask + for l_out in layer_outputs: + pad_mask = F.interpolate( + orig_pad_mask.float(), size=l_out.shape[-2:]).to(torch.bool) + outputs.append(NestedTensor(l_out, pad_mask)) + return outputs + + def num_parameters(self): + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +class FrozenBatchNorm2d(torch.nn.Module): + """ + Modified from DETR https://github.com/facebookresearch/detr + BatchNorm2d where the batch statistics and the affine parameters are fixed. + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other models than torchvision.models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer('weight', torch.ones(n)) + self.register_buffer('bias', torch.zeros(n)) + self.register_buffer('running_mean', torch.zeros(n)) + self.register_buffer('running_var', torch.ones(n)) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, + self)._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, + unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class ResNetBackbone(nn.Module): + """ + Modified from DETR https://github.com/facebookresearch/detr + ResNet backbone with frozen BatchNorm. + """ + + def __init__(self, + backbone_name: str = 'resnet50', + train_backbone: bool = True, + dilation: bool = True, + **kwargs): + super(ResNetBackbone, self).__init__() + backbone = getattr(torchvision.models, backbone_name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=is_main_process(), + norm_layer=FrozenBatchNorm2d) + for name, parameter in backbone.named_parameters(): + if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: + parameter.requires_grad_(False) + return_layers = { + 'layer1': '0', + 'layer2': '1', + 'layer3': '2', + 'layer4': '3' + } + self.body = IntermediateLayerGetter( + backbone, return_layers=return_layers) + output_channels = 512 if backbone_name in ('resnet18', + 'resnet34') else 2048 + self.layer_output_channels = [ + output_channels // 8, output_channels // 4, output_channels // 2, + output_channels + ] + + def forward(self, tensor_list: NestedTensor): + t, b, _, _, _ = tensor_list.tensors.shape + video_frames = rearrange(tensor_list.tensors, + 't b c h w -> (t b) c h w') + padding_masks = rearrange(tensor_list.mask, 't b h w -> (t b) h w') + features_list = self.body(video_frames) + out = [] + for _, f in features_list.items(): + resized_padding_masks = F.interpolate( + padding_masks[None].float(), + size=f.shape[-2:]).to(torch.bool)[0] + f = rearrange(f, '(t b) c h w -> t b c h w', t=t, b=b) + resized_padding_masks = rearrange( + resized_padding_masks, '(t b) h w -> t b h w', t=t, b=b) + out.append(NestedTensor(f, resized_padding_masks)) + return out + + def num_parameters(self): + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +def init_backbone(backbone_name, **kwargs): + if backbone_name == 'swin-t': + return VideoSwinTransformerBackbone(**kwargs) + elif 'resnet' in backbone_name: + return ResNetBackbone(backbone_name, **kwargs) + assert False, f'error: backbone "{backbone_name}" is not supported' diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/misc.py b/modelscope/models/cv/referring_video_object_segmentation/utils/misc.py new file mode 100644 index 00000000..ecf34b8c --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/misc.py @@ -0,0 +1,234 @@ +# Modified from DETR https://github.com/facebookresearch/detr +# Misc functions. +# Mostly copy-paste from torchvision references. + +import pickle +from typing import List, Optional + +import torch +import torch.distributed as dist +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +from torch import Tensor + +if float(torchvision.__version__.split('.')[1]) < 7.0: + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to('cuda') + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device='cuda') + size_list = [torch.tensor([0], device='cuda') for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append( + torch.empty((max_size, ), dtype=torch.uint8, device='cuda')) + if local_size != max_size: + padding = torch.empty( + size=(max_size - local_size, ), dtype=torch.uint8, device='cuda') + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + """ + This function receives a list of image tensors and returns a NestedTensor of the padded images, along with their + padding masks (true for padding areas, false otherwise). + """ + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[:img.shape[0], :img.shape[1], :img.shape[2]].copy_(img) + m[:img.shape[1], :img.shape[2]] = False + return NestedTensor(tensor, mask) + + +def nested_tensor_from_videos_list(videos_list: List[Tensor]): + """ + This function receives a list of videos (each of shape [T, C, H, W]) and returns a NestedTensor of the padded + videos (shape [T, B, C, PH, PW], along with their padding masks (true for padding areas, false otherwise, of shape + [T, B, PH, PW]. + """ + max_size = _max_by_axis([list(img.shape) for img in videos_list]) + padded_batch_shape = [len(videos_list)] + max_size + b, t, c, h, w = padded_batch_shape + dtype = videos_list[0].dtype + device = videos_list[0].device + padded_videos = torch.zeros(padded_batch_shape, dtype=dtype, device=device) + videos_pad_masks = torch.ones((b, t, h, w), + dtype=torch.bool, + device=device) + for vid_frames, pad_vid_frames, vid_pad_m in zip(videos_list, + padded_videos, + videos_pad_masks): + pad_vid_frames[:vid_frames.shape[0], :, :vid_frames. + shape[2], :vid_frames.shape[3]].copy_(vid_frames) + vid_pad_m[:vid_frames.shape[0], :vid_frames.shape[2], :vid_frames. + shape[3]] = False + # transpose the temporal and batch dims and create a NestedTensor: + return NestedTensor( + padded_videos.transpose(0, 1), videos_pad_masks.transpose(0, 1)) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def interpolate(input, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if float(torchvision.__version__.split('.')[1]) < 7.0: + if input.numel() > 0: + return torch.nn.functional.interpolate(input, size, scale_factor, + mode, align_corners) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, + mode, align_corners) diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/mttr.py b/modelscope/models/cv/referring_video_object_segmentation/utils/mttr.py new file mode 100644 index 00000000..e603df6c --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/mttr.py @@ -0,0 +1,128 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +from .backbone import init_backbone +from .misc import NestedTensor +from .multimodal_transformer import MultimodalTransformer +from .segmentation import FPNSpatialDecoder + + +class MTTR(nn.Module): + """ The main module of the Multimodal Tracking Transformer """ + + def __init__(self, + num_queries, + mask_kernels_dim=8, + aux_loss=False, + **kwargs): + """ + Parameters: + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + MTTR can detect in a single image. In our paper we use 50 in all settings. + mask_kernels_dim: dim of the segmentation kernels and of the feature maps outputted by the spatial decoder. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.backbone = init_backbone(**kwargs) + self.transformer = MultimodalTransformer(**kwargs) + d_model = self.transformer.d_model + self.is_referred_head = nn.Linear( + d_model, + 2) # binary 'is referred?' prediction head for object queries + self.instance_kernels_head = MLP( + d_model, d_model, output_dim=mask_kernels_dim, num_layers=2) + self.obj_queries = nn.Embedding( + num_queries, d_model) # pos embeddings for the object queries + self.vid_embed_proj = nn.Conv2d( + self.backbone.layer_output_channels[-1], d_model, kernel_size=1) + self.spatial_decoder = FPNSpatialDecoder( + d_model, self.backbone.layer_output_channels[:-1][::-1], + mask_kernels_dim) + self.aux_loss = aux_loss + + def forward(self, samples: NestedTensor, valid_indices, text_queries): + """The forward expects a NestedTensor, which consists of: + - samples.tensor: Batched frames of shape [time x batch_size x 3 x H x W] + - samples.mask: A binary mask of shape [time x batch_size x H x W], containing 1 on padded pixels + + It returns a dict with the following elements: + - "pred_is_referred": The reference prediction logits for all queries. + Shape: [time x batch_size x num_queries x 2] + - "pred_masks": The mask logits for all queries. + Shape: [time x batch_size x num_queries x H_mask x W_mask] + - "aux_outputs": Optional, only returned when auxiliary losses are activated. It is a list of + dictionaries containing the two above keys for each decoder layer. + """ + backbone_out = self.backbone(samples) + # keep only the valid frames (frames which are annotated): + # (for example, in a2d-sentences only the center frame in each window is annotated). + for layer_out in backbone_out: + layer_out.tensors = layer_out.tensors.index_select( + 0, valid_indices) + layer_out.mask = layer_out.mask.index_select(0, valid_indices) + bbone_final_layer_output = backbone_out[-1] + vid_embeds, vid_pad_mask = bbone_final_layer_output.decompose() + + T, B, _, _, _ = vid_embeds.shape + vid_embeds = rearrange(vid_embeds, 't b c h w -> (t b) c h w') + vid_embeds = self.vid_embed_proj(vid_embeds) + vid_embeds = rearrange( + vid_embeds, '(t b) c h w -> t b c h w', t=T, b=B) + + transformer_out = self.transformer(vid_embeds, vid_pad_mask, + text_queries, + self.obj_queries.weight) + # hs is: [L, T, B, N, D] where L is number of decoder layers + # vid_memory is: [T, B, D, H, W] + # txt_memory is a list of length T*B of [S, C] where S might be different for each sentence + # encoder_middle_layer_outputs is a list of [T, B, H, W, D] + hs, vid_memory, txt_memory = transformer_out + + vid_memory = rearrange(vid_memory, 't b d h w -> (t b) d h w') + bbone_middle_layer_outputs = [ + rearrange(o.tensors, 't b d h w -> (t b) d h w') + for o in backbone_out[:-1][::-1] + ] + decoded_frame_features = self.spatial_decoder( + vid_memory, bbone_middle_layer_outputs) + decoded_frame_features = rearrange( + decoded_frame_features, '(t b) d h w -> t b d h w', t=T, b=B) + instance_kernels = self.instance_kernels_head(hs) # [L, T, B, N, C] + # output masks is: [L, T, B, N, H_mask, W_mask] + output_masks = torch.einsum('ltbnc,tbchw->ltbnhw', instance_kernels, + decoded_frame_features) + outputs_is_referred = self.is_referred_head(hs) # [L, T, B, N, 2] + + layer_outputs = [] + for pm, pir in zip(output_masks, outputs_is_referred): + layer_out = {'pred_masks': pm, 'pred_is_referred': pir} + layer_outputs.append(layer_out) + out = layer_outputs[ + -1] # the output for the last decoder layer is used by default + if self.aux_loss: + out['aux_outputs'] = layer_outputs[:-1] + return out + + def num_parameters(self): + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/multimodal_transformer.py b/modelscope/models/cv/referring_video_object_segmentation/utils/multimodal_transformer.py new file mode 100644 index 00000000..8c24e397 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/multimodal_transformer.py @@ -0,0 +1,440 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR +# MTTR Multimodal Transformer class. +# Modified from DETR https://github.com/facebookresearch/detr + +import copy +import os +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import Tensor, nn +from transformers import RobertaModel, RobertaTokenizerFast + +from .position_encoding_2d import PositionEmbeddingSine2D + +os.environ[ + 'TOKENIZERS_PARALLELISM'] = 'false' # this disables a huggingface tokenizer warning (printed every epoch) + + +class MultimodalTransformer(nn.Module): + + def __init__(self, + num_encoder_layers=3, + num_decoder_layers=3, + text_encoder_type='roberta-base', + freeze_text_encoder=True, + **kwargs): + super().__init__() + self.d_model = kwargs['d_model'] + encoder_layer = TransformerEncoderLayer(**kwargs) + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers) + decoder_layer = TransformerDecoderLayer(**kwargs) + self.decoder = TransformerDecoder( + decoder_layer, + num_decoder_layers, + norm=nn.LayerNorm(self.d_model), + return_intermediate=True) + self.pos_encoder_2d = PositionEmbeddingSine2D() + self._reset_parameters() + + self.text_encoder = RobertaModel.from_pretrained(text_encoder_type) + self.text_encoder.pooler = None # this pooler is never used, this is a hack to avoid DDP problems... + self.tokenizer = RobertaTokenizerFast.from_pretrained( + text_encoder_type) + self.freeze_text_encoder = freeze_text_encoder + if freeze_text_encoder: + for p in self.text_encoder.parameters(): + p.requires_grad_(False) + + self.txt_proj = FeatureResizer( + input_feat_size=self.text_encoder.config.hidden_size, + output_feat_size=self.d_model, + dropout=kwargs['dropout'], + ) + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, vid_embeds, vid_pad_mask, text_queries, obj_queries): + device = vid_embeds.device + t, b, _, h, w = vid_embeds.shape + + txt_memory, txt_pad_mask = self.forward_text(text_queries, device) + # add temporal dim to txt memory & padding mask: + txt_memory = repeat(txt_memory, 's b c -> s (t b) c', t=t) + txt_pad_mask = repeat(txt_pad_mask, 'b s -> (t b) s', t=t) + + vid_embeds = rearrange(vid_embeds, 't b c h w -> (h w) (t b) c') + # Concat the image & text embeddings on the sequence dimension + encoder_src_seq = torch.cat((vid_embeds, txt_memory), dim=0) + seq_mask = torch.cat( + (rearrange(vid_pad_mask, 't b h w -> (t b) (h w)'), txt_pad_mask), + dim=1) + # vid_pos_embed is: [T*B, H, W, d_model] + vid_pos_embed = self.pos_encoder_2d( + rearrange(vid_pad_mask, 't b h w -> (t b) h w'), self.d_model) + # use zeros in place of pos embeds for the text sequence: + pos_embed = torch.cat( + (rearrange(vid_pos_embed, 't_b h w c -> (h w) t_b c'), + torch.zeros_like(txt_memory)), + dim=0) + + memory = self.encoder( + encoder_src_seq, src_key_padding_mask=seq_mask, + pos=pos_embed) # [S, T*B, C] + vid_memory = rearrange( + memory[:h * w, :, :], + '(h w) (t b) c -> t b c h w', + h=h, + w=w, + t=t, + b=b) + txt_memory = memory[h * w:, :, :] + txt_memory = rearrange(txt_memory, 's t_b c -> t_b s c') + txt_memory = [ + t_mem[~pad_mask] + for t_mem, pad_mask in zip(txt_memory, txt_pad_mask) + ] # remove padding + + # add T*B dims to query embeds (was: [N, C], where N is the number of object queries): + obj_queries = repeat(obj_queries, 'n c -> n (t b) c', t=t, b=b) + tgt = torch.zeros_like(obj_queries) # [N, T*B, C] + + # hs is [L, N, T*B, C] where L is number of layers in the decoder + hs = self.decoder( + tgt, + memory, + memory_key_padding_mask=seq_mask, + pos=pos_embed, + query_pos=obj_queries) + hs = rearrange(hs, 'l n (t b) c -> l t b n c', t=t, b=b) + return hs, vid_memory, txt_memory + + def forward_text(self, text_queries, device): + tokenized_queries = self.tokenizer.batch_encode_plus( + text_queries, padding='longest', return_tensors='pt') + tokenized_queries = tokenized_queries.to(device) + with torch.inference_mode(mode=self.freeze_text_encoder): + encoded_text = self.text_encoder(**tokenized_queries) + # Transpose memory because pytorch's attention expects sequence first + txt_memory = rearrange(encoded_text.last_hidden_state, + 'b s c -> s b c') + txt_memory = self.txt_proj( + txt_memory) # change text embeddings dim to model dim + # Invert attention mask that we get from huggingface because its the opposite in pytorch transformer + txt_pad_mask = tokenized_queries.attention_mask.ne(1).bool() # [B, S] + return txt_memory, txt_pad_mask + + def num_parameters(self): + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, + src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + output = src + + for layer in self.layers: + output = layer( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + + def __init__(self, + decoder_layer, + num_layers, + norm=None, + return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward(self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, + query_pos=query_pos) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, + d_model, + nheads, + dim_feedforward=2048, + dropout=0.1, + activation='relu', + normalize_before=False, + **kwargs): + super().__init__() + self.self_attn = nn.MultiheadAttention( + d_model, nheads, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn( + q, + k, + value=src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn( + q, + k, + value=src2, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, + d_model, + nheads, + dim_feedforward=2048, + dropout=0.1, + activation='relu', + normalize_before=False, + **kwargs): + super().__init__() + self.self_attn = nn.MultiheadAttention( + d_model, nheads, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention( + d_model, nheads, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn( + q, + k, + value=tgt, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre(self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn( + q, + k, + value=tgt2, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward(self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, + pos, query_pos) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class FeatureResizer(nn.Module): + """ + This class takes as input a set of embeddings of dimension C1 and outputs a set of + embedding of dimension C2, after a linear transformation, dropout and normalization (LN). + """ + + def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True): + super().__init__() + self.do_ln = do_ln + # Object feature encoding + self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True) + self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12) + self.dropout = nn.Dropout(dropout) + + def forward(self, encoder_features): + x = self.fc(encoder_features) + if self.do_ln: + x = self.layer_norm(x) + output = self.dropout(x) + return output + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == 'relu': + return F.relu + if activation == 'gelu': + return F.gelu + if activation == 'glu': + return F.glu + raise RuntimeError(F'activation should be relu/gelu, not {activation}.') diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/position_encoding_2d.py b/modelscope/models/cv/referring_video_object_segmentation/utils/position_encoding_2d.py new file mode 100644 index 00000000..f9ef05a1 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/position_encoding_2d.py @@ -0,0 +1,57 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR +# Modified from DETR https://github.com/facebookresearch/detr +# 2D sine positional encodings for the visual features in the multimodal transformer. + +import math + +import torch +from torch import Tensor, nn + + +class PositionEmbeddingSine2D(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, temperature=10000, normalize=True, scale=None): + super().__init__() + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError('normalize should be True if scale is passed') + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, mask: Tensor, hidden_dim: int): + """ + @param mask: a tensor of shape [B, H, W] + @param hidden_dim: int + @return: + """ + num_pos_feats = hidden_dim // 2 + + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange( + num_pos_feats, dtype=torch.float32, device=mask.device) + dim_t = self.temperature**(2 * (dim_t // 2) / num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3) + return pos diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/postprocessing.py b/modelscope/models/cv/referring_video_object_segmentation/utils/postprocessing.py new file mode 100644 index 00000000..64582140 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/postprocessing.py @@ -0,0 +1,119 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR + +import numpy as np +import pycocotools.mask as mask_util +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +class A2DSentencesPostProcess(nn.Module): + """ + This module converts the model's output into the format expected by the coco api for the given task + """ + + def __init__(self): + super(A2DSentencesPostProcess, self).__init__() + + @torch.inference_mode() + def forward(self, outputs, resized_padded_sample_size, + resized_sample_sizes, orig_sample_sizes): + """ Perform the computation + Parameters: + outputs: raw outputs of the model + resized_padded_sample_size: size of samples (input to model) after size augmentation + padding. + resized_sample_sizes: size of samples after size augmentation but without padding. + orig_sample_sizes: original size of the samples (no augmentations or padding) + """ + pred_is_referred = outputs['pred_is_referred'] + prob = F.softmax(pred_is_referred, dim=-1) + scores = prob[..., 0] + pred_masks = outputs['pred_masks'] + pred_masks = F.interpolate( + pred_masks, + size=resized_padded_sample_size, + mode='bilinear', + align_corners=False) + pred_masks = (pred_masks.sigmoid() > 0.5) + processed_pred_masks, rle_masks = [], [] + for f_pred_masks, resized_size, orig_size in zip( + pred_masks, resized_sample_sizes, orig_sample_sizes): + f_mask_h, f_mask_w = resized_size # resized shape without padding + # remove the samples' padding + f_pred_masks_no_pad = f_pred_masks[:, :f_mask_h, : + f_mask_w].unsqueeze(1) + # resize the samples back to their original dataset (target) size for evaluation + f_pred_masks_processed = F.interpolate( + f_pred_masks_no_pad.float(), size=orig_size, mode='nearest') + f_pred_rle_masks = [ + mask_util.encode( + np.array( + mask[0, :, :, np.newaxis], dtype=np.uint8, + order='F'))[0] + for mask in f_pred_masks_processed.cpu() + ] + processed_pred_masks.append(f_pred_masks_processed) + rle_masks.append(f_pred_rle_masks) + predictions = [{ + 'scores': s, + 'masks': m, + 'rle_masks': rle + } for s, m, rle in zip(scores, processed_pred_masks, rle_masks)] + return predictions + + +class ReferYoutubeVOSPostProcess(nn.Module): + """ + This module converts the model's output into the format expected by the coco api for the given task + """ + + def __init__(self): + super(ReferYoutubeVOSPostProcess, self).__init__() + + @torch.inference_mode() + def forward(self, outputs, videos_metadata, samples_shape_with_padding): + """ Perform the computation + Parameters: + outputs: raw outputs of the model + videos_metadata: a dictionary with each video's metadata. + samples_shape_with_padding: size of the batch frames with padding. + """ + pred_is_referred = outputs['pred_is_referred'] + prob_is_referred = F.softmax(pred_is_referred, dim=-1) + # note we average on the temporal dim to compute score per trajectory: + trajectory_scores = prob_is_referred[..., 0].mean(dim=0) + pred_trajectory_indices = torch.argmax(trajectory_scores, dim=-1) + pred_masks = rearrange(outputs['pred_masks'], + 't b nq h w -> b t nq h w') + # keep only the masks of the chosen trajectories: + b = pred_masks.shape[0] + pred_masks = pred_masks[torch.arange(b), :, pred_trajectory_indices] + # resize the predicted masks to the size of the model input (which might include padding) + pred_masks = F.interpolate( + pred_masks, + size=samples_shape_with_padding, + mode='bilinear', + align_corners=False) + # apply a threshold to create binary masks: + pred_masks = (pred_masks.sigmoid() > 0.5) + # remove the padding per video (as videos might have different resolutions and thus different padding): + preds_by_video = [] + for video_pred_masks, video_metadata in zip(pred_masks, + videos_metadata): + # size of the model input batch frames without padding: + resized_h, resized_w = video_metadata['resized_frame_size'] + video_pred_masks = video_pred_masks[:, :resized_h, : + resized_w].unsqueeze( + 1) # remove the padding + # resize the masks back to their original frames dataset size for evaluation: + original_frames_size = video_metadata['original_frame_size'] + tuple_size = tuple(original_frames_size.cpu().numpy()) + video_pred_masks = F.interpolate( + video_pred_masks.float(), size=tuple_size, mode='nearest') + video_pred_masks = video_pred_masks.to(torch.uint8).cpu() + # combine the predicted masks and the video metadata to create a final predictions dict: + video_pred = {**video_metadata, **{'pred_masks': video_pred_masks}} + preds_by_video.append(video_pred) + return preds_by_video diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/segmentation.py b/modelscope/models/cv/referring_video_object_segmentation/utils/segmentation.py new file mode 100644 index 00000000..b3228820 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/segmentation.py @@ -0,0 +1,137 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR +# Modified from DETR https://github.com/facebookresearch/detr + +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + + +class FPNSpatialDecoder(nn.Module): + """ + An FPN-like spatial decoder. Generates high-res, semantically rich features which serve as the base for creating + instance segmentation masks. + """ + + def __init__(self, context_dim, fpn_dims, mask_kernels_dim=8): + super().__init__() + + inter_dims = [ + context_dim, context_dim // 2, context_dim // 4, context_dim // 8, + context_dim // 16 + ] + self.lay1 = torch.nn.Conv2d(context_dim, inter_dims[0], 3, padding=1) + self.gn1 = torch.nn.GroupNorm(8, inter_dims[0]) + self.lay2 = torch.nn.Conv2d(inter_dims[0], inter_dims[1], 3, padding=1) + self.gn2 = torch.nn.GroupNorm(8, inter_dims[1]) + self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1) + self.gn3 = torch.nn.GroupNorm(8, inter_dims[2]) + self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1) + self.gn4 = torch.nn.GroupNorm(8, inter_dims[3]) + self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1) + self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1) + self.context_dim = context_dim + + self.add_extra_layer = len(fpn_dims) == 3 + if self.add_extra_layer: + self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1) + self.lay5 = torch.nn.Conv2d( + inter_dims[3], inter_dims[4], 3, padding=1) + self.gn5 = torch.nn.GroupNorm(8, inter_dims[4]) + self.out_lay = torch.nn.Conv2d( + inter_dims[4], mask_kernels_dim, 3, padding=1) + else: + self.out_lay = torch.nn.Conv2d( + inter_dims[3], mask_kernels_dim, 3, padding=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_uniform_(m.weight, a=1) + nn.init.constant_(m.bias, 0) + + def forward(self, x: Tensor, layer_features: List[Tensor]): + x = self.lay1(x) + x = self.gn1(x) + x = F.relu(x) + x = self.lay2(x) + x = self.gn2(x) + x = F.relu(x) + + cur_fpn = self.adapter1(layer_features[0]) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode='nearest') + x = self.lay3(x) + x = self.gn3(x) + x = F.relu(x) + + cur_fpn = self.adapter2(layer_features[1]) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode='nearest') + x = self.lay4(x) + x = self.gn4(x) + x = F.relu(x) + + if self.add_extra_layer: + cur_fpn = self.adapter3(layer_features[2]) + x = cur_fpn + F.interpolate( + x, size=cur_fpn.shape[-2:], mode='nearest') + x = self.lay5(x) + x = self.gn5(x) + x = F.relu(x) + + x = self.out_lay(x) + return x + + def num_parameters(self): + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +def dice_loss(inputs, targets, num_masks): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid() + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_masks + + +def sigmoid_focal_loss(inputs, + targets, + num_masks, + alpha: float = 0.25, + gamma: float = 2): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + Returns: + Loss tensor + """ + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits( + inputs, targets, reduction='none') + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t)**gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + return loss.mean(1).sum() / num_masks diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/swin_transformer.py b/modelscope/models/cv/referring_video_object_segmentation/utils/swin_transformer.py new file mode 100644 index 00000000..9a08ef48 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/swin_transformer.py @@ -0,0 +1,731 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR +# Modified from Video-Swin-Transformer https://github.com/SwinTransformer/Video-Swin-Transformer + +from functools import lru_cache, reduce +from operator import mul + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from einops import rearrange +from timm.models.layers import DropPath, trunc_normal_ + + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, D, H, W, C) + window_size (tuple[int]): window size + + Returns: + windows: (B*num_windows, window_size*window_size, C) + """ + B, D, H, W, C = x.shape + x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], + window_size[1], W // window_size[2], window_size[2], C) + windows = x.permute(0, 1, 3, 5, 2, 4, 6, + 7).contiguous().view(-1, reduce(mul, window_size), C) + return windows + + +def window_reverse(windows, window_size, B, D, H, W): + """ + Args: + windows: (B*num_windows, window_size, window_size, C) + window_size (tuple[int]): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, D, H, W, C) + """ + x = windows.view(B, D // window_size[0], H // window_size[1], + W // window_size[2], window_size[0], window_size[1], + window_size[2], -1) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1) + return x + + +def get_window_size(x_size, window_size, shift_size=None): + use_window_size = list(window_size) + if shift_size is not None: + use_shift_size = list(shift_size) + for i in range(len(x_size)): + if x_size[i] <= window_size[i]: + use_window_size[i] = x_size[i] + if shift_size is not None: + use_shift_size[i] = 0 + + if shift_size is None: + return tuple(use_window_size) + else: + return tuple(use_window_size), tuple(use_shift_size) + + +class WindowAttention3D(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The temporal length, height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, + dim, + window_size, + num_heads, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wd, Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + wd, wh, ww = window_size + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * wd - 1) * (2 * wh - 1) * (2 * ww - 1), num_heads)) + + # get pair-wise relative position index for each token inside the window + coords_d = torch.arange(self.window_size[0]) + coords_h = torch.arange(self.window_size[1]) + coords_w = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid(coords_d, coords_h, + coords_w)) # 3, Wd, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, + None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 + relative_coords[:, :, + 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + + relative_coords[:, :, 0] *= (2 * self.window_size[1] + - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1) + relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww + self.register_buffer('relative_position_index', + relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, N, N) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index[:N, :N].reshape(-1)].reshape( + N, N, -1) # Wd*Wh*Ww,Wd*Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock3D(nn.Module): + """ Swin Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): Window size. + shift_size (tuple[int]): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=(2, 7, 7), + shift_size=(0, 0, 0), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + use_checkpoint=False): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.use_checkpoint = use_checkpoint + + assert 0 <= self.shift_size[0] < self.window_size[ + 0], 'shift_size must in 0-window_size' + assert 0 <= self.shift_size[1] < self.window_size[ + 1], 'shift_size must in 0-window_size' + assert 0 <= self.shift_size[2] < self.window_size[ + 2], 'shift_size must in 0-window_size' + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention3D( + dim, + window_size=self.window_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + def forward_part1(self, x, mask_matrix): + B, D, H, W, C = x.shape + window_size, shift_size = get_window_size((D, H, W), self.window_size, + self.shift_size) + + x = self.norm1(x) + # pad feature maps to multiples of window size + pad_l = pad_t = pad_d0 = 0 + pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0] + pad_b = (window_size[1] - H % window_size[1]) % window_size[1] + pad_r = (window_size[2] - W % window_size[2]) % window_size[2] + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) + _, Dp, Hp, Wp, _ = x.shape + # cyclic shift + if any(i > 0 for i in shift_size): + shifted_x = torch.roll( + x, + shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), + dims=(1, 2, 3)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + # partition windows + x_windows = window_partition(shifted_x, + window_size) # B*nW, Wd*Wh*Ww, C + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C + # merge windows + attn_windows = attn_windows.view(-1, *(window_size + (C, ))) + shifted_x = window_reverse(attn_windows, window_size, B, Dp, Hp, + Wp) # B D' H' W' C + # reverse cyclic shift + if any(i > 0 for i in shift_size): + x = torch.roll( + shifted_x, + shifts=(shift_size[0], shift_size[1], shift_size[2]), + dims=(1, 2, 3)) + else: + x = shifted_x + + if pad_d1 > 0 or pad_r > 0 or pad_b > 0: + x = x[:, :D, :H, :W, :].contiguous() + return x + + def forward_part2(self, x): + return self.drop_path(self.mlp(self.norm2(x))) + + def forward(self, x, mask_matrix): + """ Forward function. + + Args: + x: Input feature, tensor size (B, D, H, W, C). + mask_matrix: Attention mask for cyclic shift. + """ + + shortcut = x + if self.use_checkpoint: + x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix) + else: + x = self.forward_part1(x, mask_matrix) + x = shortcut + self.drop_path(x) + + if self.use_checkpoint: + x = x + checkpoint.checkpoint(self.forward_part2, x) + else: + x = x + self.forward_part2(x) + + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer + + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ Forward function. + + Args: + x: Input feature, tensor size (B, D, H, W, C). + """ + B, D, H, W, C = x.shape + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C + x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C + x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C + x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +# cache each stage results +@lru_cache() +def compute_mask(D, H, W, window_size, shift_size, device): + img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1 + cnt = 0 + for d in slice(-window_size[0]), slice(-window_size[0], + -shift_size[0]), slice( + -shift_size[0], None): + for h in slice(-window_size[1]), slice(-window_size[1], + -shift_size[1]), slice( + -shift_size[1], None): + for w in slice(-window_size[2]), slice(-window_size[2], + -shift_size[2]), slice( + -shift_size[2], None): + img_mask[:, d, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, + window_size) # nW, ws[0]*ws[1]*ws[2], 1 + mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2] + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + return attn_mask + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (tuple[int]): Local window size. Default: (1,7,7). + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=(1, 7, 7), + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False): + super().__init__() + self.window_size = window_size + self.shift_size = tuple(i // 2 for i in window_size) + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock3D( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=(0, 0, 0) if (i % 2 == 0) else self.shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + use_checkpoint=use_checkpoint, + ) for i in range(depth) + ]) + + self.downsample = downsample + if self.downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + + def forward(self, x): + """ Forward function. + + Args: + x: Input feature, tensor size (B, C, D, H, W). + """ + # calculate attention mask for SW-MSA + B, C, D, H, W = x.shape + window_size, shift_size = get_window_size((D, H, W), self.window_size, + self.shift_size) + x = rearrange(x, 'b c d h w -> b d h w c') + Dp = int(np.ceil(D / window_size[0])) * window_size[0] + Hp = int(np.ceil(H / window_size[1])) * window_size[1] + Wp = int(np.ceil(W / window_size[2])) * window_size[2] + attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) + for blk in self.blocks: + x = blk(x, attn_mask) + x = x.view(B, D, H, W, -1) + + if self.downsample is not None: + x = self.downsample(x) + x = rearrange(x, 'b d h w c -> b c d h w') + return x + + +class PatchEmbed3D(nn.Module): + """ Video to Patch Embedding. + + Args: + patch_size (int): Patch token size. Default: (2,4,4). + in_chans (int): Number of input video channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, + patch_size=(2, 4, 4), + in_chans=3, + embed_dim=96, + norm_layer=None): + super().__init__() + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv3d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, D, H, W = x.size() + if W % self.patch_size[2] != 0: + x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) + if H % self.patch_size[1] != 0: + x = F.pad(x, + (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) + if D % self.patch_size[0] != 0: + x = F.pad( + x, + (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) + + x = self.proj(x) # B C D Wh Ww + if self.norm is not None: + D, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) + + return x + + +class SwinTransformer3D(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + patch_size (int | tuple(int)): Patch size. Default: (4,4,4). + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer: Normalization layer. Default: nn.LayerNorm. + patch_norm (bool): If True, add normalization after patch embedding. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + """ + + def __init__(self, + pretrained=None, + pretrained2d=True, + patch_size=(4, 4, 4), + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(2, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=False, + frozen_stages=-1, + use_checkpoint=False): + super().__init__() + + self.pretrained = pretrained + self.pretrained2d = pretrained2d + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.patch_norm = patch_norm + self.frozen_stages = frozen_stages + self.window_size = window_size + self.patch_size = patch_size + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed3D( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging + if i_layer < self.num_layers - 1 else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + self.num_features = int(embed_dim * 2**(self.num_layers - 1)) + + # add a norm layer for each output + self.norm = norm_layer(self.num_features) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1: + self.pos_drop.eval() + for i in range(0, self.frozen_stages): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def inflate_weights(self, logger): + """Inflate the swin2d parameters to swin3d. + + The differences between swin3d and swin2d mainly lie in an extra + axis. To utilize the pretrained parameters in 2d model, + the weight of swin2d models should be inflated to fit in the shapes of + the 3d counterpart. + + Args: + logger (logging.Logger): The logger used to print + debugging infomation. + """ + checkpoint = torch.load(self.pretrained, map_location='cpu') + state_dict = checkpoint['model'] + + # delete relative_position_index since we always re-init it + relative_position_index_keys = [ + k for k in state_dict.keys() if 'relative_position_index' in k + ] + for k in relative_position_index_keys: + del state_dict[k] + + # delete attn_mask since we always re-init it + attn_mask_keys = [k for k in state_dict.keys() if 'attn_mask' in k] + for k in attn_mask_keys: + del state_dict[k] + + state_dict['patch_embed.proj.weight'] = state_dict[ + 'patch_embed.proj.weight'].unsqueeze(2).repeat( + 1, 1, self.patch_size[0], 1, 1) / self.patch_size[0] + + # bicubic interpolate relative_position_bias_table if not match + relative_position_bias_table_keys = [ + k for k in state_dict.keys() if 'relative_position_bias_table' in k + ] + for k in relative_position_bias_table_keys: + relative_position_bias_table_pretrained = state_dict[k] + relative_position_bias_table_current = self.state_dict()[k] + L1, nH1 = relative_position_bias_table_pretrained.size() + L2, nH2 = relative_position_bias_table_current.size() + L2 = (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + wd = self.window_size[0] + if nH1 != nH2: + logger.warning(f'Error in loading {k}, passing') + else: + if L1 != L2: + S1 = int(L1**0.5) + relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( + relative_position_bias_table_pretrained.permute( + 1, 0).view(1, nH1, S1, S1), + size=(2 * self.window_size[1] - 1, + 2 * self.window_size[2] - 1), + mode='bicubic') + relative_position_bias_table_pretrained = relative_position_bias_table_pretrained_resized.view( + nH2, L2).permute(1, 0) + state_dict[k] = relative_position_bias_table_pretrained.repeat( + 2 * wd - 1, 1) + + msg = self.load_state_dict(state_dict, strict=False) + logger.info(msg) + logger.info(f"=> loaded successfully '{self.pretrained}'") + del checkpoint + torch.cuda.empty_cache() + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x.contiguous()) + + x = rearrange(x, 'n c d h w -> n d h w c') + x = self.norm(x) + x = rearrange(x, 'n d h w c -> n c d h w') + + return x + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer3D, self).train(mode) + self._freeze_stages() diff --git a/modelscope/models/cv/super_resolution/arch_util.py b/modelscope/models/cv/super_resolution/arch_util.py index 4b87c877..99711a11 100644 --- a/modelscope/models/cv/super_resolution/arch_util.py +++ b/modelscope/models/cv/super_resolution/arch_util.py @@ -1,3 +1,5 @@ +# The implementation is adopted from BasicSR, made public available under the Apache 2.0 License +# at https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/archs/arch_util.py import collections.abc import math import warnings diff --git a/modelscope/models/cv/super_resolution/rrdbnet_arch.py b/modelscope/models/cv/super_resolution/rrdbnet_arch.py index 44947de1..8c84f796 100644 --- a/modelscope/models/cv/super_resolution/rrdbnet_arch.py +++ b/modelscope/models/cv/super_resolution/rrdbnet_arch.py @@ -1,3 +1,5 @@ +# The implementation is adopted from BasicSR, made public available under the Apache 2.0 License +# at https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/archs/rrdbnet_arch.py import torch from torch import nn as nn from torch.nn import functional as F diff --git a/modelscope/models/cv/text_driven_segmentation/lseg_model.py b/modelscope/models/cv/text_driven_segmentation/lseg_model.py index 9a5754c6..ec381356 100644 --- a/modelscope/models/cv/text_driven_segmentation/lseg_model.py +++ b/modelscope/models/cv/text_driven_segmentation/lseg_model.py @@ -93,7 +93,7 @@ class TextDrivenSeg(TorchModel): """ with torch.no_grad(): if self.device_id == -1: - output = self.model(image) + output = self.model(image, [text]) else: device = torch.device('cuda', self.device_id) output = self.model(image.to(device), [text]) diff --git a/modelscope/models/cv/tinynas_detection/__init__.py b/modelscope/models/cv/tinynas_detection/__init__.py index 13532d10..6d696ac4 100644 --- a/modelscope/models/cv/tinynas_detection/__init__.py +++ b/modelscope/models/cv/tinynas_detection/__init__.py @@ -7,10 +7,12 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .tinynas_detector import Tinynas_detector + from .tinynas_damoyolo import DamoYolo else: _import_structure = { 'tinynas_detector': ['TinynasDetector'], + 'tinynas_damoyolo': ['DamoYolo'], } import sys diff --git a/modelscope/models/cv/tinynas_detection/backbone/tinynas.py b/modelscope/models/cv/tinynas_detection/backbone/tinynas.py index 814ee550..87a28a2f 100755 --- a/modelscope/models/cv/tinynas_detection/backbone/tinynas.py +++ b/modelscope/models/cv/tinynas_detection/backbone/tinynas.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn +from modelscope.utils.file_utils import read_file from ..core.base_ops import Focus, SPPBottleneck, get_activation from ..core.repvgg_block import RepVggBlock @@ -49,12 +50,16 @@ class ResConvK1KX(nn.Module): kernel_size, stride, force_resproj=False, - act='silu'): + act='silu', + reparam=False): super(ResConvK1KX, self).__init__() self.stride = stride self.conv1 = ConvKXBN(in_c, btn_c, 1, 1) - self.conv2 = RepVggBlock( - btn_c, out_c, kernel_size, stride, act='identity') + if not reparam: + self.conv2 = ConvKXBN(btn_c, out_c, 3, stride) + else: + self.conv2 = RepVggBlock( + btn_c, out_c, kernel_size, stride, act='identity') if act is None: self.activation_function = torch.relu @@ -97,7 +102,8 @@ class SuperResConvK1KX(nn.Module): stride, num_blocks, with_spp=False, - act='silu'): + act='silu', + reparam=False): super(SuperResConvK1KX, self).__init__() if act is None: self.act = torch.relu @@ -124,7 +130,8 @@ class SuperResConvK1KX(nn.Module): this_kernel_size, this_stride, force_resproj, - act=act) + act=act, + reparam=reparam) self.block_list.append(the_block) if block_id == 0 and with_spp: self.block_list.append( @@ -248,7 +255,8 @@ class TinyNAS(nn.Module): with_spp=False, use_focus=False, need_conv1=True, - act='silu'): + act='silu', + reparam=False): super(TinyNAS, self).__init__() assert len(out_indices) == len(out_channels) self.out_indices = out_indices @@ -281,7 +289,8 @@ class TinyNAS(nn.Module): block_info['s'], block_info['L'], spp, - act=act) + act=act, + reparam=reparam) self.block_list.append(the_block) elif the_block_class == 'SuperResConvKXKX': spp = with_spp if idx == len(structure_info) - 1 else False @@ -325,8 +334,8 @@ class TinyNAS(nn.Module): def load_tinynas_net(backbone_cfg): # load masternet model to path import ast - - struct_str = ''.join([x.strip() for x in backbone_cfg.net_structure_str]) + net_structure_str = read_file(backbone_cfg.structure_file) + struct_str = ''.join([x.strip() for x in net_structure_str]) struct_info = ast.literal_eval(struct_str) for layer in struct_info: if 'nbitsA' in layer: @@ -342,6 +351,6 @@ def load_tinynas_net(backbone_cfg): use_focus=backbone_cfg.use_focus, act=backbone_cfg.act, need_conv1=backbone_cfg.need_conv1, - ) + reparam=backbone_cfg.reparam) return model diff --git a/modelscope/models/cv/tinynas_detection/detector.py b/modelscope/models/cv/tinynas_detection/detector.py index 615b13a8..42a71381 100644 --- a/modelscope/models/cv/tinynas_detection/detector.py +++ b/modelscope/models/cv/tinynas_detection/detector.py @@ -30,7 +30,7 @@ class SingleStageDetector(TorchModel): """ super().__init__(model_dir, *args, **kwargs) - config_path = osp.join(model_dir, 'airdet_s.py') + config_path = osp.join(model_dir, self.config_name) config = parse_config(config_path) self.cfg = config model_path = osp.join(model_dir, config.model.name) @@ -41,6 +41,9 @@ class SingleStageDetector(TorchModel): self.conf_thre = config.model.head.nms_conf_thre self.nms_thre = config.model.head.nms_iou_thre + if self.cfg.model.backbone.name == 'TinyNAS': + self.cfg.model.backbone.structure_file = osp.join( + model_dir, self.cfg.model.backbone.structure_file) self.backbone = build_backbone(self.cfg.model.backbone) self.neck = build_neck(self.cfg.model.neck) self.head = build_head(self.cfg.model.head) diff --git a/modelscope/models/cv/tinynas_detection/head/gfocal_v2_tiny.py b/modelscope/models/cv/tinynas_detection/head/gfocal_v2_tiny.py index 41f35968..66904ed1 100644 --- a/modelscope/models/cv/tinynas_detection/head/gfocal_v2_tiny.py +++ b/modelscope/models/cv/tinynas_detection/head/gfocal_v2_tiny.py @@ -124,11 +124,13 @@ class GFocalHead_Tiny(nn.Module): simOTA_iou_weight=3.0, octbase=8, simlqe=False, + use_lqe=True, **kwargs): self.simlqe = simlqe self.num_classes = num_classes self.in_channels = in_channels self.strides = strides + self.use_lqe = use_lqe self.feat_channels = feat_channels if isinstance(feat_channels, list) \ else [feat_channels] * len(self.strides) @@ -181,15 +183,20 @@ class GFocalHead_Tiny(nn.Module): groups=self.conv_groups, norm=self.norm, act=self.act)) - if not self.simlqe: - conf_vector = [nn.Conv2d(4 * self.total_dim, self.reg_channels, 1)] + if self.use_lqe: + if not self.simlqe: + conf_vector = [ + nn.Conv2d(4 * self.total_dim, self.reg_channels, 1) + ] + else: + conf_vector = [ + nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1) + ] + conf_vector += [self.relu] + conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()] + reg_conf = nn.Sequential(*conf_vector) else: - conf_vector = [ - nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1) - ] - conf_vector += [self.relu] - conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()] - reg_conf = nn.Sequential(*conf_vector) + reg_conf = None return cls_convs, reg_convs, reg_conf @@ -290,21 +297,27 @@ class GFocalHead_Tiny(nn.Module): N, C, H, W = bbox_pred.size() prob = F.softmax( bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2) - if not self.simlqe: - prob_topk, _ = prob.topk(self.reg_topk, dim=2) - - if self.add_mean: - stat = torch.cat( - [prob_topk, prob_topk.mean(dim=2, keepdim=True)], dim=2) + if self.use_lqe: + if not self.simlqe: + prob_topk, _ = prob.topk(self.reg_topk, dim=2) + + if self.add_mean: + stat = torch.cat( + [prob_topk, + prob_topk.mean(dim=2, keepdim=True)], + dim=2) + else: + stat = prob_topk + + quality_score = reg_conf( + stat.reshape(N, 4 * self.total_dim, H, W)) else: - stat = prob_topk + quality_score = reg_conf( + bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W)) - quality_score = reg_conf(stat.reshape(N, 4 * self.total_dim, H, W)) + cls_score = gfl_cls(cls_feat).sigmoid() * quality_score else: - quality_score = reg_conf( - bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W)) - - cls_score = gfl_cls(cls_feat).sigmoid() * quality_score + cls_score = gfl_cls(cls_feat).sigmoid() flatten_cls_score = cls_score.flatten(start_dim=2).transpose(1, 2) flatten_bbox_pred = bbox_pred.flatten(start_dim=2).transpose(1, 2) diff --git a/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_v2.py b/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_v2.py index b710572f..b88c39f2 100644 --- a/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_v2.py +++ b/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_v2.py @@ -14,7 +14,6 @@ class GiraffeNeckV2(nn.Module): self, depth=1.0, width=1.0, - in_features=[2, 3, 4], in_channels=[256, 512, 1024], out_channels=[256, 512, 1024], depthwise=False, @@ -24,7 +23,6 @@ class GiraffeNeckV2(nn.Module): block_name='BasicBlock', ): super().__init__() - self.in_features = in_features self.in_channels = in_channels Conv = DWConv if depthwise else BaseConv @@ -169,8 +167,7 @@ class GiraffeNeckV2(nn.Module): """ # backbone - features = [out_features[f] for f in self.in_features] - [x2, x1, x0] = features + [x2, x1, x0] = out_features # node x3 x13 = self.bu_conv13(x1) diff --git a/modelscope/models/cv/tinynas_detection/tinynas_damoyolo.py b/modelscope/models/cv/tinynas_detection/tinynas_damoyolo.py new file mode 100644 index 00000000..9effad3a --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/tinynas_damoyolo.py @@ -0,0 +1,15 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Tasks +from .detector import SingleStageDetector + + +@MODELS.register_module( + Tasks.image_object_detection, module_name=Models.tinynas_damoyolo) +class DamoYolo(SingleStageDetector): + + def __init__(self, model_dir, *args, **kwargs): + self.config_name = 'damoyolo_s.py' + super(DamoYolo, self).__init__(model_dir, *args, **kwargs) diff --git a/modelscope/models/cv/tinynas_detection/tinynas_detector.py b/modelscope/models/cv/tinynas_detection/tinynas_detector.py index e6f144df..92acf3fa 100644 --- a/modelscope/models/cv/tinynas_detection/tinynas_detector.py +++ b/modelscope/models/cv/tinynas_detection/tinynas_detector.py @@ -12,5 +12,5 @@ from .detector import SingleStageDetector class TinynasDetector(SingleStageDetector): def __init__(self, model_dir, *args, **kwargs): - + self.config_name = 'airdet_s.py' super(TinynasDetector, self).__init__(model_dir, *args, **kwargs) diff --git a/modelscope/models/cv/video_summarization/__init__.py b/modelscope/models/cv/video_summarization/__init__.py index 064110f7..15ad61b4 100644 --- a/modelscope/models/cv/video_summarization/__init__.py +++ b/modelscope/models/cv/video_summarization/__init__.py @@ -1 +1,22 @@ -from .summarizer import PGLVideoSummarization +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .summarizer import (PGLVideoSummarization, summary_format) + +else: + _import_structure = { + 'summarizer': ['PGLVideoSummarization', 'summary_format'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/video_summarization/base_model.py b/modelscope/models/cv/video_summarization/base_model.py index 670da251..912ba68d 100644 --- a/modelscope/models/cv/video_summarization/base_model.py +++ b/modelscope/models/cv/video_summarization/base_model.py @@ -1,4 +1,5 @@ -# The implementation is based on pytorch-caffe-models, available at https://github.com/crowsonkb/pytorch-caffe-models. +# Part of the implementation is borrowed and modified from pytorch-caffe-models, +# publicly available at https://github.com/crowsonkb/pytorch-caffe-models import cv2 import numpy as np diff --git a/modelscope/models/cv/video_summarization/kts/cpd_auto.py b/modelscope/models/cv/video_summarization/kts/cpd_auto.py index a794ca26..58281df8 100644 --- a/modelscope/models/cv/video_summarization/kts/cpd_auto.py +++ b/modelscope/models/cv/video_summarization/kts/cpd_auto.py @@ -1,4 +1,5 @@ -# The implementation is based on KTS, available at https://github.com/TatsuyaShirakawa/KTS. +# Part of the implementation is borrowed and modified from KTS, +# publicly available at https://github.com/TatsuyaShirakawa/KTS import numpy as np diff --git a/modelscope/models/cv/video_summarization/kts/cpd_nonlin.py b/modelscope/models/cv/video_summarization/kts/cpd_nonlin.py index ef2eb6ef..55e279e9 100644 --- a/modelscope/models/cv/video_summarization/kts/cpd_nonlin.py +++ b/modelscope/models/cv/video_summarization/kts/cpd_nonlin.py @@ -1,4 +1,5 @@ -# The implementation is based on KTS, available at https://github.com/TatsuyaShirakawa/KTS. +# Part of the implementation is borrowed and modified from KTS, +# publicly available at https://github.com/TatsuyaShirakawa/KTS import numpy as np diff --git a/modelscope/models/cv/video_summarization/pgl_sum.py b/modelscope/models/cv/video_summarization/pgl_sum.py index ab3010c9..2d27501d 100644 --- a/modelscope/models/cv/video_summarization/pgl_sum.py +++ b/modelscope/models/cv/video_summarization/pgl_sum.py @@ -1,4 +1,5 @@ -# The implementation is based on PGL-SUM, available at https://github.com/e-apostolidis/PGL-SUM. +# Part of the implementation is borrowed and modified from PGL-SUM, +# publicly available at https://github.com/e-apostolidis/PGL-SUM import math diff --git a/modelscope/models/cv/video_summarization/summarizer.py b/modelscope/models/cv/video_summarization/summarizer.py index c95da025..c9987670 100644 --- a/modelscope/models/cv/video_summarization/summarizer.py +++ b/modelscope/models/cv/video_summarization/summarizer.py @@ -1,4 +1,5 @@ -# The implementation is based on PGL-SUM, available at https://github.com/e-apostolidis/PGL-SUM. +# Part of the implementation is borrowed and modified from PGL-SUM, +# publicly available at https://github.com/e-apostolidis/PGL-SUM import os.path as osp from copy import deepcopy @@ -23,7 +24,8 @@ logger = get_logger() def get_change_points(video_feat, n_frame): video_feat = np.array(video_feat, np.float32) K = np.dot(video_feat, video_feat.T) - change_points, _ = cpd_auto(K, ncp=120, vmax=2.2 / 4.0, lmin=1) + change_points, _ = cpd_auto( + K, ncp=min(K.shape[0] - 1, 120), vmax=2.2 / 4.0, lmin=1) change_points = change_points * 15 change_points = np.concatenate(([0], change_points, [n_frame - 1])) @@ -135,6 +137,46 @@ def generate_summary(all_shot_bound, all_scores, all_nframes, all_positions): return all_summaries +def transform_time(seconds): + m, s = divmod(seconds, 60) + h, m = divmod(m, 60) + time = '%02d:%02d:%06.3f' % (h, m, s) + return time + + +def summary_format(summary, fps): + frames_list = [] + start_frame = -1 + end_frame = -1 + is_summary_frame = False + for i, idx in enumerate(summary): + if idx: + if is_summary_frame is False: + start_frame = i + is_summary_frame = True + else: + if is_summary_frame: + end_frame = i - 1 + frames_list.append([start_frame, end_frame]) + is_summary_frame = False + + if is_summary_frame and summary[-1] == 1: + end_frame = len(summary) - 1 + frames_list.append([start_frame, end_frame]) + + output = [] + for seg in frames_list: + output.append({ + 'frame': + seg, + 'timestamps': [ + transform_time(seg[0] / float(fps)), + transform_time(seg[1] / float(fps)) + ] + }) + return output + + @MODELS.register_module( Tasks.video_summarization, module_name=Models.video_summarization) class PGLVideoSummarization(TorchModel): diff --git a/modelscope/models/multi_modal/clip/__init__.py b/modelscope/models/multi_modal/clip/__init__.py index 3fd492b9..e2e925ce 100644 --- a/modelscope/models/multi_modal/clip/__init__.py +++ b/modelscope/models/multi_modal/clip/__init__.py @@ -1 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + from .model import CLIPForMultiModalEmbedding diff --git a/modelscope/models/multi_modal/clip/model.py b/modelscope/models/multi_modal/clip/model.py index 2fb0d7e3..92d9e11a 100644 --- a/modelscope/models/multi_modal/clip/model.py +++ b/modelscope/models/multi_modal/clip/model.py @@ -1,3 +1,18 @@ +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from collections import OrderedDict from typing import Any, Dict, Iterable, List, Tuple, Union diff --git a/modelscope/models/multi_modal/diffusion/__init__.py b/modelscope/models/multi_modal/diffusion/__init__.py index 28813cc9..e7e374b6 100644 --- a/modelscope/models/multi_modal/diffusion/__init__.py +++ b/modelscope/models/multi_modal/diffusion/__init__.py @@ -1 +1,2 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. from .model import DiffusionForTextToImageSynthesis diff --git a/modelscope/models/multi_modal/gemm/__init__.py b/modelscope/models/multi_modal/gemm/__init__.py index b920628e..fe5df1fe 100644 --- a/modelscope/models/multi_modal/gemm/__init__.py +++ b/modelscope/models/multi_modal/gemm/__init__.py @@ -1 +1,2 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. from .gemm_model import GEMMForMultiModalEmbedding diff --git a/modelscope/models/multi_modal/gemm/gemm_base.py b/modelscope/models/multi_modal/gemm/gemm_base.py index 09ef2480..806c469c 100644 --- a/modelscope/models/multi_modal/gemm/gemm_base.py +++ b/modelscope/models/multi_modal/gemm/gemm_base.py @@ -543,6 +543,7 @@ class GEMMModel(nn.Module): img_feature, text_feature, caption = None, None, None if captioning and image is not None: img_feature, caption = self.model.image_to_text(image) + img_feature = self.parse_feat(img_feature) elif image is not None: img_feature = self.parse_feat(self.model.encode_image(image)) if text is not None: diff --git a/modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py b/modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py index 5e8e2e7a..0cc040c6 100644 --- a/modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py +++ b/modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py @@ -1,9 +1,13 @@ -# The implementation is adopated from the CLIP4Clip implementation, +# The implementation is adopted from the CLIP4Clip implementation, # made pubicly available under Apache License, Version 2.0 at https://github.com/ArrowLuo/CLIP4Clip +import os import random +import uuid from os.path import exists +from tempfile import TemporaryDirectory from typing import Any, Dict +from urllib.parse import urlparse import json import numpy as np @@ -11,6 +15,7 @@ import torch from decord import VideoReader, cpu from PIL import Image +from modelscope.hub.file_download import http_get_file from modelscope.metainfo import Models from modelscope.models import TorchModel from modelscope.models.builder import MODELS @@ -68,12 +73,16 @@ class VideoCLIPForMultiModalEmbedding(TorchModel): self.model.to(self.device) def _get_text(self, caption, tokenizer, enable_zh=False): - if len(caption) == 3: - _caption_text, s, e = caption - elif len(caption) == 4: - _caption_text, s, e, pos = caption - else: - NotImplementedError + + if type(caption) is str: + _caption_text, s, e = caption, None, None + elif type(caption) is tuple: + if len(caption) == 3: + _caption_text, s, e = caption + elif len(caption) == 4: + _caption_text, s, e, pos = caption + else: + NotImplementedError if isinstance(_caption_text, list): caption_text = random.choice(_caption_text) @@ -137,11 +146,25 @@ class VideoCLIPForMultiModalEmbedding(TorchModel): elif start_time == end_time: end_time = end_time + 1 - if exists(video_path): + url_parsed = urlparse(video_path) + if url_parsed.scheme in ('file', '') and exists( + url_parsed.path): # Possibly a local file vreader = VideoReader(video_path, ctx=cpu(0)) else: - logger.error('non video input, output is wrong!!!') - return video, video_mask + try: + with TemporaryDirectory() as temporary_cache_dir: + random_str = uuid.uuid4().hex + http_get_file( + url=video_path, + local_dir=temporary_cache_dir, + file_name=random_str, + cookies=None) + temp_file_path = os.path.join(temporary_cache_dir, + random_str) + vreader = VideoReader(temp_file_path, ctx=cpu(0)) + except Exception as ex: + logger.error('non video input, output is {}!!!'.format(ex)) + return video, video_mask fps = vreader.get_avg_fps() f_start = 0 if start_time is None else int(start_time * fps) @@ -213,8 +236,10 @@ class VideoCLIPForMultiModalEmbedding(TorchModel): logger.info('text feature: {}'.format(sequence_output[0][0][0])) logger.info('video feature: {}'.format(visual_output[0][0][0])) - output[OutputKeys.VIDEO_EMBEDDING] = visual_output - output[OutputKeys.TEXT_EMBEDDING] = sequence_output + output[ + OutputKeys.VIDEO_EMBEDDING] = visual_output.cpu().detach().numpy() + output[OutputKeys.TEXT_EMBEDDING] = sequence_output.cpu().detach( + ).numpy() return output def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: diff --git a/modelscope/models/multi_modal/mmr/models/dynamic_inverted_softmax.py b/modelscope/models/multi_modal/mmr/models/dynamic_inverted_softmax.py index 253a847c..c2d96275 100644 --- a/modelscope/models/multi_modal/mmr/models/dynamic_inverted_softmax.py +++ b/modelscope/models/multi_modal/mmr/models/dynamic_inverted_softmax.py @@ -1,4 +1,4 @@ -# The implementation is adopated from the CLIP4Clip implementation, +# The implementation is adopted from the CLIP4Clip implementation, # made pubicly available under Apache License, Version 2.0 at https://github.com/ArrowLuo/CLIP4Clip import numpy as np diff --git a/modelscope/models/multi_modal/mmr/models/tokenization_clip.py b/modelscope/models/multi_modal/mmr/models/tokenization_clip.py index 4e2c9b15..97ee7156 100644 --- a/modelscope/models/multi_modal/mmr/models/tokenization_clip.py +++ b/modelscope/models/multi_modal/mmr/models/tokenization_clip.py @@ -1,4 +1,4 @@ -# The implementation is adopated from the CLIP4Clip implementation, +# The implementation is adopted from the CLIP4Clip implementation, # made pubicly available under Apache License, Version 2.0 at https://github.com/ArrowLuo/CLIP4Clip import gzip diff --git a/modelscope/models/multi_modal/multi_stage_diffusion/__init__.py b/modelscope/models/multi_modal/multi_stage_diffusion/__init__.py index accbb56e..1b3f445b 100644 --- a/modelscope/models/multi_modal/multi_stage_diffusion/__init__.py +++ b/modelscope/models/multi_modal/multi_stage_diffusion/__init__.py @@ -1 +1,2 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. from .model import MultiStageDiffusionForTextToImageSynthesis diff --git a/modelscope/models/multi_modal/ofa/__init__.py b/modelscope/models/multi_modal/ofa/__init__.py index 16de7fff..3e8e59f4 100644 --- a/modelscope/models/multi_modal/ofa/__init__.py +++ b/modelscope/models/multi_modal/ofa/__init__.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + from .modeling_ofa import OFADecoder, OFAEncoder, OFAModel, OFAPreTrainedModel from .tokenization_ofa import OFATokenizer, OFATokenizerZH from .tokenization_ofa_fast import OFATokenizerFast, OFATokenizerZHFast diff --git a/modelscope/models/multi_modal/ofa/adaptor/__init__.py b/modelscope/models/multi_modal/ofa/adaptor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/multi_modal/ofa/generate/search.py b/modelscope/models/multi_modal/ofa/generate/search.py index 63ecb0a9..0dcaf6b3 100644 --- a/modelscope/models/multi_modal/ofa/generate/search.py +++ b/modelscope/models/multi_modal/ofa/generate/search.py @@ -148,7 +148,7 @@ class BeamSearch(Search): scores_buf = top_prediction[0] indices_buf = top_prediction[1] # Project back into relative indices and beams - beams_buf = indices_buf // vocab_size + beams_buf = torch.div(indices_buf, vocab_size, rounding_mode='floor') indices_buf = indices_buf.fmod(vocab_size) # At this point, beams_buf and indices_buf are single-dim and contain relative indices diff --git a/modelscope/models/multi_modal/ofa/generate/sequence_generator.py b/modelscope/models/multi_modal/ofa/generate/sequence_generator.py index 590fb67b..e42d3c8e 100644 --- a/modelscope/models/multi_modal/ofa/generate/sequence_generator.py +++ b/modelscope/models/multi_modal/ofa/generate/sequence_generator.py @@ -385,12 +385,7 @@ class SequenceGenerator(nn.Module): attn = torch.empty(bsz * beam_size, avg_attn_scores.size(1), max_len + 2).to(scores) - # print("+++++++ debug attention shape +++++++") - # print("attn", attn.shape) - # print("avg_attn_scores", avg_attn_scores.shape) attn[:, :, step + 1].copy_(avg_attn_scores) - # print("attn[:, :, step + 1]", attn[:, :, step + 1].shape) - # print("attn", attn.shape) scores = scores.type_as(lprobs) eos_bbsz_idx = torch.empty(0).to( @@ -404,8 +399,28 @@ class SequenceGenerator(nn.Module): self.search.set_src_lengths(src_lengths) if self.repeat_ngram_blocker is not None: - lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, - beam_size, step) + # process prefix_tokens + p_toks_len = prefix_tokens.ne(self.pad).sum( + dim=1) if prefix_tokens is not None else None + if p_toks_len is not None: + p_toks_len_beam = p_toks_len.unsqueeze(-1).repeat( + 1, beam_size).view(-1) + no_repeat_ngram_size = self.repeat_ngram_blocker.no_repeat_ngram_size + out_prefix = p_toks_len_beam < ( + step + no_repeat_ngram_size - 1) + else: + out_prefix = torch.ones(bsz * beam_size).bool() + ngram_blocker_tokens = tokens[out_prefix] + ngram_blocker_lprobs = lprobs[out_prefix] + ngram_blocker_bsz = torch.div( + out_prefix.sum(), beam_size, rounding_mode='trunc') + + lprobs[out_prefix] = self.repeat_ngram_blocker( + tokens=ngram_blocker_tokens, + lprobs=ngram_blocker_lprobs, + bsz=ngram_blocker_bsz, + beam_size=beam_size, + step=step) # Shape: (batch, cand_size) cand_scores, cand_indices, cand_beams = self.search.step( @@ -415,7 +430,6 @@ class SequenceGenerator(nn.Module): tokens[:, :step + 1], original_batch_idxs, ) - # cand_bbsz_idx contains beam indices for the top candidate # hypotheses, with a range of values: [0, bsz*beam_size), # and dimensions: [bsz, cand_size] @@ -671,7 +685,7 @@ class SequenceGenerator(nn.Module): cum_unfin.append(prev) cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx) - unfin_idx = bbsz_idx // beam_size + unfin_idx = torch.div(bbsz_idx, beam_size, rounding_mode='floor') sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx) # Create a set of "{sent}{unfin_idx}", where diff --git a/modelscope/models/multi_modal/ofa/modeling_ofa.py b/modelscope/models/multi_modal/ofa/modeling_ofa.py index 01cc02f9..0a7a2ce6 100755 --- a/modelscope/models/multi_modal/ofa/modeling_ofa.py +++ b/modelscope/models/multi_modal/ofa/modeling_ofa.py @@ -19,6 +19,7 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import torch +from packaging import version from torch import Tensor, nn from torch.nn import functional as F from transformers.activations import ACT2FN @@ -40,6 +41,8 @@ logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = 'ofa-base' _CONFIG_FOR_DOC = 'OFAConfig' _TOKENIZER_FOR_DOC = 'OFATokenizer' +TORCH_VERSION = version.parse(torch.__version__) +TORCH_MESH_GRID_WARNING_VERSION = version.parse('1.9.1') DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 @@ -51,6 +54,7 @@ OFA_PRETRAINED_MODEL_ARCHIVE_LIST = [ 'ofa-medium', 'ofa-base', 'ofa-large', + 'ofa-huge', ] try: @@ -114,7 +118,11 @@ def make_image_bucket_position(bucket_size, num_relative_distance): """ coords_h = torch.arange(bucket_size) coords_w = torch.arange(bucket_size) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + if TORCH_VERSION > TORCH_MESH_GRID_WARNING_VERSION: + coords = torch.stack( + torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww + else: + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - \ coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww diff --git a/modelscope/models/multi_modal/ofa/resnet.py b/modelscope/models/multi_modal/ofa/resnet.py index de6444ab..aad0f002 100644 --- a/modelscope/models/multi_modal/ofa/resnet.py +++ b/modelscope/models/multi_modal/ofa/resnet.py @@ -1,3 +1,17 @@ +# Copyright 2022 OFA-Sys Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch import torch.nn as nn diff --git a/modelscope/models/multi_modal/ofa/utils/__init__.py b/modelscope/models/multi_modal/ofa/utils/__init__.py index e69de29b..b937315b 100644 --- a/modelscope/models/multi_modal/ofa/utils/__init__.py +++ b/modelscope/models/multi_modal/ofa/utils/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. diff --git a/modelscope/models/multi_modal/ofa/utils/constant.py b/modelscope/models/multi_modal/ofa/utils/constant.py index 984da443..b3776f8f 100644 --- a/modelscope/models/multi_modal/ofa/utils/constant.py +++ b/modelscope/models/multi_modal/ofa/utils/constant.py @@ -3,11 +3,12 @@ from modelscope.outputs import OutputKeys from modelscope.utils.constant import Tasks OFA_TASK_KEY_MAPPING = { + Tasks.ocr_recognition: OutputKeys.TEXT, Tasks.image_captioning: OutputKeys.CAPTION, - Tasks.summarization: OutputKeys.TEXT, + Tasks.text_summarization: OutputKeys.TEXT, Tasks.visual_question_answering: OutputKeys.TEXT, Tasks.visual_grounding: OutputKeys.BOXES, - Tasks.text_classification: (OutputKeys.SCORES, OutputKeys.LABELS), + Tasks.text_classification: OutputKeys.LABELS, Tasks.image_classification: OutputKeys.LABELS, - Tasks.visual_entailment: (OutputKeys.SCORES, OutputKeys.LABELS), + Tasks.visual_entailment: OutputKeys.LABELS, } diff --git a/modelscope/models/multi_modal/ofa_for_all_tasks.py b/modelscope/models/multi_modal/ofa_for_all_tasks.py index 45bafde9..56d19ad8 100644 --- a/modelscope/models/multi_modal/ofa_for_all_tasks.py +++ b/modelscope/models/multi_modal/ofa_for_all_tasks.py @@ -1,8 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import math +import os import string +from functools import partial from os import path as osp -from typing import Any, Dict +from typing import Any, Callable, Dict, List, Optional, Union import json import torch.cuda @@ -10,7 +12,6 @@ import torch.nn.functional as F from modelscope.metainfo import Models from modelscope.models import TorchModel -from modelscope.models.base import Tensor from modelscope.models.builder import MODELS from modelscope.outputs import OutputKeys from modelscope.preprocessors.ofa.utils.collate import collate_tokens @@ -27,12 +28,13 @@ __all__ = ['OfaForAllTasks'] @MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa) +@MODELS.register_module(Tasks.ocr_recognition, module_name=Models.ofa) @MODELS.register_module(Tasks.visual_grounding, module_name=Models.ofa) @MODELS.register_module( Tasks.visual_question_answering, module_name=Models.ofa) @MODELS.register_module(Tasks.visual_entailment, module_name=Models.ofa) @MODELS.register_module(Tasks.image_classification, module_name=Models.ofa) -@MODELS.register_module(Tasks.summarization, module_name=Models.ofa) +@MODELS.register_module(Tasks.text_summarization, module_name=Models.ofa) @MODELS.register_module(Tasks.text_classification, module_name=Models.ofa) class OfaForAllTasks(TorchModel): @@ -65,10 +67,9 @@ class OfaForAllTasks(TorchModel): self.gen_type = self.cfg.model.get('gen_type', 'generation') assert self.gen_type in ['generation', 'traverse'], \ 'model.gen_type must be in ["generation", "traverse"]' - self._device = torch.device('cuda') if torch.cuda.is_available() \ - else torch.device('cpu') - self.eos_item = torch.LongTensor([self.tokenizer.eos_token_id - ]).to(self._device) + self.bos_item = torch.LongTensor([self.tokenizer.bos_token_id]) + self.pad_item = torch.LongTensor([self.tokenizer.pad_token_id]) + self.eos_item = torch.LongTensor([self.tokenizer.eos_token_id]) self.index2ans = {} self.ans2label_dict = {} self.load_ans2label() @@ -89,15 +90,17 @@ class OfaForAllTasks(TorchModel): self.val_masks_l = [] self.build_trie() sg_args['constraint_trie'] = self.constraint_trie - self.model.to(self._device) + else: + self.constraint_trie = None self.generator = sg.SequenceGenerator(**sg_args) inference_d = { 'generation': self._text_gen_inference, 'traverse': self._traverse_inference, } self.task_inference_mapping = { + Tasks.ocr_recognition: self._text_gen_inference, Tasks.image_captioning: self._text_gen_inference, - Tasks.summarization: self._text_gen_inference, + Tasks.text_summarization: self._text_gen_inference, Tasks.visual_grounding: self._visual_grounding_inference, Tasks.visual_entailment: inference_d[self.gen_type], Tasks.visual_question_answering: inference_d[self.gen_type], @@ -106,8 +109,16 @@ class OfaForAllTasks(TorchModel): } def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + input = move_to_device(input, self.model.device) + if self.model.training: + return self.model(**input['net_input']) + else: + return self.inference(input) + + def inference(self, input: Dict[str, Any]) -> Dict[str, Any]: ret = self.task_inference_mapping[self.cfg.task](input) - ret['samples'] = input['samples'] + if 'samples' in input: + ret['samples'] = input['samples'] for key in [ OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, OutputKeys.LABELS, OutputKeys.SCORES @@ -116,21 +127,33 @@ class OfaForAllTasks(TorchModel): ret[key] = None return ret - def postprocess(self, input: Dict[str, Tensor], - **kwargs) -> Dict[str, Tensor]: - if self.cfg.task == Tasks.image_captioning: - caption = [ - cap.translate(self.transtab).strip() - for cap in input[OutputKeys.CAPTION] - ] - input[OutputKeys.CAPTION] = caption + def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: + if not self.model.training and self.cfg.task == Tasks.image_captioning: + caption = input[OutputKeys.CAPTION] + result_l = list() + for cap in caption: + result_l.append(cap.translate(self.transtab).strip()) + input[OutputKeys.CAPTION] = result_l return input def _text_gen_inference(self, input): - input = move_to_device(input, self._device) - gen_output = self.generator.generate([self.model], input) - gen = [gen_output[i][0]['tokens'] for i in range(len(gen_output))] - result = self.tokenizer.batch_decode(gen, skip_special_tokens=True) + gen_outputs = self.generator.generate([self.model], + input, + prefix_tokens=input.get( + 'prefix_tokens', None)) + gen_l = list() + for idx, gen_out in enumerate(gen_outputs): + if len(gen_out) > 0: + decode_tokens = gen_out[0]['tokens'] + if 'prefix_tokens' in input: + prefix_len = input['prefix_tokens'][idx].ne( + self.pad_item.to(self.model.device)).sum() + decode_tokens = decode_tokens[prefix_len:] + gen_l.append(decode_tokens) + else: + gen_l.append('') + result = self.tokenizer.batch_decode(gen_l, skip_special_tokens=True) + result = [item.strip() for item in result] # text generation tasks have no score ret = {OFA_TASK_KEY_MAPPING[self.cfg.task]: result} if self.cfg.task.endswith('classification'): @@ -138,7 +161,6 @@ class OfaForAllTasks(TorchModel): return ret def _visual_grounding_inference(self, input): - input = move_to_device(input, self._device) gen_output = self.generator.generate([self.model], input) tokens = [gen_output[i][0]['tokens'] for i in range(len(gen_output))] region_coord_l = list() @@ -158,7 +180,6 @@ class OfaForAllTasks(TorchModel): } def _traverse_inference(self, input): - input = move_to_device(input, self._device) encoder_input = dict() for key in input['net_input'].keys(): encoder_input[key] = input['net_input'][key] @@ -168,13 +189,14 @@ class OfaForAllTasks(TorchModel): valid_size = len(val_ans) valid_tgt_items = [ torch.cat([ - torch.tensor(decoder_prompt[1:]), valid_answer, + torch.tensor(decoder_prompt[1:]).to('cpu'), valid_answer, self.eos_item ]) for decoder_prompt in input['decoder_prompts'] for valid_answer in val_ans ] valid_prev_items = [ - torch.cat([torch.tensor(decoder_prompt), valid_answer]) + torch.cat( + [torch.tensor(decoder_prompt).to('cpu'), valid_answer]) for decoder_prompt in input['decoder_prompts'] for valid_answer in val_ans ] @@ -182,19 +204,19 @@ class OfaForAllTasks(TorchModel): torch.cat([ torch.zeros( len(decoder_prompt) - 1, - valid_constraint_mask.size(1)).bool().to(self._device), + valid_constraint_mask.size(1)).bool(), valid_constraint_mask], dim=0) # yapf: disable for decoder_prompt in input['decoder_prompts'] # yapf: disable for valid_constraint_mask in val_masks] # yapf: disable valid_tgt = collate_tokens( valid_tgt_items, - pad_idx=self.tokenizer.pad_token_id).to(self._device) + pad_idx=self.tokenizer.pad_token_id).to(self.model.device) valid_prev_output = collate_tokens( valid_prev_items, - pad_idx=self.tokenizer.pad_token_id).to(self._device) + pad_idx=self.tokenizer.pad_token_id).to(self.model.device) val_masks = collate_tokens( valid_constraint_mask_items, - pad_idx=self.tokenizer.pad_token_id).to(self._device) + pad_idx=self.tokenizer.pad_token_id).to(self.model.device) new_encoder_out = { 'last_hidden_state': encoder_out['last_hidden_state'].repeat_interleave( @@ -269,10 +291,23 @@ class OfaForAllTasks(TorchModel): self.val_masks_l += [ constraint_mask_list[i:i + self.val_batch_size] ] - self.val_ans_l = move_to_device(self.val_ans_l, self._device) - self.val_masks_l = move_to_device(self.val_masks_l, self._device) def load_ans2label(self): if self.cfg.model.get('answer2label', None): - filename = osp.join(self.model_dir, self.cfg.model.answer2label) - self.ans2label_dict = json.load(open(filename)) + ans2label_file = osp.join(self.model_dir, + self.cfg.model.answer2label) + with open(ans2label_file, 'r') as reader: + self.ans2label_dict = json.load(reader) + + def save_pretrained(self, + target_folder: Union[str, os.PathLike], + save_checkpoint_names: Union[str, List[str]] = None, + save_function: Callable = None, + config: Optional[dict] = None, + **kwargs): + super(OfaForAllTasks, self). \ + save_pretrained(target_folder=target_folder, + save_checkpoint_names=save_checkpoint_names, + save_function=partial(save_function, with_meta=False), + config=config, + **kwargs) diff --git a/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py b/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py index b942e3fa..8110a0f7 100644 --- a/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py +++ b/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os from typing import Any, Dict diff --git a/modelscope/models/multi_modal/team/__init__.py b/modelscope/models/multi_modal/team/__init__.py index 0597040c..58bbdca5 100644 --- a/modelscope/models/multi_modal/team/__init__.py +++ b/modelscope/models/multi_modal/team/__init__.py @@ -1 +1,2 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. from .team_model import TEAMForMultiModalSimilarity diff --git a/modelscope/models/nlp/T5/__init__.py b/modelscope/models/nlp/T5/__init__.py index 7c1cea36..cb0921c6 100644 --- a/modelscope/models/nlp/T5/__init__.py +++ b/modelscope/models/nlp/T5/__init__.py @@ -1,13 +1,17 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .t5_for_text_generation import T5ForConditionalGeneration + from .backbone import T5Model + from .text2text_generation import T5ForConditionalGeneration else: _import_structure = { - 't5_for_text_generation': ['T5ForConditionalGeneration'], + 'backbone': ['T5Model'], + 'text2text_generation': ['T5ForConditionalGeneration'], } import sys diff --git a/modelscope/models/nlp/T5/modeling_t5.py b/modelscope/models/nlp/T5/backbone.py similarity index 73% rename from modelscope/models/nlp/T5/modeling_t5.py rename to modelscope/models/nlp/T5/backbone.py index da50741e..9a46d980 100644 --- a/modelscope/models/nlp/T5/modeling_t5.py +++ b/modelscope/models/nlp/T5/backbone.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,12 +22,8 @@ from typing import Optional, Tuple, Union import torch from torch import nn -from torch.nn import CrossEntropyLoss from torch.utils.checkpoint import checkpoint from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, - Seq2SeqLMOutput, Seq2SeqModelOutput) from transformers.modeling_utils import (PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer) @@ -36,30 +33,20 @@ from transformers.utils import (DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, from transformers.utils.model_parallel_utils import (assert_device_map, get_device_map) +from modelscope.metainfo import Models +from modelscope.models.base import Model, Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import (BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqModelOutput) +from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger -from .configuration_t5 import T5Config +from .configuration import T5Config logger = get_logger(__name__) -_CONFIG_FOR_DOC = 'T5Config' -_TOKENIZER_FOR_DOC = 'T5Tokenizer' -_CHECKPOINT_FOR_DOC = 't5-small' -#################################################### -# This dict contains ids and associated url -# for the pretrained weights provided with the models -#################################################### -T5_PRETRAINED_MODEL_ARCHIVE_LIST = [ - 't5-small', - 't5-base', - 't5-large', - 't5-3b', - 't5-11b', - # See all T5 models at https://huggingface.co/models?filter=t5 -] - - -#################################################### +################################################### # This is a conversion method from TF 1.0 to PyTorch # More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 #################################################### @@ -173,65 +160,6 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): return model -#################################################### -# PyTorch Models are constructed by sub-classing -# - torch.nn.Module for the layers and -# - PreTrainedModel for the models (it-self a sub-class of nn.Module) -#################################################### -PARALLELIZE_DOCSTRING = r""" - This is an experimental feature and is a subject to change at a moment's notice. - - Uses a device map to distribute attention modules of the model across several devices. If no device map is given, - it will evenly distribute blocks across all devices. - - Args: - device_map (`Dict[int, list]`, optional, defaults to None): - A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always - automatically mapped to the first device (for esoteric reasons). That means that the first device should - have fewer attention modules mapped to it than other devices. For reference, the t5 models have the - following number of attention modules: - - - t5-small: 6 - - t5-base: 12 - - t5-large: 24 - - t5-3b: 24 - - t5-11b: 24 - - Example: - - ```python - # Here is an example of a device map on a machine with 4 GPUs - # using t5-3b, which has a total of 24 attention modules: - model = T5ForConditionalGeneration.from_pretrained("t5-3b") - device_map = { - 0: [0, 1, 2], - 1: [3, 4, 5, 6, 7, 8, 9], - 2: [10, 11, 12, 13, 14, 15, 16], - 3: [17, 18, 19, 20, 21, 22, 23], - } - model.parallelize(device_map) - ``` -""" -DEPARALLELIZE_DOCSTRING = r""" - Moves the model to cpu from a model parallel state. - - Example: - - ```python - # On a 4 GPU machine with t5-3b: - model = T5ForConditionalGeneration.from_pretrained("t5-3b") - device_map = { - 0: [0, 1, 2], - 1: [3, 4, 5, 6, 7, 8, 9], - 2: [10, 11, 12, 13, 14, 15, 16], - 3: [17, 18, 19, 20, 21, 22, 23], - } - model.parallelize(device_map) # Splits the model across several devices - model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() - ``` -""" - - class T5LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -261,23 +189,6 @@ class T5LayerNorm(nn.Module): return self.weight * hidden_states -try: - from apex.normalization import FusedRMSNorm - - T5LayerNorm = FusedRMSNorm # noqa - - logger.info( - 'Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm' - ) -except ImportError: - # using the normal T5LayerNorm - pass -except Exception: - logger.warning( - 'discovered apex but it failed to load, falling back to T5LayerNorm') - pass - - class T5DenseReluDense(nn.Module): def __init__(self, config: T5Config): @@ -791,7 +702,7 @@ class T5Block(nn.Module): return outputs -class T5PreTrainedModel(PreTrainedModel): +class T5PreTrainedModel(TorchModel, PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. @@ -803,6 +714,10 @@ class T5PreTrainedModel(PreTrainedModel): is_parallelizable = True supports_gradient_checkpointing = True + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + @property def dummy_inputs(self): input_ids = torch.tensor(DUMMY_INPUTS) @@ -819,8 +734,7 @@ class T5PreTrainedModel(PreTrainedModel): factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, T5LayerNorm): module.weight.data.fill_(factor * 1.0) - elif isinstance(module, - (T5Model, T5ForConditionalGeneration, T5EncoderModel)): + elif isinstance(module, T5Model): # Mesh TensorFlow embeddings initialization See # https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) @@ -902,6 +816,36 @@ class T5PreTrainedModel(PreTrainedModel): return shifted_input_ids + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + Args: + kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the + label information. num_labels: An optional arg to tell the + model how many classes to initialize. + Method will call utils.parse_label_mapping + if num_labels not supplied. If num_labels is + not found, the model will use the default + setting (2 classes). + + Returns: + The loaded model, which is initialized by + transformers.PreTrainedModel.from_pretrained + """ + + model_dir = kwargs.get('model_dir', None) + if model_dir is None: + config = T5Config(**kwargs) + model = cls(config) + else: + model_kwargs = {} + model = super(Model, cls).from_pretrained( + pretrained_model_name_or_path=model_dir, **model_kwargs) + model.model_dir = model_dir + return model + class T5Stack(T5PreTrainedModel): @@ -926,8 +870,42 @@ class T5Stack(T5PreTrainedModel): self.device_map = None self.gradient_checkpointing = False - @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): + r""" + This is an experimental feature and is a subject to change at a + moment's notice. + + Uses a device map to distribute attention modules of the model + across several devices. If no device map is given, it will evenly + distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note + that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric + reasons). That means that the first device should have fewer + attention modules mapped to it than other devices. For + reference, the t5 models have the following number of + attention modules: + + - t5-small: 6 + - t5-base: 12 + - t5-large: 24 + - t5-3b: 24 + - t5-11b: 24 + + Example: + + ```python # Here is an example of a device map on a machine with 4 + GPUs # using t5-3b, which has a total of 24 attention modules: model + = T5ForConditionalGeneration.from_pretrained("t5-3b") device_map = { + 0: [0, 1, 2], 1: [3, 4, 5, 6, 7, 8, 9], 2: [10, 11, 12, 13, 14, + 15, 16], 3: [17, 18, 19, 20, 21, 22, 23], + } model.parallelize(device_map) ``` all of the parallelize methods + in this file are the same + + """ # Check validity of device_map self.device_map = ( get_device_map(len(self.block), range(torch.cuda.device_count())) @@ -948,8 +926,22 @@ class T5Stack(T5PreTrainedModel): # Set final layer norm to last device self.final_layer_norm = self.final_layer_norm.to(self.last_device) - @add_start_docstrings(PARALLELIZE_DOCSTRING) def deparallelize(self): + r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python # On a 4 GPU machine with t5-3b: model = + T5ForConditionalGeneration.from_pretrained("t5-3b") device_map = { + 0: [0, 1, 2], 1: [3, 4, 5, 6, 7, 8, 9], 2: [10, 11, 12, 13, 14, + 15, 16], 3: [17, 18, 19, 20, 21, 22, 23], + } model.parallelize(device_map) # Splits the model across several + devices model.deparallelize() # Put the model back on cpu and + cleans memory by calling torch.cuda.empty_cache() ``` + + all of the deparallelize methods in this file are the same + """ self.model_parallel = False self.device_map = None self.first_device = 'cpu' @@ -1199,7 +1191,20 @@ class T5Stack(T5PreTrainedModel): ) -T5_START_DOCSTRING = r""" +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and +`decoder_head_mask`. Currently, `decoder_head_mask` is set to copy `head_mask`, +but this feature is deprecated and will be removed in future versions. If you do +not want to use any `decoder_head_mask` now, please set `decoder_head_mask = +torch.ones(num_layers, num_heads)`. +""" + + +@MODELS.register_module(group_key=Tasks.backbone, module_name=Models.T5) +class T5Model(T5PreTrainedModel): + """The bare T5 Model transformer outputting raw hidden-states without any + specific head on top. The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by @@ -1224,10 +1229,99 @@ T5_START_DOCSTRING = r""" with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" + """ + _keys_to_ignore_on_load_missing = [ + r'encoder\.embed_tokens\.weight', + r'decoder\.embed_tokens\.weight', + ] + _keys_to_ignore_on_load_unexpected = [ + r'decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight', + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) -T5_INPUTS_DOCSTRING = r""" - Args: + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map( + len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None else device_map) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.model_parallel = True + + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to('cpu') + self.decoder = self.decoder.to('cpu') + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of + heads to prune in this layer} See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: + r""" + Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you should be able to pad the @@ -1343,244 +1437,84 @@ T5_INPUTS_DOCSTRING = r""" return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" + Returns: -T5_ENCODER_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. T5 is a model - with relative position embeddings so you should be able to pad the - inputs on both the right and the left. + Example: - Indices can be obtained using [`T5Tokenizer`]. See - [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] - for detail. + ```python >>> from transformers import T5Tokenizer, T5Model - To know more on how to prepare `input_ids` for pretraining take a - look a [T5 Training](./t5#training). - attention_mask (`torch.FloatTensor` of shape `(batch_size, - sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask - values selected in `[0, 1]`: + >>> tokenizer = T5Tokenizer.from_pretrained("t5-small") + >>> model = T5Model.from_pretrained("t5-small") - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + >>> ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 - [What are attention masks?](../glossary#attention-mask) - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, - num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask - values selected in `[0, 1]`: + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, - sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to - directly pass an embedded representation. This is useful if you want - more control over how to convert `input_ids` indices into associated - vectors than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention - layers. See `attentions` under returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See - `hidden_states` under returned tensors for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain - tuple. -""" + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] + if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] + if len(encoder_outputs) > 2 else None, + ) -# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask -__HEAD_MASK_WARNING_MSG = """ -The input argument `head_mask` was split into two arguments `head_mask` and -`decoder_head_mask`. Currently, `decoder_head_mask` is set to copy `head_mask`, -but this feature is deprecated and will be removed in future versions. If you do -not want to use any `decoder_head_mask` now, please set `decoder_head_mask = -torch.ones(num_layers, num_heads)`. -""" + hidden_states = encoder_outputs[0] + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to( + self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to( + self.decoder.first_device) - -@add_start_docstrings( - 'The bare T5 Model transformer outputting raw hidden-states without any specific head on top.', - T5_START_DOCSTRING, -) -class T5Model(T5PreTrainedModel): - _keys_to_ignore_on_load_missing = [ - r'encoder\.embed_tokens\.weight', - r'decoder\.embed_tokens\.weight', - ] - _keys_to_ignore_on_load_unexpected = [ - r'decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight', - ] - - def __init__(self, config: T5Config): - super().__init__(config) - self.shared = nn.Embedding(config.vocab_size, config.d_model) - - encoder_config = copy.deepcopy(config) - encoder_config.is_decoder = False - encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False - self.encoder = T5Stack(encoder_config, self.shared) - - decoder_config = copy.deepcopy(config) - decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False - decoder_config.num_layers = config.num_decoder_layers - self.decoder = T5Stack(decoder_config, self.shared) - - # Initialize weights and apply final processing - self.post_init() - - # Model parallel - self.model_parallel = False - self.device_map = None - - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def parallelize(self, device_map=None): - self.device_map = ( - get_device_map( - len(self.encoder.block), range(torch.cuda.device_count())) - if device_map is None else device_map) - assert_device_map(self.device_map, len(self.encoder.block)) - self.encoder.parallelize(self.device_map) - self.decoder.parallelize(self.device_map) - self.model_parallel = True - - @add_start_docstrings(DEPARALLELIZE_DOCSTRING) - def deparallelize(self): - self.encoder.deparallelize() - self.decoder.deparallelize() - self.encoder = self.encoder.to('cpu') - self.decoder = self.decoder.to('cpu') - self.model_parallel = False - self.device_map = None - torch.cuda.empty_cache() - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, new_embeddings): - self.shared = new_embeddings - self.encoder.set_input_embeddings(new_embeddings) - self.decoder.set_input_embeddings(new_embeddings) - - def get_encoder(self): - return self.encoder - - def get_decoder(self): - return self.decoder - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of - heads to prune in this layer} See base class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.encoder.layer[layer].attention.prune_heads(heads) - - @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.BoolTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - decoder_head_mask: Optional[torch.FloatTensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - decoder_inputs_embeds: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: - r""" - Returns: - - Example: - - ```python >>> from transformers import T5Tokenizer, T5Model - - >>> tokenizer = T5Tokenizer.from_pretrained("t5-small") - >>> model = T5Model.from_pretrained("t5-small") - - >>> input_ids = tokenizer( - ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" - >>> ).input_ids # Batch size 1 - >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 - - >>> # forward pass - >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) - >>> last_hidden_states = outputs.last_hidden_state - ```""" - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask - if head_mask is not None and decoder_head_mask is None: - if self.config.num_layers == self.config.num_decoder_layers: - warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) - decoder_head_mask = head_mask - - # Encode if needed (training, first prediction pass) - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] - if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] - if len(encoder_outputs) > 2 else None, - ) - - hidden_states = encoder_outputs[0] - if self.model_parallel: - torch.cuda.set_device(self.decoder.first_device) - # Set device for model parallelism - if self.model_parallel: - torch.cuda.set_device(self.decoder.first_device) - hidden_states = hidden_states.to(self.decoder.first_device) - if decoder_input_ids is not None: - decoder_input_ids = decoder_input_ids.to( - self.decoder.first_device) - if attention_mask is not None: - attention_mask = attention_mask.to(self.decoder.first_device) - if decoder_attention_mask is not None: - decoder_attention_mask = decoder_attention_mask.to( - self.decoder.first_device) - - # Decode - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - inputs_embeds=decoder_inputs_embeds, - past_key_values=past_key_values, - encoder_hidden_states=hidden_states, - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) if not return_dict: return decoder_outputs + encoder_outputs @@ -1595,409 +1529,3 @@ class T5Model(T5PreTrainedModel): encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) - - -@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", - T5_START_DOCSTRING) -class T5ForConditionalGeneration(T5PreTrainedModel): - _keys_to_ignore_on_load_missing = [ - r'encoder\.embed_tokens\.weight', - r'decoder\.embed_tokens\.weight', - r'lm_head\.weight', - ] - _keys_to_ignore_on_load_unexpected = [ - r'decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight', - ] - - def __init__(self, config: T5Config): - super().__init__(config) - self.model_dim = config.d_model - - self.shared = nn.Embedding(config.vocab_size, config.d_model) - - encoder_config = copy.deepcopy(config) - encoder_config.is_decoder = False - encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False - self.encoder = T5Stack(encoder_config, self.shared) - - decoder_config = copy.deepcopy(config) - decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False - decoder_config.num_layers = config.num_decoder_layers - self.decoder = T5Stack(decoder_config, self.shared) - - self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - # Model parallel - self.model_parallel = False - self.device_map = None - - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def parallelize(self, device_map=None): - self.device_map = ( - get_device_map( - len(self.encoder.block), range(torch.cuda.device_count())) - if device_map is None else device_map) - assert_device_map(self.device_map, len(self.encoder.block)) - self.encoder.parallelize(self.device_map) - self.decoder.parallelize(self.device_map) - self.lm_head = self.lm_head.to(self.decoder.first_device) - self.model_parallel = True - - @add_start_docstrings(DEPARALLELIZE_DOCSTRING) - def deparallelize(self): - self.encoder.deparallelize() - self.decoder.deparallelize() - self.encoder = self.encoder.to('cpu') - self.decoder = self.decoder.to('cpu') - self.lm_head = self.lm_head.to('cpu') - self.model_parallel = False - self.device_map = None - torch.cuda.empty_cache() - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, new_embeddings): - self.shared = new_embeddings - self.encoder.set_input_embeddings(new_embeddings) - self.decoder.set_input_embeddings(new_embeddings) - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def get_output_embeddings(self): - return self.lm_head - - def get_encoder(self): - return self.encoder - - def get_decoder(self): - return self.decoder - - @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.BoolTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - decoder_head_mask: Optional[torch.FloatTensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. - Indices should be in `[-100, 0, ..., config.vocab_size - 1]`. All - labels set to `-100` are ignored (masked), the loss is only computed - for labels in `[0, ..., config.vocab_size]` - - Returns: - - Examples: - - ```python >>> from transformers import T5Tokenizer, - T5ForConditionalGeneration - - >>> tokenizer = T5Tokenizer.from_pretrained("t5-small") - >>> model = T5ForConditionalGeneration.from_pretrained("t5-small") - - >>> # training - >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids - >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids - >>> outputs = model(input_ids=input_ids, labels=labels) - >>> loss = outputs.loss - >>> logits = outputs.logits - - >>> # inference - >>> input_ids = tokenizer( - ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" - >>> ).input_ids # Batch size 1 - >>> outputs = model.generate(input_ids) - >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) - >>> # studies have shown that owning a dog is good for you. - ```""" - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask - if head_mask is not None and decoder_head_mask is None: - if self.config.num_layers == self.config.num_decoder_layers: - warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) - decoder_head_mask = head_mask - - # Encode if needed (training, first prediction pass) - if encoder_outputs is None: - # Convert encoder inputs in embeddings if needed - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] - if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] - if len(encoder_outputs) > 2 else None, - ) - - hidden_states = encoder_outputs[0] - - if self.model_parallel: - torch.cuda.set_device(self.decoder.first_device) - - if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: - # get decoder inputs from shifting lm labels to the right - decoder_input_ids = self._shift_right(labels) - - # Set device for model parallelism - if self.model_parallel: - torch.cuda.set_device(self.decoder.first_device) - hidden_states = hidden_states.to(self.decoder.first_device) - if decoder_input_ids is not None: - decoder_input_ids = decoder_input_ids.to( - self.decoder.first_device) - if attention_mask is not None: - attention_mask = attention_mask.to(self.decoder.first_device) - if decoder_attention_mask is not None: - decoder_attention_mask = decoder_attention_mask.to( - self.decoder.first_device) - - # Decode - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - inputs_embeds=decoder_inputs_embeds, - past_key_values=past_key_values, - encoder_hidden_states=hidden_states, - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = decoder_outputs[0] - - # Set device for model parallelism - if self.model_parallel: - torch.cuda.set_device(self.encoder.first_device) - self.lm_head = self.lm_head.to(self.encoder.first_device) - sequence_output = sequence_output.to(self.lm_head.weight.device) - - if self.config.tie_word_embeddings: - # Rescale output before projecting on vocab See - # https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 - sequence_output = sequence_output * (self.model_dim**-0.5) - - lm_logits = self.lm_head(sequence_output) - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss(ignore_index=-100) - loss = loss_fct( - lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) - # TODO(thom): Add z_loss - # https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 - - if not return_dict: - output = (lm_logits, ) + decoder_outputs[1:] + encoder_outputs - return ((loss, ) + output) if loss is not None else output - - return Seq2SeqLMOutput( - loss=loss, - logits=lm_logits, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - def prepare_inputs_for_generation(self, - input_ids, - past=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs): - - # cut decoder_input_ids if past is used - if past is not None: - input_ids = input_ids[:, -1:] - - return { - 'decoder_input_ids': input_ids, - 'past_key_values': past, - 'encoder_outputs': encoder_outputs, - 'attention_mask': attention_mask, - 'head_mask': head_mask, - 'decoder_head_mask': decoder_head_mask, - 'cross_attn_head_mask': cross_attn_head_mask, - 'use_cache': use_cache, - } - - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): - return self._shift_right(labels) - - def _reorder_cache(self, past, beam_idx): - # if decoder past is not included in output - # speedy decoding is disabled and no need to reorder - if past is None: - logger.warning( - 'You might want to consider setting `use_cache=True` to speed up decoding' - ) - return past - - reordered_decoder_past = () - for layer_past_states in past: - # get the correct batch idx from layer past batch dim - # batch dim of `past` is at 2nd position - reordered_layer_past_states = () - for layer_past_state in layer_past_states: - # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select( - 0, beam_idx.to(layer_past_state.device)), ) - - assert reordered_layer_past_states[0].shape == layer_past_states[ - 0].shape - assert len(reordered_layer_past_states) == len(layer_past_states) - - reordered_decoder_past = reordered_decoder_past + ( - reordered_layer_past_states, ) - return reordered_decoder_past - - -@add_start_docstrings( - "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.", - T5_START_DOCSTRING, -) -class T5EncoderModel(T5PreTrainedModel): - authorized_missing_keys = [ - r'encoder\.embed_tokens\.weight', - ] - - def __init__(self, config: T5Config): - super().__init__(config) - self.shared = nn.Embedding(config.vocab_size, config.d_model) - - encoder_config = copy.deepcopy(config) - encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False - self.encoder = T5Stack(encoder_config, self.shared) - - # Initialize weights and apply final processing - self.post_init() - - # Model parallel - self.model_parallel = False - self.device_map = None - - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def parallelize(self, device_map=None): - self.device_map = ( - get_device_map( - len(self.encoder.block), range(torch.cuda.device_count())) - if device_map is None else device_map) - assert_device_map(self.device_map, len(self.encoder.block)) - self.encoder.parallelize(self.device_map) - self.model_parallel = True - - @add_start_docstrings(DEPARALLELIZE_DOCSTRING) - def deparallelize(self): - self.encoder.deparallelize() - self.encoder = self.encoder.to('cpu') - self.model_parallel = False - self.device_map = None - torch.cuda.empty_cache() - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, new_embeddings): - self.shared = new_embeddings - self.encoder.set_input_embeddings(new_embeddings) - - def get_encoder(self): - return self.encoder - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of - heads to prune in this layer} See base class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.encoder.layer[layer].attention.prune_heads(heads) - - @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: - r""" - Returns: - - Example: - - ```python - >>> from transformers import T5Tokenizer, T5EncoderModel - - >>> tokenizer = T5Tokenizer.from_pretrained("t5-small") - >>> model = T5EncoderModel.from_pretrained("t5-small") - >>> input_ids = tokenizer( - ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" - >>> ).input_ids # Batch size 1 - >>> outputs = model(input_ids=input_ids) - >>> last_hidden_states = outputs.last_hidden_state - ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - return encoder_outputs diff --git a/modelscope/models/nlp/T5/configuration_t5.py b/modelscope/models/nlp/T5/configuration.py similarity index 99% rename from modelscope/models/nlp/T5/configuration_t5.py rename to modelscope/models/nlp/T5/configuration.py index 117a6bc1..1f9a965e 100644 --- a/modelscope/models/nlp/T5/configuration_t5.py +++ b/modelscope/models/nlp/T5/configuration.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright 2020, The T5 Authors and HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/modelscope/models/nlp/T5/t5_for_text_generation.py b/modelscope/models/nlp/T5/t5_for_text_generation.py deleted file mode 100644 index 27f077d8..00000000 --- a/modelscope/models/nlp/T5/t5_for_text_generation.py +++ /dev/null @@ -1,56 +0,0 @@ -from typing import Optional, Tuple - -import torch - -from modelscope.metainfo import Models -from modelscope.models.base import Tensor, TorchModel -from modelscope.models.builder import MODELS -from modelscope.outputs import OutputKeys -from modelscope.utils.constant import Tasks -from .modeling_t5 import T5Config -from .modeling_t5 import T5ForConditionalGeneration as T5ForGeneration - - -@MODELS.register_module( - group_key=Tasks.text2text_generation, - module_name=Models.T5, -) -class T5ForConditionalGeneration(TorchModel): - - def __init__(self, model_dir=None, *args, **kwargs): - """initialize the text generation model from the `model_dir` path. - - Args: - model_dir (str): the model path. - model_cls (Optional[Any], optional): model loader, if None, use the - default loader to load model weights, by default None. - """ - super().__init__(model_dir, *args, **kwargs) - self.model = T5ForGeneration.from_pretrained(model_dir) - self.generate = self.model.generate - self.config = self.model.config - - def forward(self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.BoolTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - decoder_head_mask: Optional[torch.FloatTensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs): - return self.model.forward( - self, input_ids, attention_mask, decoder_input_ids, - decoder_attention_mask, head_mask, decoder_head_mask, - cross_attn_head_mask, encoder_outputs, past_key_values, - inputs_embeds, decoder_inputs_embeds, labels, use_cache, - output_attentions, output_hidden_states, return_dict, **kwargs) diff --git a/modelscope/models/nlp/T5/text2text_generation.py b/modelscope/models/nlp/T5/text2text_generation.py new file mode 100644 index 00000000..c4dcdfdb --- /dev/null +++ b/modelscope/models/nlp/T5/text2text_generation.py @@ -0,0 +1,455 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import warnings +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.utils.model_parallel_utils import (assert_device_map, + get_device_map) + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.outputs import BaseModelOutput, Seq2SeqLMOutput +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from .backbone import T5PreTrainedModel, T5Stack +from .configuration import T5Config + +logger = get_logger(__name__) + +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and +`decoder_head_mask`. Currently, `decoder_head_mask` is set to copy `head_mask`, +but this feature is deprecated and will be removed in future versions. If you do +not want to use any `decoder_head_mask` now, please set `decoder_head_mask = +torch.ones(num_layers, num_heads)`. +""" + + +@MODELS.register_module( + group_key=Tasks.text2text_generation, + module_name=Models.T5, +) +class T5ForConditionalGeneration(T5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r'encoder\.embed_tokens\.weight', + r'decoder\.embed_tokens\.weight', + r'lm_head\.weight', + ] + _keys_to_ignore_on_load_unexpected = [ + r'decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight', + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map( + len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None else device_map) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.decoder.first_device) + self.model_parallel = True + + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to('cpu') + self.decoder = self.decoder.to('cpu') + self.lm_head = self.lm_head.to('cpu') + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def forward(self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model + with relative position embeddings so you should be able to pad the + inputs on both the right and the left. + + Indices can be obtained using [`T5Tokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] + for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a + look a [T5 Training](./t5#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, + sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask + values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, + target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`T5Tokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] + for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + T5 uses the `pad_token_id` as the starting token for + `decoder_input_ids` generation. If `past_key_values` is used, + optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining + take a look at [T5 Training](./t5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, + target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in + `decoder_input_ids`. Causal mask will also be used by default. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, + num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the + encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or + `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the + decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or + `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in + the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, + `optional`: *attentions*) `last_hidden_state` of shape `(batch_size, + sequence_length, hidden_size)` is a sequence of hidden states at the + output of the last layer of the encoder. Used in the cross-attention + of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length + `config.n_layers` with each tuple having 4 tensors of shape + `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention + blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only + the last `decoder_input_ids` (those that don't have their past key + value states given to this model) of shape `(batch_size, 1)` instead + of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, + sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to + directly pass an embedded representation. This is useful if you want + more control over how to convert `input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, + target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to + directly pass an embedded representation. If `past_key_values` is + used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more + control over how to convert `decoder_input_ids` indices into + associated vectors than the model's internal embedding lookup + matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, + `decoder_inputs_embeds` takes the value of `inputs_embeds`. + + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned + and can be used to speed up decoding (see `past_key_values`). + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention + layers. See `attentions` under returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See + `hidden_states` under returned tensors for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain + tuple. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. + Indices should be in `[-100, 0, ..., config.vocab_size - 1]`. All + labels set to `-100` are ignored (masked), the loss is only computed + for labels in `[0, ..., config.vocab_size]` + + Returns: + + Examples: + + ```python >>> from transformers import T5Tokenizer, + T5ForConditionalGeneration + + >>> tokenizer = T5Tokenizer.from_pretrained("t5-small") + >>> model = T5ForConditionalGeneration.from_pretrained("t5-small") + + >>> # training + >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids + >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> # inference + >>> input_ids = tokenizer( + ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" + >>> ).input_ids # Batch size 1 + >>> outputs = model.generate(input_ids) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> # studies have shown that owning a dog is good for you. + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] + if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] + if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to( + self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to( + self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab See + # https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct( + lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + # TODO(thom): Add z_loss + # https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (lm_logits, ) + decoder_outputs[1:] + encoder_outputs + return ((loss, ) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation(self, + input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs): + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + 'decoder_input_ids': input_ids, + 'past_key_values': past, + 'encoder_outputs': encoder_outputs, + 'attention_mask': attention_mask, + 'head_mask': head_mask, + 'decoder_head_mask': decoder_head_mask, + 'cross_attn_head_mask': cross_attn_head_mask, + 'use_cache': use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past is None: + logger.warning( + 'You might want to consider setting `use_cache=True` to speed up decoding' + ) + return past + + reordered_decoder_past = () + for layer_past_states in past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select( + 0, beam_idx.to(layer_past_state.device)), ) + + assert reordered_layer_past_states[0].shape == layer_past_states[ + 0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + ( + reordered_layer_past_states, ) + return reordered_decoder_past diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index 8ef96365..d4562f10 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -4,77 +4,109 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .backbones import SbertModel - from .bart_for_text_error_correction import BartForTextErrorCorrection - from .bert_for_document_segmentation import BertForDocumentSegmentation - from .csanmt_for_translation import CsanmtForTranslation - from .heads import SequenceClassificationHead + from .bart import BartForTextErrorCorrection + from .bert import ( + BertForMaskedLM, + BertForTextRanking, + BertForSentenceEmbedding, + BertForSequenceClassification, + BertForTokenClassification, + BertForDocumentSegmentation, + BertModel, + BertConfig, + ) + from .csanmt import CsanmtForTranslation + from .deberta_v2 import DebertaV2ForMaskedLM, DebertaV2Model + from .gpt_neo import GPTNeoModel from .gpt3 import GPT3ForTextGeneration - from .masked_language import (StructBertForMaskedLM, VecoForMaskedLM, - BertForMaskedLM, DebertaV2ForMaskedLM) - from .ponet_for_masked_language import PoNetForMaskedLM - from .nncrf_for_named_entity_recognition import ( - TransformerCRFForNamedEntityRecognition, - LSTMCRFForNamedEntityRecognition) + from .heads import SequenceClassificationHead from .palm_v2 import PalmForTextGeneration - from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering - from .star_text_to_sql import StarForTextToSql - from .sequence_classification import (VecoForSequenceClassification, - SbertForSequenceClassification, - BertForSequenceClassification) - from .space import SpaceForDialogIntent - from .space import SpaceForDialogModeling - from .space import SpaceForDialogStateTracking - from .table_question_answering import TableQuestionAnswering - from .task_models import (FeatureExtractionModel, - InformationExtractionModel, - SequenceClassificationModel, - SingleBackboneTaskModelBase, - TokenClassificationModel) - from .token_classification import SbertForTokenClassification - from .sentence_embedding import SentenceEmbedding - from .passage_ranking import PassageRanking + from .ponet import PoNetForMaskedLM, PoNetModel, PoNetConfig + from .space import SpaceForDialogIntent, SpaceForDialogModeling, SpaceForDST + from .space_T_cn import TableQuestionAnswering + from .space_T_en import StarForTextToSql + from .structbert import ( + SbertForFaqQuestionAnswering, + SbertForMaskedLM, + SbertForSequenceClassification, + SbertForTokenClassification, + SbertTokenizer, + SbertModel, + SbertTokenizerFast, + ) from .T5 import T5ForConditionalGeneration + from .task_models import ( + FeatureExtractionModel, + InformationExtractionModel, + LSTMCRFForNamedEntityRecognition, + SequenceClassificationModel, + SingleBackboneTaskModelBase, + TaskModelForTextGeneration, + TokenClassificationModel, + TransformerCRFForNamedEntityRecognition, + ) + from .veco import (VecoConfig, VecoForMaskedLM, + VecoForSequenceClassification, + VecoForTokenClassification, VecoModel, VecoTokenizer, + VecoTokenizerFast) + else: _import_structure = { 'backbones': ['SbertModel'], - 'bart_for_text_error_correction': ['BartForTextErrorCorrection'], - 'bert_for_document_segmentation': ['BertForDocumentSegmentation'], - 'csanmt_for_translation': ['CsanmtForTranslation'], + 'bart': ['BartForTextErrorCorrection'], + 'csanmt': ['CsanmtForTranslation'], 'heads': ['SequenceClassificationHead'], 'gpt3': ['GPT3ForTextGeneration'], - 'masked_language': [ - 'StructBertForMaskedLM', 'VecoForMaskedLM', 'BertForMaskedLM', - 'DebertaV2ForMaskedLM' - ], - 'nncrf_for_named_entity_recognition': [ - 'TransformerCRFForNamedEntityRecognition', - 'LSTMCRFForNamedEntityRecognition' + 'structbert': [ + 'SbertForFaqQuestionAnswering', + 'SbertForMaskedLM', + 'SbertForSequenceClassification', + 'SbertForTokenClassification', + 'SbertTokenizer', + 'SbertTokenizerFast', + 'SbertModel', ], - 'ponet_for_masked_language': ['PoNetForMaskedLM'], - 'palm_v2': ['PalmForTextGeneration'], - 'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'], - 'star_text_to_sql': ['StarForTextToSql'], - 'sequence_classification': [ - 'VecoForSequenceClassification', 'SbertForSequenceClassification', - 'BertForSequenceClassification' + 'veco': [ + 'VecoConfig', + 'VecoForMaskedLM', + 'VecoForSequenceClassification', + 'VecoForTokenClassification', + 'VecoModel', + 'VecoTokenizer', + 'VecoTokenizerFast', ], - 'space': [ - 'SpaceForDialogIntent', 'SpaceForDialogModeling', - 'SpaceForDialogStateTracking' + 'bert': [ + 'BertForMaskedLM', + 'BertForTextRanking', + 'BertForSentenceEmbedding', + 'BertForSequenceClassification', + 'BertForTokenClassification', + 'BertForDocumentSegmentation', + 'BertModel', + 'BertConfig', ], + 'ponet': ['PoNetForMaskedLM', 'PoNetModel', 'PoNetConfig'], + 'palm_v2': ['PalmForTextGeneration'], + 'deberta_v2': ['DebertaV2ForMaskedLM', 'DebertaV2Model'], + 'space_T_en': ['StarForTextToSql'], + 'space_T_cn': ['TableQuestionAnswering'], + 'space': + ['SpaceForDialogIntent', 'SpaceForDialogModeling', 'SpaceForDST'], 'task_models': [ 'FeatureExtractionModel', 'InformationExtractionModel', + 'LSTMCRFForNamedEntityRecognition', + 'LSTMCRFForWordSegmentation', 'SequenceClassificationModel', 'SingleBackboneTaskModelBase', + 'TaskModelForTextGeneration', 'TokenClassificationModel', + 'TransformerCRFForNamedEntityRecognition', + 'TransformerCRFForWordSegmentation', ], - 'token_classification': ['SbertForTokenClassification'], - 'table_question_answering': ['TableQuestionAnswering'], 'sentence_embedding': ['SentenceEmbedding'], - 'passage_ranking': ['PassageRanking'], 'T5': ['T5ForConditionalGeneration'], + 'gpt_neo': ['GPTNeoModel'], } import sys diff --git a/modelscope/models/nlp/backbones/bert.py b/modelscope/models/nlp/backbones/bert.py deleted file mode 100644 index aa513944..00000000 --- a/modelscope/models/nlp/backbones/bert.py +++ /dev/null @@ -1,7 +0,0 @@ -from modelscope.metainfo import Models -from modelscope.models.builder import BACKBONES -from modelscope.models.nlp.bert import BertModel -from modelscope.utils.constant import Fields - -BACKBONES.register_module( - group_key=Fields.nlp, module_name=Models.bert, module_cls=BertModel) diff --git a/modelscope/models/nlp/backbones/structbert.py b/modelscope/models/nlp/backbones/structbert.py deleted file mode 100644 index 74735520..00000000 --- a/modelscope/models/nlp/backbones/structbert.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -from modelscope.metainfo import Models -from modelscope.models.base import TorchModel -from modelscope.models.builder import BACKBONES -from modelscope.models.nlp.structbert import SbertConfig -from modelscope.models.nlp.structbert import SbertModel as SbertModelTransform -from modelscope.utils.constant import Fields -from modelscope.utils.logger import get_logger - -logger = get_logger(__name__) - - -@BACKBONES.register_module(Fields.nlp, module_name=Models.structbert) -class SbertModel(TorchModel, SbertModelTransform): - - def __init__(self, model_dir=None, add_pooling_layer=True, **config): - """ - Args: - model_dir (str, optional): The model checkpoint directory. Defaults to None. - add_pooling_layer (bool, optional): to decide if pool the output from hidden layer. Defaults to True. - """ - config = SbertConfig(**config) - super().__init__(model_dir) - self.config = config - SbertModelTransform.__init__(self, config, add_pooling_layer) - - def extract_sequence_outputs(self, outputs): - return outputs['last_hidden_state'] - - def extract_pooled_outputs(self, outputs): - return outputs['pooler_output'] - - def forward(self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - **kwargs): - return SbertModelTransform.forward( - self, input_ids, attention_mask, token_type_ids, position_ids, - head_mask, inputs_embeds, encoder_hidden_states, - encoder_attention_mask, past_key_values, use_cache, - output_attentions, output_hidden_states, return_dict, **kwargs) diff --git a/modelscope/models/nlp/bart/__init__.py b/modelscope/models/nlp/bart/__init__.py new file mode 100644 index 00000000..31912efc --- /dev/null +++ b/modelscope/models/nlp/bart/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .text_error_correction import BartForTextErrorCorrection diff --git a/modelscope/models/nlp/bart_for_text_error_correction.py b/modelscope/models/nlp/bart/text_error_correction.py similarity index 100% rename from modelscope/models/nlp/bart_for_text_error_correction.py rename to modelscope/models/nlp/bart/text_error_correction.py diff --git a/modelscope/models/nlp/bert/__init__.py b/modelscope/models/nlp/bert/__init__.py index 705d9519..28a10f57 100644 --- a/modelscope/models/nlp/bert/__init__.py +++ b/modelscope/models/nlp/bert/__init__.py @@ -4,50 +4,32 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .modeling_bert import ( - BERT_PRETRAINED_MODEL_ARCHIVE_LIST, - BertForMaskedLM, - BertForMultipleChoice, - BertForNextSentencePrediction, - BertForPreTraining, - BertForQuestionAnswering, - BertForSequenceClassification, - BertForTokenClassification, + from .backbone import ( BertLayer, - BertLMHeadModel, BertModel, BertPreTrainedModel, - load_tf_weights_in_bert, ) - - from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig, BertOnnxConfig - from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer - from .tokenization_bert_fast import BertTokenizerFast - + from .configuration import BertConfig + from .fill_mask import BertForMaskedLM + from .text_ranking import BertForTextRanking + from .sentence_embedding import BertForSentenceEmbedding + from .text_classification import BertForSequenceClassification + from .token_classification import BertForTokenClassification + from .document_segmentation import BertForDocumentSegmentation else: _import_structure = { - 'configuration_bert': - ['BERT_PRETRAINED_CONFIG_ARCHIVE_MAP', 'BertConfig', 'BertOnnxConfig'], - 'tokenization_bert': - ['BasicTokenizer', 'BertTokenizer', 'WordpieceTokenizer'], + 'backbone': [ + 'BertModel', + 'BertPreTrainedModel', + ], + 'configuration': ['BertConfig'], + 'fill_mask': ['BertForMaskedLM'], + 'text_ranking': ['BertForTextRanking'], + 'sentence_embedding': ['BertForSentenceEmbedding'], + 'text_classification': ['BertForSequenceClassification'], + 'token_classification': ['BertForTokenClassification'], + 'document_segmentation': ['BertForDocumentSegmentation'], } - _import_structure['tokenization_bert_fast'] = ['BertTokenizerFast'] - - _import_structure['modeling_bert'] = [ - 'BERT_PRETRAINED_MODEL_ARCHIVE_LIST', - 'BertForMaskedLM', - 'BertForMultipleChoice', - 'BertForNextSentencePrediction', - 'BertForPreTraining', - 'BertForQuestionAnswering', - 'BertForSequenceClassification', - 'BertForTokenClassification', - 'BertLayer', - 'BertLMHeadModel', - 'BertModel', - 'BertPreTrainedModel', - 'load_tf_weights_in_bert', - ] import sys diff --git a/modelscope/models/nlp/bert/backbone.py b/modelscope/models/nlp/bert/backbone.py new file mode 100755 index 00000000..df0aebd2 --- /dev/null +++ b/modelscope/models/nlp/bert/backbone.py @@ -0,0 +1,952 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model. """ + +import math +import os +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.utils.checkpoint +from packaging import version +from torch import nn +from transformers.activations import ACT2FN +from transformers.modeling_utils import (PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer) + +from modelscope.metainfo import Models +from modelscope.models import Model, TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import (BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions) +from modelscope.utils.constant import Tasks +from modelscope.utils.hub import parse_label_mapping +from modelscope.utils.logger import get_logger +from .configuration import BertConfig + +logger = get_logger(__name__) + +_CONFIG_FOR_DOC = 'BertConfig' + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model + # variable name and be able to load any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and + # exported when serialized + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + self.register_buffer( + 'position_ids', + torch.arange(config.max_position_embeddings).expand((1, -1))) + if version.parse(torch.__version__) > version.parse('1.6.0'): + self.register_buffer( + 'token_type_ids', + torch.zeros(self.position_ids.size(), dtype=torch.long), + persistent=False, + ) + + def forward(self, + input_ids=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, + past_key_values_length:seq_length + + past_key_values_length] + + # Setting the token_type_ids to the registered buffer in constructor + # where it is all zeros, which usually occurs when its auto-generated, + # registered buffer helps users when tracing the model without passing + # token_type_ids, solves issue #5664 + if token_type_ids is None: + if hasattr(self, 'token_type_ids'): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand( + input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros( + input_shape, + dtype=torch.long, + device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == 'absolute': + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, 'embedding_size'): + raise ValueError( + f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention ' + f'heads ({config.num_attention_heads})') + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size + / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, 'position_embedding_type', 'absolute') + if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, + self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores( + self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores( + self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all + # cross attention key/value_states. Further calls to cross_attention + # layer can then reuse all cross-attention key/value_states (first + # "if" case) if uni-directional self-attention (decoder) save + # Tuple(torch.Tensor, torch.Tensor) of all previous decoder + # key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected + # key/value_states (third "elif" case) if encoder bi-directional + # self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + + if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == 'relative_key': + relative_position_scores = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == 'relative_key_query': + relative_position_scores_query = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + relative_position_scores_key = torch.einsum( + 'bhrd,lrd->bhlr', key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if output_attentions else (context_layer, ) + + if self.is_decoder: + outputs = outputs + (past_key_value, ) + return outputs + + +class BertSelfOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = BertSelfAttention( + config, position_embedding_type=position_embedding_type) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, + self.self.attention_head_size, self.pruned_heads) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len( + heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, + ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError( + f'{self} should be used as a decoder model if cross attention is added' + ) + self.crossattention = BertAttention( + config, position_embedding_type='absolute') + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[: + 2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[ + 1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, 'crossattention'): + raise ValueError( + f'If `encoder_hidden_states` are passed, {self} has to be instantiated ' + f'with cross-attention layers by setting `config.add_cross_attention=True`' + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[ + -2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[ + 1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward(self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output) + outputs = (layer_output, ) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value, ) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + ) if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' + ) + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, + output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_outputs[1], ) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + ( + layer_outputs[2], ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPreTrainedModel(TorchModel, PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface + for downloading and loading pretrained models. + """ + + config_class = BertConfig + base_model_prefix = 'bert' + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r'position_ids'] + + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BertEncoder): + module.gradient_checkpointing = value + + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + Args: + kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the label information. + num_labels: An optional arg to tell the model how many classes to initialize. + Method will call utils.parse_label_mapping if num_labels not supplied. + If num_labels is not found, the model will use the default setting (2 classes). + + Returns: + The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained + """ + + model_dir = kwargs.get('model_dir', None) + if model_dir is None: + config = BertConfig(**kwargs) + model = cls(config) + else: + model_kwargs = {} + label2id = kwargs.get('label2id', parse_label_mapping(model_dir)) + id2label = kwargs.get( + 'id2label', None if label2id is None else + {id: label + for label, id in label2id.items()}) + if id2label is not None and label2id is None: + label2id = {label: id for id, label in id2label.items()} + + num_labels = kwargs.get( + 'num_labels', None if label2id is None else len(label2id)) + if num_labels is not None: + model_kwargs['num_labels'] = num_labels + if label2id is not None: + model_kwargs['label2id'] = label2id + if id2label is not None: + model_kwargs['id2label'] = id2label + model = super(Model, cls).from_pretrained( + pretrained_model_name_or_path=model_dir, **model_kwargs) + model.model_dir = model_dir + return model + + +@MODELS.register_module(group_key=Tasks.backbone, module_name=Models.bert) +class BertModel(BertPreTrainedModel): + """The Bert Model transformer outputting raw hidden-states without any + specific head on top. + + This model inherits from [`PreTrainedModel`]. Check the superclass + documentation for the generic methods the library implements for all its + model (such as downloading or saving, resizing the input embeddings, pruning + heads etc.) + + This model is also a PyTorch + [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) + subclass. Use it as a regular PyTorch Module and refer to the PyTorch + documentation for all matter related to general usage and behavior. + + Parameters: + config ([`BertConfig`]): Model configuration class with all the + parameters of the model. + Initializing with a config file does not load the weights associated + with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model + weights. + + The model can behave as an encoder (with only self-attention) as well as a + decoder, in which case a layer of cross-attention is added between the + self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam + Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz + Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the + `is_decoder` argument of the configuration set to `True`. To be used in a + Seq2Seq model, the model needs to initialized with both `is_decoder` + argument and `add_cross_attention` set to `True`; an `encoder_hidden_states` + is then expected as an input to the forward pass. + + + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + @classmethod + def _instantiate(cls, model_dir=None, add_pooling_layer=True, **config): + config = BertConfig(**config) + model = cls(config, add_pooling_layer) + return model + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward(self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs): + r""" + Args: + input_ids (`torch.LongTensor` of shape `((batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`BertTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] + for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `((batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask + values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `((batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the + inputs. Indices are selected in `[0, 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `((batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position + embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, + num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask + values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `((batch_size, sequence_length, hidden_size)`, + *optional*): + Optionally, instead of passing `input_ids` you can choose to + directly pass an embedded representation. This is useful if you want + more control over how to convert `input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention + layers. See `attentions` under returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See + `hidden_states` under returned tensors for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a + plain tuple. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, + sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the + encoder. Used in the cross-attention if the model is configured as a + decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, + sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of + the encoder input. This mask is used in the cross-attention if the + model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length + `config.n_layers` with each tuple having 4 tensors of shape + `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention + blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only + the last `decoder_input_ids` (those that don't have their past key + value states given to this model) of shape `(batch_size, 1)` instead + of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned + and can be used to speed up decoding (see `past_key_values`). + Others (**kwargs) + some additional parameters might passed in from upstream pipeline, + which not influence the results. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time' + ) + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError( + 'You have to specify either input_ids or inputs_embeds') + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[ + 2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), + device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, 'token_type_ids'): + buffered_token_type_ids = self.embeddings.token_type_ids[:, : + seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand( + batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size( + ) + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler( + sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + def extract_sequence_outputs(self, outputs): + return outputs['last_hidden_state'] + + def extract_pooled_outputs(self, outputs): + return outputs['pooler_output'] diff --git a/modelscope/models/nlp/bert/configuration_bert.py b/modelscope/models/nlp/bert/configuration.py similarity index 99% rename from modelscope/models/nlp/bert/configuration_bert.py rename to modelscope/models/nlp/bert/configuration.py index 2c9293ec..1e2cef95 100644 --- a/modelscope/models/nlp/bert/configuration_bert.py +++ b/modelscope/models/nlp/bert/configuration.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # diff --git a/modelscope/models/nlp/bert_for_document_segmentation.py b/modelscope/models/nlp/bert/document_segmentation.py similarity index 99% rename from modelscope/models/nlp/bert_for_document_segmentation.py rename to modelscope/models/nlp/bert/document_segmentation.py index dfa57597..b46c77e4 100644 --- a/modelscope/models/nlp/bert_for_document_segmentation.py +++ b/modelscope/models/nlp/bert/document_segmentation.py @@ -2,6 +2,7 @@ from typing import Any, Dict +import torch from torch import nn from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import TokenClassifierOutput diff --git a/modelscope/models/nlp/bert/fill_mask.py b/modelscope/models/nlp/bert/fill_mask.py new file mode 100644 index 00000000..4f81f62d --- /dev/null +++ b/modelscope/models/nlp/bert/fill_mask.py @@ -0,0 +1,299 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionFillMaskModelOutput +from modelscope.utils import logger as logging +from modelscope.utils.constant import Tasks +from .backbone import BertModel, BertPreTrainedModel +from .configuration import BertConfig + +logger = logging.get_logger(__name__) + + +class BertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +@MODELS.register_module(Tasks.fill_mask, module_name=Models.bert) +class BertForMaskedLM(BertPreTrainedModel): + r"""Bert Model with a `language modeling` head on top. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Preprocessor: + This is the fill_mask model of Structbert, the preprocessor of this model + is `modelscope.preprocessors.NLPPreprocessor`. + + Parameters: + config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with + all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. + """ + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config: BertConfig, **kwargs): + super().__init__(config) + + if config.is_decoder: + logger.warning( + 'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for ' + 'bi-directional self-attention.') + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, + *optional*): + Labels for computing the masked language modeling loss. Indices + should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` + docstring) Tokens with indices set to `-100` are ignored (masked), + the loss is only computed for the tokens with labels in `[0, ..., + config.vocab_size]` + + Returns: + Returns `modelscope.outputs.AttentionFillMaskModelOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_bert_backbone_base_std') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_bert_backbone_base_std') + >>> print(model(**preprocessor(('This is a test', 'This is also a test')))) + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + + if not return_dict: + output = (prediction_scores, ) + outputs[2:] + return ((masked_lm_loss, ) + + output) if masked_lm_loss is not None else output + + return AttentionFillMaskModelOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + input_ids=input_ids, + ) + + def prepare_inputs_for_generation(self, + input_ids, + attention_mask=None, + **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + if self.config.pad_token_id is None: + raise ValueError('The PAD token should be defined for generation') + + padding_mask = attention_mask.new_zeros((attention_mask.shape[0], 1)) + attention_mask = torch.cat([attention_mask, padding_mask], dim=-1) + dummy_token = torch.full((effective_batch_size, 1), + self.config.pad_token_id, + dtype=torch.long, + device=input_ids.device) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {'input_ids': input_ids, 'attention_mask': attention_mask} diff --git a/modelscope/models/nlp/bert/modeling_bert.py b/modelscope/models/nlp/bert/modeling_bert.py deleted file mode 100755 index f8fd5994..00000000 --- a/modelscope/models/nlp/bert/modeling_bert.py +++ /dev/null @@ -1,2040 +0,0 @@ -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch BERT model. """ - -import math -import os -import warnings -from dataclasses import dataclass -from typing import Optional, Tuple - -import torch -import torch.utils.checkpoint -from packaging import version -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers.activations import ACT2FN -from transformers.file_utils import (ModelOutput, add_start_docstrings, - add_start_docstrings_to_model_forward, - replace_return_docstrings) -from transformers.modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, - BaseModelOutputWithPoolingAndCrossAttentions, - CausalLMOutputWithCrossAttentions, MaskedLMOutput, - MultipleChoiceModelOutput, NextSentencePredictorOutput, - QuestionAnsweringModelOutput, SequenceClassifierOutput, - TokenClassifierOutput) -from transformers.modeling_utils import (PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer) - -from modelscope.models.base import TorchModel -from modelscope.utils.logger import get_logger -from .configuration_bert import BertConfig - -logger = get_logger(__name__) - -_CONFIG_FOR_DOC = 'BertConfig' - - -def load_tf_weights_in_bert(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - 'Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see ' - 'https://www.tensorflow.org/install/ for installation instructions.' - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f'Converting TensorFlow checkpoint from {tf_path}') - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f'Loading TF weight {name} with shape {shape}') - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split('/') - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any(n in [ - 'adam_v', 'adam_m', 'AdamWeightDecayOptimizer', - 'AdamWeightDecayOptimizer_1', 'global_step' - ] for n in name): - logger.info(f"Skipping {'/'.join(name)}") - continue - pointer = model - for m_name in name: - if re.fullmatch(r'[A-Za-z]+_\d+', m_name): - scope_names = re.split(r'_(\d+)', m_name) - else: - scope_names = [m_name] - if scope_names[0] == 'kernel' or scope_names[0] == 'gamma': - pointer = getattr(pointer, 'weight') - elif scope_names[0] == 'output_bias' or scope_names[0] == 'beta': - pointer = getattr(pointer, 'bias') - elif scope_names[0] == 'output_weights': - pointer = getattr(pointer, 'weight') - elif scope_names[0] == 'squad': - pointer = getattr(pointer, 'classifier') - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == '_embeddings': - pointer = getattr(pointer, 'weight') - elif m_name == 'kernel': - array = np.transpose(array) - try: - if pointer.shape != array.shape: - raise ValueError( - f'Pointer shape {pointer.shape} and array shape {array.shape} mismatched' - ) - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f'Initialize PyTorch weight {name}') - pointer.data = torch.from_numpy(array) - return model - - -class BertEmbeddings(nn.Module): - """Construct the embeddings from word, position and token_type embeddings.""" - - def __init__(self, config): - super().__init__() - self.word_embeddings = nn.Embedding( - config.vocab_size, - config.hidden_size, - padding_idx=config.pad_token_id) - self.position_embeddings = nn.Embedding(config.max_position_embeddings, - config.hidden_size) - self.token_type_embeddings = nn.Embedding(config.type_vocab_size, - config.hidden_size) - - # self.LayerNorm is not snake-cased to stick with TensorFlow model - # variable name and be able to load any TensorFlow checkpoint file - self.LayerNorm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - # position_ids (1, len position emb) is contiguous in memory and - # exported when serialized - self.position_embedding_type = getattr(config, - 'position_embedding_type', - 'absolute') - self.register_buffer( - 'position_ids', - torch.arange(config.max_position_embeddings).expand((1, -1))) - if version.parse(torch.__version__) > version.parse('1.6.0'): - self.register_buffer( - 'token_type_ids', - torch.zeros(self.position_ids.size(), dtype=torch.long), - persistent=False, - ) - - def forward(self, - input_ids=None, - token_type_ids=None, - position_ids=None, - inputs_embeds=None, - past_key_values_length=0): - if input_ids is not None: - input_shape = input_ids.size() - else: - input_shape = inputs_embeds.size()[:-1] - - seq_length = input_shape[1] - - if position_ids is None: - position_ids = self.position_ids[:, - past_key_values_length:seq_length - + past_key_values_length] - - # Setting the token_type_ids to the registered buffer in constructor - # where it is all zeros, which usually occurs when its auto-generated, - # registered buffer helps users when tracing the model without passing - # token_type_ids, solves issue #5664 - if token_type_ids is None: - if hasattr(self, 'token_type_ids'): - buffered_token_type_ids = self.token_type_ids[:, :seq_length] - buffered_token_type_ids_expanded = buffered_token_type_ids.expand( - input_shape[0], seq_length) - token_type_ids = buffered_token_type_ids_expanded - else: - token_type_ids = torch.zeros( - input_shape, - dtype=torch.long, - device=self.position_ids.device) - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == 'absolute': - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings) - return embeddings - - -class BertSelfAttention(nn.Module): - - def __init__(self, config, position_embedding_type=None): - super().__init__() - if config.hidden_size % config.num_attention_heads != 0 and not hasattr( - config, 'embedding_size'): - raise ValueError( - f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention ' - f'heads ({config.num_attention_heads})') - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size - / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = nn.Linear(config.hidden_size, self.all_head_size) - self.key = nn.Linear(config.hidden_size, self.all_head_size) - self.value = nn.Linear(config.hidden_size, self.all_head_size) - - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, 'position_embedding_type', 'absolute') - if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding( - 2 * config.max_position_embeddings - 1, - self.attention_head_size) - - self.is_decoder = config.is_decoder - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, - self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - ): - mixed_query_layer = self.query(hidden_states) - - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores( - self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores( - self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all - # cross attention key/value_states. Further calls to cross_attention - # layer can then reuse all cross-attention key/value_states (first - # "if" case) if uni-directional self-attention (decoder) save - # Tuple(torch.Tensor, torch.Tensor) of all previous decoder - # key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected - # key/value_states (third "elif" case) if encoder bi-directional - # self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, - key_layer.transpose(-1, -2)) - - if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange( - seq_length, dtype=torch.long, - device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange( - seq_length, dtype=torch.long, - device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - positional_embedding = self.distance_embedding( - distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to( - dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == 'relative_key': - relative_position_scores = torch.einsum( - 'bhld,lrd->bhlr', query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == 'relative_key_query': - relative_position_scores_query = torch.einsum( - 'bhld,lrd->bhlr', query_layer, positional_embedding) - relative_position_scores_key = torch.einsum( - 'bhrd,lrd->bhlr', key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - - attention_scores = attention_scores / math.sqrt( - self.attention_head_size) - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in BertModel forward() function) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + ( - self.all_head_size, ) - context_layer = context_layer.view(*new_context_layer_shape) - - outputs = (context_layer, - attention_probs) if output_attentions else (context_layer, ) - - if self.is_decoder: - outputs = outputs + (past_key_value, ) - return outputs - - -class BertSelfOutput(nn.Module): - - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class BertAttention(nn.Module): - - def __init__(self, config, position_embedding_type=None): - super().__init__() - self.self = BertSelfAttention( - config, position_embedding_type=position_embedding_type) - self.output = BertSelfOutput(config) - self.pruned_heads = set() - - def prune_heads(self, heads): - if len(heads) == 0: - return - heads, index = find_pruneable_heads_and_indices( - heads, self.self.num_attention_heads, - self.self.attention_head_size, self.pruned_heads) - - # Prune linear layers - self.self.query = prune_linear_layer(self.self.query, index) - self.self.key = prune_linear_layer(self.self.key, index) - self.self.value = prune_linear_layer(self.self.value, index) - self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) - - # Update hyper params and store pruned heads - self.self.num_attention_heads = self.self.num_attention_heads - len( - heads) - self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads - self.pruned_heads = self.pruned_heads.union(heads) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - ): - self_outputs = self.self( - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output, - ) + self_outputs[1:] # add attentions if we output them - return outputs - - -class BertIntermediate(nn.Module): - - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.intermediate_size) - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = ACT2FN[config.hidden_act] - else: - self.intermediate_act_fn = config.hidden_act - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - return hidden_states - - -class BertOutput(nn.Module): - - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class BertLayer(nn.Module): - - def __init__(self, config): - super().__init__() - self.chunk_size_feed_forward = config.chunk_size_feed_forward - self.seq_len_dim = 1 - self.attention = BertAttention(config) - self.is_decoder = config.is_decoder - self.add_cross_attention = config.add_cross_attention - if self.add_cross_attention: - if not self.is_decoder: - raise ValueError( - f'{self} should be used as a decoder model if cross attention is added' - ) - self.crossattention = BertAttention( - config, position_embedding_type='absolute') - self.intermediate = BertIntermediate(config) - self.output = BertOutput(config) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - ): - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[: - 2] if past_key_value is not None else None - self_attention_outputs = self.attention( - hidden_states, - attention_mask, - head_mask, - output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, - ) - attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[ - 1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None - if self.is_decoder and encoder_hidden_states is not None: - if not hasattr(self, 'crossattention'): - raise ValueError( - f'If `encoder_hidden_states` are passed, {self} has to be instantiated ' - f'with cross-attention layers by setting `config.add_cross_attention=True`' - ) - - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[ - -2:] if past_key_value is not None else None - cross_attention_outputs = self.crossattention( - attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[ - 1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - - layer_output = apply_chunking_to_forward(self.feed_forward_chunk, - self.chunk_size_feed_forward, - self.seq_len_dim, - attention_output) - outputs = (layer_output, ) + outputs - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value, ) - - return outputs - - def feed_forward_chunk(self, attention_output): - intermediate_output = self.intermediate(attention_output) - layer_output = self.output(intermediate_output, attention_output) - return layer_output - - -class BertEncoder(nn.Module): - - def __init__(self, config): - super().__init__() - self.config = config - self.layer = nn.ModuleList( - [BertLayer(config) for _ in range(config.num_hidden_layers)]) - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ): - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = ( - ) if output_attentions and self.config.add_cross_attention else None - - next_decoder_cache = () if use_cache else None - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states, ) - - layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[ - i] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - if use_cache: - logger.warning( - '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' - ) - use_cache = False - - def create_custom_forward(module): - - def custom_forward(*inputs): - return module(*inputs, past_key_value, - output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - - hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1], ) - if output_attentions: - all_self_attentions = all_self_attentions + ( - layer_outputs[1], ) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + ( - layer_outputs[2], ) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states, ) - - if not return_dict: - return tuple(v for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] if v is not None) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - - -class BertPooler(nn.Module): - - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.activation = nn.Tanh() - - def forward(self, hidden_states): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(first_token_tensor) - pooled_output = self.activation(pooled_output) - return pooled_output - - -class BertPredictionHeadTransform(nn.Module): - - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - if isinstance(config.hidden_act, str): - self.transform_act_fn = ACT2FN[config.hidden_act] - else: - self.transform_act_fn = config.hidden_act - self.LayerNorm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps) - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.LayerNorm(hidden_states) - return hidden_states - - -class BertLMPredictionHead(nn.Module): - - def __init__(self, config): - super().__init__() - self.transform = BertPredictionHeadTransform(config) - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = nn.Linear( - config.hidden_size, config.vocab_size, bias=False) - - self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def forward(self, hidden_states): - hidden_states = self.transform(hidden_states) - hidden_states = self.decoder(hidden_states) - return hidden_states - - -class BertOnlyMLMHead(nn.Module): - - def __init__(self, config): - super().__init__() - self.predictions = BertLMPredictionHead(config) - - def forward(self, sequence_output): - prediction_scores = self.predictions(sequence_output) - return prediction_scores - - -class BertOnlyNSPHead(nn.Module): - - def __init__(self, config): - super().__init__() - self.seq_relationship = nn.Linear(config.hidden_size, 2) - - def forward(self, pooled_output): - seq_relationship_score = self.seq_relationship(pooled_output) - return seq_relationship_score - - -class BertPreTrainingHeads(nn.Module): - - def __init__(self, config): - super().__init__() - self.predictions = BertLMPredictionHead(config) - self.seq_relationship = nn.Linear(config.hidden_size, 2) - - def forward(self, sequence_output, pooled_output): - prediction_scores = self.predictions(sequence_output) - seq_relationship_score = self.seq_relationship(pooled_output) - return prediction_scores, seq_relationship_score - - -class BertPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface - for downloading and loading pretrained models. - """ - - config_class = BertConfig - load_tf_weights = load_tf_weights_in_bert - base_model_prefix = 'bert' - supports_gradient_checkpointing = True - _keys_to_ignore_on_load_missing = [r'position_ids'] - - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_( - mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_( - mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, BertEncoder): - module.gradient_checkpointing = value - - -@dataclass -class BertForPreTrainingOutput(ModelOutput): - """ - Output type of [`BertForPreTraining`]. - - Args: - loss (*optional*, returned when `labels` is provided, - `torch.FloatTensor` of shape `(1,)`): - Total loss as the sum of the masked language modeling loss and the - next sequence prediction (classification) loss. - prediction_logits (`torch.FloatTensor` of shape `(batch_size, - sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each - vocabulary token before SoftMax). - seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, - 2)`): - Prediction scores of the next sequence prediction (classification) - head (scores of True/False continuation before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when - `output_hidden_states=True` is passed or when - `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + - one for the output of each layer) of shape `(batch_size, - sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the - initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when - `output_attentions=True` is passed or when - `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape - `(batch_size, num_heads, sequence_length, sequence_length)`. - - Attentions weights after the attention softmax, used to compute the - weighted average in the self-attention heads. - """ - - loss: Optional[torch.FloatTensor] = None - prediction_logits: torch.FloatTensor = None - seq_relationship_logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - - -BERT_START_DOCSTRING = r""" - - This model inherits from [`PreTrainedModel`]. Check the superclass - documentation for the generic methods the library implements for all its - model (such as downloading or saving, resizing the input embeddings, pruning - heads etc.) - - This model is also a PyTorch - [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) - subclass. Use it as a regular PyTorch Module and refer to the PyTorch - documentation for all matter related to general usage and behavior. - - Parameters: - config ([`BertConfig`]): Model configuration class with all the - parameters of the model. - Initializing with a config file does not load the weights associated - with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model - weights. -""" - -BERT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`BertTokenizer`]. See - [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] - for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask - values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the - inputs. Indices are selected in `[0, 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`torch.LongTensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position - embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, - num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask - values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, - *optional*): - Optionally, instead of passing `input_ids` you can choose to - directly pass an embedded representation. This is useful if you want - more control over how to convert `input_ids` indices into associated - vectors than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention - layers. See `attentions` under returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See - `hidden_states` under returned tensors for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~file_utils.ModelOutput`] instead of a - plain tuple. -""" - - -@add_start_docstrings( - 'The bare Bert Model transformer outputting raw hidden-states without any specific head on top.', - BERT_START_DOCSTRING, -) -class BertModel(BertPreTrainedModel): - """ - - The model can behave as an encoder (with only self-attention) as well as a - decoder, in which case a layer of cross-attention is added between the - self-attention layers, following the architecture described in [Attention is - all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam - Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz - Kaiser and Illia Polosukhin. - - To behave as an decoder the model needs to be initialized with the - `is_decoder` argument of the configuration set to `True`. To be used in a - Seq2Seq model, the model needs to initialized with both `is_decoder` - argument and `add_cross_attention` set to `True`; an `encoder_hidden_states` - is then expected as an input to the forward pass. - """ - - def __init__(self, config, add_pooling_layer=True): - super().__init__(config) - self.embeddings = BertEmbeddings(config) - self.encoder = BertEncoder(config) - - self.pooler = BertPooler(config) if add_pooling_layer else None - - # Initialize weights and apply final processing - self.post_init() - - @classmethod - def _instantiate(cls, model_dir=None, add_pooling_layer=True, **config): - config = BertConfig(**config) - model = cls(config, add_pooling_layer) - return model - - def get_input_embeddings(self): - return self.embeddings.word_embeddings - - def set_input_embeddings(self, value): - self.embeddings.word_embeddings = value - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.encoder.layer[layer].attention.prune_heads(heads) - - @add_start_docstrings_to_model_forward( - BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - def forward(self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - **kwargs): - r""" - encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, - sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the - encoder. Used in the cross-attention if the model is configured as a - decoder. - encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, - sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of - the encoder input. This mask is used in the cross-attention if the - model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))` of length - `config.n_layers` with each tuple having 4 tensors of shape - `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention - blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only - the last `decoder_input_ids` (those that don't have their past key - value states given to this model) of shape `(batch_size, 1)` instead - of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned - and can be used to speed up decoding (see `past_key_values`). - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else - self.config.output_hidden_states) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if self.config.is_decoder: - use_cache = use_cache if use_cache is not None else self.config.use_cache - else: - use_cache = False - - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - 'You cannot specify both input_ids and inputs_embeds at the same time' - ) - elif input_ids is not None: - input_shape = input_ids.size() - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError( - 'You have to specify either input_ids or inputs_embeds') - - batch_size, seq_length = input_shape - device = input_ids.device if input_ids is not None else inputs_embeds.device - - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[ - 2] if past_key_values is not None else 0 - - if attention_mask is None: - attention_mask = torch.ones( - ((batch_size, seq_length + past_key_values_length)), - device=device) - - if token_type_ids is None: - if hasattr(self.embeddings, 'token_type_ids'): - buffered_token_type_ids = self.embeddings.token_type_ids[:, : - seq_length] - buffered_token_type_ids_expanded = buffered_token_type_ids.expand( - batch_size, seq_length) - token_type_ids = buffered_token_type_ids_expanded - else: - token_type_ids = torch.zeros( - input_shape, dtype=torch.long, device=device) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( - attention_mask, input_shape, device) - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size( - ) - encoder_hidden_shape = (encoder_batch_size, - encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones( - encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask( - encoder_attention_mask) - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, - self.config.num_hidden_layers) - - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) - encoder_outputs = self.encoder( - embedding_output, - attention_mask=extended_attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - sequence_output = encoder_outputs[0] - pooled_output = self.pooler( - sequence_output) if self.pooler is not None else None - - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] - - return BaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - past_key_values=encoder_outputs.past_key_values, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, - ) - - def extract_sequence_outputs(self, outputs): - return outputs['last_hidden_state'] - - def extract_pooled_outputs(self, outputs): - return outputs['pooler_output'] - - -@add_start_docstrings( - """ - Bert Model with two heads on top as done during the pretraining: a `masked - language modeling` head and a `next sentence prediction (classification)` - head. - """, - BERT_START_DOCSTRING, -) -class BertForPreTraining(BertPreTrainedModel): - - def __init__(self, config): - super().__init__(config) - - self.bert = BertModel(config) - self.cls = BertPreTrainingHeads(config) - - # Initialize weights and apply final processing - self.post_init() - - def get_output_embeddings(self): - return self.cls.predictions.decoder - - def set_output_embeddings(self, new_embeddings): - self.cls.predictions.decoder = new_embeddings - - @add_start_docstrings_to_model_forward( - BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - @replace_return_docstrings( - output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - next_sentence_label=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, - *optional*): - Labels for computing the masked language modeling loss. Indices - should be in `[-100, 0, ..., config.vocab_size]` (see - `input_ids` docstring) Tokens with indices set to `-100` are - ignored (masked), the loss is only computed for the tokens with - labels in `[0, ..., config.vocab_size]` - next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, - *optional*): - Labels for computing the next sequence prediction - (classification) loss. Input should be a sequence pair (see - `input_ids` docstring) Indices should be in `[0, 1]`: - - - 0 indicates sequence B is a continuation of sequence A, - - 1 indicates sequence B is a random sequence. - kwargs (`Dict[str, any]`, optional, defaults to *{}*): - Used to hide legacy arguments that have been deprecated. - - Returns: - - Example: - - ```python >>> from transformers import BertTokenizer, BertForPreTraining - >>> import torch - - >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') - >>> model = BertForPreTraining.from_pretrained('bert-base-uncased') - - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") - >>> outputs = model(**inputs) - - >>> prediction_logits = outputs.prediction_logits - >>> seq_relationship_logits = outputs.seq_relationship_logits - ``` - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output, pooled_output = outputs[:2] - prediction_scores, seq_relationship_score = self.cls( - sequence_output, pooled_output) - - total_loss = None - if labels is not None and next_sentence_label is not None: - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct( - prediction_scores.view(-1, self.config.vocab_size), - labels.view(-1)) - next_sentence_loss = loss_fct( - seq_relationship_score.view(-1, 2), - next_sentence_label.view(-1)) - total_loss = masked_lm_loss + next_sentence_loss - - if not return_dict: - output = (prediction_scores, seq_relationship_score) + outputs[2:] - return ((total_loss, ) - + output) if total_loss is not None else output - - return BertForPreTrainingOutput( - loss=total_loss, - prediction_logits=prediction_scores, - seq_relationship_logits=seq_relationship_score, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """Bert Model with a `language modeling` head on top for CLM fine-tuning. """, - BERT_START_DOCSTRING) -class BertLMHeadModel(BertPreTrainedModel): - - _keys_to_ignore_on_load_unexpected = [r'pooler'] - _keys_to_ignore_on_load_missing = [ - r'position_ids', r'predictions.decoder.bias' - ] - - def __init__(self, config): - super().__init__(config) - - if not config.is_decoder: - logger.warning( - 'If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`' - ) - - self.bert = BertModel(config, add_pooling_layer=False) - self.cls = BertOnlyMLMHead(config) - - # Initialize weights and apply final processing - self.post_init() - - def get_output_embeddings(self): - return self.cls.predictions.decoder - - def set_output_embeddings(self, new_embeddings): - self.cls.predictions.decoder = new_embeddings - - @add_start_docstrings_to_model_forward( - BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - @replace_return_docstrings( - output_type=CausalLMOutputWithCrossAttentions, - config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, - sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the - encoder. Used in the cross-attention if the model is configured - as a decoder. - encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, - sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices - of the encoder input. This mask is used in the cross-attention - if the model is configured as a decoder. Mask values selected in - `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, - *optional*): - Labels for computing the left-to-right language modeling loss - (next word prediction). Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with - indices set to `-100` are ignored (masked), the loss is only - computed for the tokens with labels n `[0, ..., - config.vocab_size]` - past_key_values (`tuple(tuple(torch.FloatTensor))` of length - `config.n_layers` with each tuple having 4 tensors of shape - `(batch_size, num_heads, sequence_length - 1, - embed_size_per_head)`): - Contains precomputed key and value hidden states of the - attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input - only the last `decoder_input_ids` (those that don't have their - past key value states given to this model) of shape - `(batch_size, 1)` instead of all `decoder_input_ids` of shape - `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are - returned and can be used to speed up decoding (see - `past_key_values`). - - Returns: - - Example: - - ```python >>> from transformers import BertTokenizer, BertLMHeadModel, - BertConfig >>> import torch - - >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') - >>> config = BertConfig.from_pretrained("bert-base-cased") - >>> config.is_decoder = True - >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) - - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") - >>> outputs = model(**inputs) - - >>> prediction_logits = outputs.logits - ``` - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if labels is not None: - use_cache = False - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - prediction_scores = self.cls(sequence_output) - - lm_loss = None - if labels is not None: - # we are doing next-token prediction; shift prediction scores and input ids by one - shifted_prediction_scores = prediction_scores[:, : - -1, :].contiguous() - labels = labels[:, 1:].contiguous() - loss_fct = CrossEntropyLoss() - lm_loss = loss_fct( - shifted_prediction_scores.view(-1, self.config.vocab_size), - labels.view(-1)) - - if not return_dict: - output = (prediction_scores, ) + outputs[2:] - return ((lm_loss, ) + output) if lm_loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=lm_loss, - logits=prediction_scores, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - def prepare_inputs_for_generation(self, - input_ids, - past=None, - attention_mask=None, - **model_kwargs): - input_shape = input_ids.shape - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = input_ids.new_ones(input_shape) - - # cut decoder_input_ids if past is used - if past is not None: - input_ids = input_ids[:, -1:] - - return { - 'input_ids': input_ids, - 'attention_mask': attention_mask, - 'past_key_values': past - } - - def _reorder_cache(self, past, beam_idx): - reordered_past = () - for layer_past in past: - reordered_past += (tuple( - past_state.index_select(0, beam_idx) - for past_state in layer_past), ) - return reordered_past - - -@add_start_docstrings( - """Bert Model with a `language modeling` head on top. """, - BERT_START_DOCSTRING) -class BertForMaskedLM(BertPreTrainedModel): - - _keys_to_ignore_on_load_unexpected = [r'pooler'] - _keys_to_ignore_on_load_missing = [ - r'position_ids', r'predictions.decoder.bias' - ] - - def __init__(self, config): - super().__init__(config) - - if config.is_decoder: - logger.warning( - 'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for ' - 'bi-directional self-attention.') - - self.bert = BertModel(config, add_pooling_layer=False) - self.cls = BertOnlyMLMHead(config) - - # Initialize weights and apply final processing - self.post_init() - - def get_output_embeddings(self): - return self.cls.predictions.decoder - - def set_output_embeddings(self, new_embeddings): - self.cls.predictions.decoder = new_embeddings - - @add_start_docstrings_to_model_forward( - BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, - *optional*): - Labels for computing the masked language modeling loss. Indices - should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` - docstring) Tokens with indices set to `-100` are ignored (masked), - the loss is only computed for the tokens with labels in `[0, ..., - config.vocab_size]` - """ - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - prediction_scores = self.cls(sequence_output) - - masked_lm_loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() # -100 index = padding token - masked_lm_loss = loss_fct( - prediction_scores.view(-1, self.config.vocab_size), - labels.view(-1)) - - if not return_dict: - output = (prediction_scores, ) + outputs[2:] - return ((masked_lm_loss, ) - + output) if masked_lm_loss is not None else output - - return MaskedLMOutput( - loss=masked_lm_loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation(self, - input_ids, - attention_mask=None, - **model_kwargs): - input_shape = input_ids.shape - effective_batch_size = input_shape[0] - - # add a dummy token - if self.config.pad_token_id is None: - raise ValueError('The PAD token should be defined for generation') - - padding_mask = attention_mask.new_zeros((attention_mask.shape[0], 1)) - attention_mask = torch.cat([attention_mask, padding_mask], dim=-1) - dummy_token = torch.full((effective_batch_size, 1), - self.config.pad_token_id, - dtype=torch.long, - device=input_ids.device) - input_ids = torch.cat([input_ids, dummy_token], dim=1) - - return {'input_ids': input_ids, 'attention_mask': attention_mask} - - -@add_start_docstrings( - """Bert Model with a `next sentence prediction (classification)` head on top. """, - BERT_START_DOCSTRING, -) -class BertForNextSentencePrediction(BertPreTrainedModel): - - def __init__(self, config): - super().__init__(config) - - self.bert = BertModel(config) - self.cls = BertOnlyNSPHead(config) - - # Initialize weights and apply final processing - self.post_init() - - @add_start_docstrings_to_model_forward( - BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - @replace_return_docstrings( - output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - **kwargs, - ): - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the next sequence prediction (classification) - loss. Input should be a sequence pair (see `input_ids` docstring). - Indices should be in `[0, 1]`: - - - 0 indicates sequence B is a continuation of sequence A, - - 1 indicates sequence B is a random sequence. - - Returns: - - Example: - - ```python >>> from transformers import BertTokenizer, - BertForNextSentencePrediction >>> import torch - - >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') - >>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased') - - >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." - >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." - >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt') - - >>> outputs = model(**encoding, labels=torch.LongTensor([1])) - >>> logits = outputs.logits - >>> assert logits[0, 0] < logits[0, 1] # next sentence was random - ``` - """ - - if 'next_sentence_label' in kwargs: - warnings.warn( - 'The `next_sentence_label` argument is deprecated, use `labels` instead.', - FutureWarning, - ) - labels = kwargs.pop('next_sentence_label') - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = outputs[1] - - seq_relationship_scores = self.cls(pooled_output) - - next_sentence_loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - next_sentence_loss = loss_fct( - seq_relationship_scores.view(-1, 2), labels.view(-1)) - - if not return_dict: - output = (seq_relationship_scores, ) + outputs[2:] - return ((next_sentence_loss, ) - + output) if next_sentence_loss is not None else output - - return NextSentencePredictorOutput( - loss=next_sentence_loss, - logits=seq_relationship_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Bert Model transformer with a sequence classification/regression head on top - (a linear layer on top of the pooled output) e.g. for GLUE tasks. - """, - BERT_START_DOCSTRING, -) -class BertForSequenceClassification(BertPreTrainedModel): - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.config = config - - self.bert = BertModel(config) - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None - else config.hidden_dropout_prob) - self.dropout = nn.Dropout(classifier_dropout) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - @add_start_docstrings_to_model_forward( - BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. - Indices should be in `[0, ..., config.num_labels - 1]`. If - `config.num_labels == 1` a regression loss is computed (Mean-Square - loss), If `config.num_labels > 1` a classification loss is computed - (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = outputs[1] - - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = 'regression' - elif self.num_labels > 1 and (labels.dtype == torch.long - or labels.dtype == torch.int): - self.config.problem_type = 'single_label_classification' - else: - self.config.problem_type = 'multi_label_classification' - - if self.config.problem_type == 'regression': - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(logits, labels) - elif self.config.problem_type == 'single_label_classification': - loss_fct = CrossEntropyLoss() - loss = loss_fct( - logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == 'multi_label_classification': - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits, labels) - if not return_dict: - output = (logits, ) + outputs[2:] - return ((loss, ) + output) if loss is not None else output - - return SequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Bert Model with a multiple choice classification head on top (a linear layer - on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks. - """, - BERT_START_DOCSTRING, -) -class BertForMultipleChoice(BertPreTrainedModel): - - def __init__(self, config): - super().__init__(config) - - self.bert = BertModel(config) - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None - else config.hidden_dropout_prob) - self.dropout = nn.Dropout(classifier_dropout) - self.classifier = nn.Linear(config.hidden_size, 1) - - # Initialize weights and apply final processing - self.post_init() - - @add_start_docstrings_to_model_forward( - BERT_INPUTS_DOCSTRING.format( - 'batch_size, num_choices, sequence_length')) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the multiple choice classification loss. - Indices should be in `[0, ..., num_choices-1]` where `num_choices` - is the size of the second dimension of the input tensors. (See - `input_ids` above) - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - num_choices = input_ids.shape[ - 1] if input_ids is not None else inputs_embeds.shape[1] - - input_ids = input_ids.view( - -1, input_ids.size(-1)) if input_ids is not None else None - attention_mask = attention_mask.view( - -1, - attention_mask.size(-1)) if attention_mask is not None else None - token_type_ids = token_type_ids.view( - -1, - token_type_ids.size(-1)) if token_type_ids is not None else None - position_ids = position_ids.view( - -1, position_ids.size(-1)) if position_ids is not None else None - inputs_embeds = ( - inputs_embeds.view(-1, inputs_embeds.size(-2), - inputs_embeds.size(-1)) - if inputs_embeds is not None else None) - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = outputs[1] - - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - reshaped_logits = logits.view(-1, num_choices) - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(reshaped_logits, labels) - - if not return_dict: - output = (reshaped_logits, ) + outputs[2:] - return ((loss, ) + output) if loss is not None else output - - return MultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Bert Model with a token classification head on top (a linear layer on top of - the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. - """, - BERT_START_DOCSTRING, -) -class BertForTokenClassification(BertPreTrainedModel): - - _keys_to_ignore_on_load_unexpected = [r'pooler'] - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - - self.bert = BertModel(config, add_pooling_layer=False) - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None - else config.hidden_dropout_prob) - self.dropout = nn.Dropout(classifier_dropout) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - @add_start_docstrings_to_model_forward( - BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, - *optional*): - Labels for computing the token classification loss. Indices should - be in `[0, ..., config.num_labels - 1]`. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - - sequence_output = self.dropout(sequence_output) - logits = self.classifier(sequence_output) - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - # Only keep active parts of the loss - if attention_mask is not None: - active_loss = attention_mask.view(-1) == 1 - active_logits = logits.view(-1, self.num_labels) - active_labels = torch.where( - active_loss, labels.view(-1), - torch.tensor(loss_fct.ignore_index).type_as(labels)) - loss = loss_fct(active_logits, active_labels) - else: - loss = loss_fct( - logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits, ) + outputs[2:] - return ((loss, ) + output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Bert Model with a span classification head on top for extractive - question-answering tasks like SQuAD (a linear layers on top of the - hidden-states output to compute `span start logits` and `span end logits`). - """, - BERT_START_DOCSTRING, -) -class BertForQuestionAnswering(BertPreTrainedModel): - - _keys_to_ignore_on_load_unexpected = [r'pooler'] - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - - self.bert = BertModel(config, add_pooling_layer=False) - self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - @add_start_docstrings_to_model_forward( - BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - start_positions=None, - end_positions=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - start_positions (`torch.LongTensor` of shape `(batch_size,)`, - *optional*): - Labels for position (index) of the start of the labelled span for - computing the token classification loss. Positions are clamped to - the length of the sequence (`sequence_length`). Position outside of - the sequence are not taken into account for computing the loss. - end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for - computing the token classification loss. Positions are clamped to - the length of the sequence (`sequence_length`). Position outside of - the sequence are not taken into account for computing the loss. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1).contiguous() - end_logits = end_logits.squeeze(-1).contiguous() - - total_loss = None - if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions = start_positions.clamp(0, ignored_index) - end_positions = end_positions.clamp(0, ignored_index) - - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((total_loss, ) - + output) if total_loss is not None else output - - return QuestionAnsweringModelOutput( - loss=total_loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) diff --git a/modelscope/models/nlp/bert/sentence_embedding.py b/modelscope/models/nlp/bert/sentence_embedding.py new file mode 100644 index 00000000..f4c2620e --- /dev/null +++ b/modelscope/models/nlp/bert/sentence_embedding.py @@ -0,0 +1,113 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from modelscope.metainfo import Models +from modelscope.models import Model +from modelscope.models.builder import MODELS +from modelscope.outputs import BackboneModelOutput +from modelscope.utils.constant import Tasks +from .backbone import BertModel, BertPreTrainedModel + + +@MODELS.register_module(Tasks.sentence_embedding, module_name=Models.bert) +class BertForSentenceEmbedding(BertPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + self.config = config + setattr(self, self.base_model_prefix, + BertModel(config, add_pooling_layer=False)) + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ) -> BackboneModelOutput: + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. + Returns: + Returns `modelscope.outputs.AttentionTextClassificationModelOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_corom_sentence-embedding_chinese-base') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_corom_sentence-embedding_chinese-base') + >>> print(model(**preprocessor('This is a test'))) + """ + return self.base_model.forward( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + Args: + kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the label information. + + Returns: + The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained + """ + model_dir = kwargs.get('model_dir') + model = super( + Model, + cls).from_pretrained(pretrained_model_name_or_path=model_dir) + model.model_dir = model_dir + return model diff --git a/modelscope/models/nlp/bert/text_classification.py b/modelscope/models/nlp/bert/text_classification.py new file mode 100644 index 00000000..b1d18d0f --- /dev/null +++ b/modelscope/models/nlp/bert/text_classification.py @@ -0,0 +1,208 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionTextClassificationModelOutput +from modelscope.utils import logger as logging +from modelscope.utils.constant import Tasks +from .backbone import BertModel, BertPreTrainedModel + +logger = logging.get_logger(__name__) + + +@MODELS.register_module(Tasks.text_classification, module_name=Models.bert) +@MODELS.register_module(Tasks.nli, module_name=Models.bert) +@MODELS.register_module( + Tasks.sentiment_classification, module_name=Models.bert) +@MODELS.register_module(Tasks.sentence_similarity, module_name=Models.bert) +@MODELS.register_module( + Tasks.zero_shot_classification, module_name=Models.bert) +class BertForSequenceClassification(BertPreTrainedModel): + r"""Bert Model transformer with a sequence classification/regression head on top + (a linear layer on top of the pooled output) e.g. for GLUE tasks. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Preprocessor: + This is the fill_mask model of Bert, the preprocessor of this model + is `modelscope.preprocessors.SequenceClassificationPreprocessor`. + + Trainer: + This model is a normal PyTorch model, and can be trained by variable trainers, like EpochBasedTrainer, + NlpEpochBasedTrainer, or trainers from other frameworks. + The preferred trainer in ModelScope is NlpEpochBasedTrainer. + + Parameters: + config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with + all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. + """ + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + setattr(self, self.base_model_prefix, BertModel(config)) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None + else config.hidden_dropout_prob) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + Returns `modelscope.outputs.AttentionTextClassificationModelOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_structbert_sentence-similarity_chinese-base') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_structbert_sentence-similarity_chinese-base') + >>> print(model(**preprocessor(('This is a test', 'This is also a test')))) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.base_model.forward( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = 'regression' + elif self.num_labels > 1 and (labels.dtype == torch.long + or labels.dtype == torch.int): + self.config.problem_type = 'single_label_classification' + else: + self.config.problem_type = 'multi_label_classification' + + if self.config.problem_type == 'regression': + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == 'single_label_classification': + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == 'multi_label_classification': + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits, ) + outputs[2:] + return ((loss, ) + output) if loss is not None else output + + return AttentionTextClassificationModelOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/modelscope/models/nlp/bert/text_ranking.py b/modelscope/models/nlp/bert/text_ranking.py new file mode 100644 index 00000000..d6bbf277 --- /dev/null +++ b/modelscope/models/nlp/bert/text_ranking.py @@ -0,0 +1,92 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch +import torch.utils.checkpoint + +from modelscope.metainfo import Models +from modelscope.models import Model +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionTextClassificationModelOutput +from modelscope.utils import logger as logging +from modelscope.utils.constant import Tasks +from .backbone import BertModel +from .text_classification import BertForSequenceClassification + +logger = logging.get_logger(__name__) + + +@MODELS.register_module(Tasks.text_ranking, module_name=Models.bert) +class BertForTextRanking(BertForSequenceClassification): + + def __init__(self, config, *args, **kwargs): + super().__init__(config) + neg_sample = kwargs.get('neg_sample', 8) + self.neg_sample = neg_sample + setattr(self, self.base_model_prefix, + BertModel(self.config, add_pooling_layer=True)) + + def forward(self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs) -> AttentionTextClassificationModelOutput: + outputs = self.base_model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + + # backbone model should return pooled_output as its second output + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + if self.base_model.training: + scores = logits.view(-1, self.neg_sample + 1) + batch_size = scores.size(0) + loss_fct = torch.nn.CrossEntropyLoss() + target_label = torch.zeros( + batch_size, dtype=torch.long, device=scores.device) + loss = loss_fct(scores, target_label) + return AttentionTextClassificationModelOutput( + loss=loss, + logits=logits, + ) + return AttentionTextClassificationModelOutput(logits=logits, ) + + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + Args: + kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the label information. + num_labels: An optional arg to tell the model how many classes to initialize. + Method will call utils.parse_label_mapping if num_labels not supplied. + If num_labels is not found, the model will use the default setting (1 classes). + + Returns: + The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained + """ + num_labels = kwargs.get('num_labels', 1) + neg_sample = kwargs.get('neg_sample', 4) + model_args = {} if num_labels is None else {'num_labels': num_labels} + if neg_sample is not None: + model_args['neg_sample'] = neg_sample + + model_dir = kwargs.get('model_dir') + model = super(Model, cls).from_pretrained( + pretrained_model_name_or_path=model_dir, **model_args) + model.model_dir = model_dir + return model diff --git a/modelscope/models/nlp/bert/token_classification.py b/modelscope/models/nlp/bert/token_classification.py new file mode 100644 index 00000000..5dc6b0ce --- /dev/null +++ b/modelscope/models/nlp/bert/token_classification.py @@ -0,0 +1,225 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.outputs import TokenClassifierOutput +from modelscope.utils import logger as logging +from modelscope.utils.constant import Tasks +from .backbone import BertModel, BertPreTrainedModel + +logger = logging.get_logger(__name__) + + +@MODELS.register_module(Tasks.token_classification, module_name=Models.bert) +@MODELS.register_module(Tasks.part_of_speech, module_name=Models.bert) +@MODELS.register_module(Tasks.word_segmentation, module_name=Models.bert) +class BertForTokenClassification(BertPreTrainedModel): + r"""Bert Model with a token classification head on top (a linear layer on top of + the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks, word-segmentation. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Preprocessor: + This is the fill_mask model of Bert, the preprocessor of this model + is `modelscope.preprocessors.SequenceClassificationPreprocessor`. + + Trainer: + This model is a normal PyTorch model, and can be trained by variable trainers, like EpochBasedTrainer, + NlpEpochBasedTrainer, or trainers from other frameworks. + The preferred trainer in ModelScope is NlpEpochBasedTrainer. + + Parameters: + config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with + all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. + """ + _keys_to_ignore_on_load_unexpected = [r'pooler'] + + def __init__(self, config, **kwargs): + super().__init__(config) + self.num_labels = config.num_labels + + setattr(self, self.base_model_prefix, + BertModel(config, add_pooling_layer=False)) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None + else config.hidden_dropout_prob) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + offset_mapping=None, + label_mask=None, + ): + r""" + Args: input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, + sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using + :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and + :meth:`transformers.PreTrainedTokenizer.__call__` for details. + + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, + sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask + values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, + sequence_length)`, `optional`): + Segment token indices to indicate first and second portions of the + inputs. Indices are selected in ``[0, 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, + sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position + embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or + :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask + values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, + sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to + directly pass an embedded representation. This is useful if you want + more control over how to convert :obj:`input_ids` indices into + associated vectors than the model's internal embedding lookup + matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention + layers. See ``attentions`` under returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See + ``hidden_states`` under returned tensors for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.ModelOutput` + instead of a plain tuple. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, + `optional`): + Labels for computing the sequence classification/regression loss. + Indices should be in :obj:`[0, ..., config.num_labels - 1]`. If + :obj:`config.num_labels == 1` a regression loss is computed + (Mean-Square loss), If :obj:`config.num_labels > 1` a classification + loss is computed (Cross-Entropy). + offset_mapping (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, + sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the sentence. + Selected in the range ``[0, sequence_length - 1]``. + label_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, + sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask + values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + Returns `modelscope.outputs.TokenClassifierOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_bert_word-segmentation_chinese-base') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_bert_word-segmentation_chinese-base') + >>> print(model(**preprocessor(('This is a test', 'This is also a test')))) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), + torch.tensor(loss_fct.ignore_index).type_as(labels)) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct( + logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits, ) + outputs[2:] + return ((loss, ) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + offset_mapping=offset_mapping, + ) diff --git a/modelscope/models/nlp/bloom/backbone.py b/modelscope/models/nlp/bloom/backbone.py new file mode 100644 index 00000000..b6bd315e --- /dev/null +++ b/modelscope/models/nlp/bloom/backbone.py @@ -0,0 +1,15 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from transformers import BloomConfig +from transformers import BloomModel as BloomModelTransform + +from modelscope.metainfo import Models +from modelscope.models.builder import BACKBONES +from modelscope.utils.constant import Fields + + +@BACKBONES.register_module(group_key=Fields.nlp, module_name=Models.bloom) +class BloomModel(BloomModelTransform): + + def __init__(self, **kwargs): + config = BloomConfig(**kwargs) + super().__init__(config) diff --git a/modelscope/models/nlp/csanmt/__init__.py b/modelscope/models/nlp/csanmt/__init__.py new file mode 100644 index 00000000..85531617 --- /dev/null +++ b/modelscope/models/nlp/csanmt/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .translation import CsanmtForTranslation diff --git a/modelscope/models/nlp/csanmt_for_translation.py b/modelscope/models/nlp/csanmt/translation.py similarity index 100% rename from modelscope/models/nlp/csanmt_for_translation.py rename to modelscope/models/nlp/csanmt/translation.py diff --git a/modelscope/models/nlp/deberta_v2/__init__.py b/modelscope/models/nlp/deberta_v2/__init__.py index 830210ed..08b184e5 100644 --- a/modelscope/models/nlp/deberta_v2/__init__.py +++ b/modelscope/models/nlp/deberta_v2/__init__.py @@ -22,38 +22,28 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .configuration_deberta_v2 import DebertaV2Config - from .tokenization_deberta_v2 import DebertaV2Tokenizer - from .tokenization_deberta_v2_fast import DebertaV2TokenizerFast - - from .modeling_deberta_v2 import ( - DebertaV2ForMaskedLM, - DebertaV2ForMultipleChoice, - DebertaV2ForQuestionAnswering, - DebertaV2ForSequenceClassification, - DebertaV2ForTokenClassification, + from .configuration import DebertaV2Config + from .tokenization import DebertaV2Tokenizer + from .tokenization_fast import DebertaV2TokenizerFast + from .backbone import ( DebertaV2Model, DebertaV2PreTrainedModel, ) + from .fill_mask import DebertaV2ForMaskedLM else: _import_structure = { - 'configuration_deberta_v2': - ['DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP', 'DebertaV2Config'], - 'tokenization_deberta_v2': ['DebertaV2Tokenizer'] + 'configuration': ['DebertaV2Config'], + 'tokenization': ['DebertaV2Tokenizer'], + 'tokenization_fast': ['DebertaV2TokenizerFast'], + 'backbone': [ + 'DebertaV2Model', + 'DebertaV2PreTrainedModel', + ], + 'fill_mask': [ + 'DebertaV2ForMaskedLM', + ] } - _import_structure['tokenization_deberta_v2_fast'] = [ - 'DebertaV2TokenizerFast' - ] - _import_structure['modeling_deberta_v2'] = [ - 'DebertaV2ForMaskedLM', - 'DebertaV2ForMultipleChoice', - 'DebertaV2ForQuestionAnswering', - 'DebertaV2ForSequenceClassification', - 'DebertaV2ForTokenClassification', - 'DebertaV2Model', - 'DebertaV2PreTrainedModel', - ] import sys sys.modules[__name__] = LazyImportModule( diff --git a/modelscope/models/nlp/deberta_v2/modeling_deberta_v2.py b/modelscope/models/nlp/deberta_v2/backbone.py similarity index 64% rename from modelscope/models/nlp/deberta_v2/modeling_deberta_v2.py rename to modelscope/models/nlp/deberta_v2/backbone.py index 1c6b9071..cca38133 100644 --- a/modelscope/models/nlp/deberta_v2/modeling_deberta_v2.py +++ b/modelscope/models/nlp/deberta_v2/backbone.py @@ -20,28 +20,22 @@ from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss +from torch.nn import LayerNorm from transformers.activations import ACT2FN -from transformers.file_utils import (add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward) -from transformers.modeling_outputs import (BaseModelOutput, MaskedLMOutput, - MultipleChoiceModelOutput, - QuestionAnsweringModelOutput, - SequenceClassifierOutput, - TokenClassifierOutput) +from transformers.modeling_outputs import BaseModelOutput from transformers.modeling_utils import PreTrainedModel from transformers.pytorch_utils import softmax_backward_data +from modelscope.metainfo import Models +from modelscope.models import Model, TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionBackboneModelOutput from modelscope.utils import logger as logging -from .configuration_deberta_v2 import DebertaV2Config +from modelscope.utils.constant import Tasks +from .configuration import DebertaV2Config logger = logging.get_logger(__name__) -_CONFIG_FOR_DOC = 'DebertaV2Config' -_TOKENIZER_FOR_DOC = 'DebertaV2Tokenizer' -_CHECKPOINT_FOR_DOC = 'nlp_debertav2_fill-mask_chinese-lite' - # Copied from transformers.models.deberta.modeling_deberta.ContextPooler class ContextPooler(nn.Module): @@ -1006,7 +1000,7 @@ class DebertaV2Embeddings(nn.Module): # Copied from transformers.models.deberta.modeling_deberta.DebertaPreTrainedModel with Deberta->DebertaV2 -class DebertaV2PreTrainedModel(PreTrainedModel): +class DebertaV2PreTrainedModel(TorchModel, PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. @@ -1018,6 +1012,10 @@ class DebertaV2PreTrainedModel(PreTrainedModel): _keys_to_ignore_on_load_unexpected = ['position_embeddings'] supports_gradient_checkpointing = True + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, nn.Linear): @@ -1037,8 +1035,24 @@ class DebertaV2PreTrainedModel(PreTrainedModel): if isinstance(module, DebertaV2Encoder): module.gradient_checkpointing = value + @classmethod + def _instantiate(cls, **kwargs): + model_dir = kwargs.pop('model_dir', None) + if model_dir is None: + ponet_config = DebertaV2Config(**kwargs) + model = cls(ponet_config) + else: + model = super( + Model, + cls).from_pretrained(pretrained_model_name_or_path=model_dir) + return model + + +@MODELS.register_module(Tasks.backbone, module_name=Models.deberta_v2) +# Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2 +class DebertaV2Model(DebertaV2PreTrainedModel): + """The bare DeBERTa_v2 Model transformer outputting raw hidden-states without any specific head on top. -DEBERTA_START_DOCSTRING = r""" The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two @@ -1048,65 +1062,13 @@ DEBERTA_START_DOCSTRING = r""" Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. - Parameters: - config ([`DebertaV2Config`]): Model configuration class with all the parameters of the model. + config (`DebertaV2Config`): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -DEBERTA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`DebertaV2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`torch.LongTensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert *input_ids* indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - 'The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.', - DEBERTA_START_DOCSTRING, -) -# Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2 -class DebertaV2Model(DebertaV2PreTrainedModel): + configuration. + """ - def __init__(self, config): + def __init__(self, config, **kwargs): super().__init__(config) self.embeddings = DebertaV2Embeddings(config) @@ -1130,14 +1092,6 @@ class DebertaV2Model(DebertaV2PreTrainedModel): raise NotImplementedError( 'The prune function is not implemented in DeBERTa model.') - @add_start_docstrings_to_model_forward( - DEBERTA_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=BaseModelOutput, - config_class=_CONFIG_FOR_DOC, - ) def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -1148,7 +1102,53 @@ class DebertaV2Model(DebertaV2PreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutput]: + ) -> Union[Tuple, AttentionBackboneModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `('batch_size, sequence_length')`): + Indices of input sequence tokens in the vocabulary. + + attention_mask (`torch.FloatTensor` of shape `('batch_size, sequence_length')`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + token_type_ids (`torch.LongTensor` of shape `('batch_size, sequence_length')`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + position_ids (`torch.LongTensor` of shape `('batch_size, sequence_length')`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + inputs_embeds (`torch.FloatTensor` of shape `('batch_size, sequence_length', hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a dataclass instead of a plain tuple. + + Returns: + Returns `modelscope.outputs.AttentionBackboneModelOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_debertav2_fill-mask_chinese-lite', task='backbone') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_debertav2_fill-mask_chinese-lite') + >>> print(model(**preprocessor('这是个测试'))) + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else @@ -1216,574 +1216,9 @@ class DebertaV2Model(DebertaV2PreTrainedModel): return (sequence_output, ) + encoder_outputs[ (1 if output_hidden_states else 2):] - return BaseModelOutput( + return AttentionBackboneModelOutput( last_hidden_state=sequence_output, hidden_states=encoder_outputs.hidden_states if output_hidden_states else None, attentions=encoder_outputs.attentions, ) - - -@add_start_docstrings( - """DeBERTa Model with a `language modeling` head on top.""", - DEBERTA_START_DOCSTRING) -# Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM with Deberta->DebertaV2 -class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r'pooler'] - _keys_to_ignore_on_load_missing = [ - r'position_ids', r'predictions.decoder.bias' - ] - - def __init__(self, config): - super().__init__(config) - - self.deberta = DebertaV2Model(config) - self.cls = DebertaV2OnlyMLMHead(config) - - # Initialize weights and apply final processing - self.post_init() - - def get_output_embeddings(self): - return self.cls.predictions.decoder - - def set_output_embeddings(self, new_embeddings): - self.cls.predictions.decoder = new_embeddings - - @add_start_docstrings_to_model_forward( - DEBERTA_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=MaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, MaskedLMOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - """ - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.deberta( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - prediction_scores = self.cls(sequence_output) - - masked_lm_loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() # -100 index = padding token - masked_lm_loss = loss_fct( - prediction_scores.view(-1, self.config.vocab_size), - labels.view(-1)) - - if not return_dict: - output = (prediction_scores, ) + outputs[1:] - return ((masked_lm_loss, ) - + output) if masked_lm_loss is not None else output - - return MaskedLMOutput( - loss=masked_lm_loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta -class DebertaV2PredictionHeadTransform(nn.Module): - - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - if isinstance(config.hidden_act, str): - self.transform_act_fn = ACT2FN[config.hidden_act] - else: - self.transform_act_fn = config.hidden_act - self.LayerNorm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps) - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.LayerNorm(hidden_states) - return hidden_states - - -# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta -class DebertaV2LMPredictionHead(nn.Module): - - def __init__(self, config): - super().__init__() - self.transform = DebertaV2PredictionHeadTransform(config) - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = nn.Linear( - config.hidden_size, config.vocab_size, bias=False) - - self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def forward(self, hidden_states): - hidden_states = self.transform(hidden_states) - hidden_states = self.decoder(hidden_states) - return hidden_states - - -# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta -class DebertaV2OnlyMLMHead(nn.Module): - - def __init__(self, config): - super().__init__() - self.predictions = DebertaV2LMPredictionHead(config) - - def forward(self, sequence_output): - prediction_scores = self.predictions(sequence_output) - return prediction_scores - - -@add_start_docstrings( - """ - DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the - pooled output) e.g. for GLUE tasks. - """, - DEBERTA_START_DOCSTRING, -) -# Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2 -class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel): - - def __init__(self, config): - super().__init__(config) - - num_labels = getattr(config, 'num_labels', 2) - self.num_labels = num_labels - - self.deberta = DebertaV2Model(config) - self.pooler = ContextPooler(config) - output_dim = self.pooler.output_dim - - self.classifier = nn.Linear(output_dim, num_labels) - drop_out = getattr(config, 'cls_dropout', None) - drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out - self.dropout = StableDropout(drop_out) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.deberta.get_input_embeddings() - - def set_input_embeddings(self, new_embeddings): - self.deberta.set_input_embeddings(new_embeddings) - - @add_start_docstrings_to_model_forward( - DEBERTA_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=SequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.deberta( - input_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - encoder_layer = outputs[0] - pooled_output = self.pooler(encoder_layer) - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - # regression task - loss_fn = nn.MSELoss() - logits = logits.view(-1).to(labels.dtype) - loss = loss_fn(logits, labels.view(-1)) - elif labels.dim() == 1 or labels.size(-1) == 1: - label_index = (labels >= 0).nonzero() - labels = labels.long() - if label_index.size(0) > 0: - labeled_logits = torch.gather( - logits, 0, - label_index.expand( - label_index.size(0), logits.size(1))) - labels = torch.gather(labels, 0, label_index.view(-1)) - loss_fct = CrossEntropyLoss() - loss = loss_fct( - labeled_logits.view(-1, self.num_labels).float(), - labels.view(-1)) - else: - loss = torch.tensor(0).to(logits) - else: - log_softmax = nn.LogSoftmax(-1) - loss = -((log_softmax(logits) * labels).sum(-1)).mean() - elif self.config.problem_type == 'regression': - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(logits, labels) - elif self.config.problem_type == 'single_label_classification': - loss_fct = CrossEntropyLoss() - loss = loss_fct( - logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == 'multi_label_classification': - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits, labels) - if not return_dict: - output = (logits, ) + outputs[1:] - return ((loss, ) + output) if loss is not None else output - - return SequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions) - - -@add_start_docstrings( - """ - DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - DEBERTA_START_DOCSTRING, -) -# Copied from transformers.models.deberta.modeling_deberta.DebertaForTokenClassification with Deberta->DebertaV2 -class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r'pooler'] - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - - self.deberta = DebertaV2Model(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - @add_start_docstrings_to_model_forward( - DEBERTA_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.deberta( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - - sequence_output = self.dropout(sequence_output) - logits = self.classifier(sequence_output) - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits, ) + outputs[1:] - return ((loss, ) + output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions) - - -@add_start_docstrings( - """ - DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - DEBERTA_START_DOCSTRING, -) -# Copied from transformers.models.deberta.modeling_deberta.DebertaForQuestionAnswering with Deberta->DebertaV2 -class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r'pooler'] - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - - self.deberta = DebertaV2Model(config) - self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - @add_start_docstrings_to_model_forward( - DEBERTA_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=QuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - start_positions: Optional[torch.Tensor] = None, - end_positions: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, QuestionAnsweringModelOutput]: - r""" - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.deberta( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1).contiguous() - end_logits = end_logits.squeeze(-1).contiguous() - - total_loss = None - if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions = start_positions.clamp(0, ignored_index) - end_positions = end_positions.clamp(0, ignored_index) - - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - - if not return_dict: - output = (start_logits, end_logits) + outputs[1:] - return ((total_loss, ) - + output) if total_loss is not None else output - - return QuestionAnsweringModelOutput( - loss=total_loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - DeBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """, - DEBERTA_START_DOCSTRING, -) -class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel): - - def __init__(self, config): - super().__init__(config) - - num_labels = getattr(config, 'num_labels', 2) - self.num_labels = num_labels - - self.deberta = DebertaV2Model(config) - self.pooler = ContextPooler(config) - output_dim = self.pooler.output_dim - - self.classifier = nn.Linear(output_dim, 1) - drop_out = getattr(config, 'cls_dropout', None) - drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out - self.dropout = StableDropout(drop_out) - - self.init_weights() - - def get_input_embeddings(self): - return self.deberta.get_input_embeddings() - - def set_input_embeddings(self, new_embeddings): - self.deberta.set_input_embeddings(new_embeddings) - - @add_start_docstrings_to_model_forward( - DEBERTA_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=MultipleChoiceModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., - num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See - `input_ids` above) - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - num_choices = input_ids.shape[ - 1] if input_ids is not None else inputs_embeds.shape[1] - - flat_input_ids = input_ids.view( - -1, input_ids.size(-1)) if input_ids is not None else None - flat_position_ids = position_ids.view( - -1, position_ids.size(-1)) if position_ids is not None else None - flat_token_type_ids = token_type_ids.view( - -1, - token_type_ids.size(-1)) if token_type_ids is not None else None - flat_attention_mask = attention_mask.view( - -1, - attention_mask.size(-1)) if attention_mask is not None else None - flat_inputs_embeds = ( - inputs_embeds.view(-1, inputs_embeds.size(-2), - inputs_embeds.size(-1)) - if inputs_embeds is not None else None) - - outputs = self.deberta( - flat_input_ids, - position_ids=flat_position_ids, - token_type_ids=flat_token_type_ids, - attention_mask=flat_attention_mask, - inputs_embeds=flat_inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - encoder_layer = outputs[0] - pooled_output = self.pooler(encoder_layer) - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - reshaped_logits = logits.view(-1, num_choices) - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(reshaped_logits, labels) - - if not return_dict: - output = (reshaped_logits, ) + outputs[1:] - return ((loss, ) + output) if loss is not None else output - - return MultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) diff --git a/modelscope/models/nlp/deberta_v2/configuration_deberta_v2.py b/modelscope/models/nlp/deberta_v2/configuration.py similarity index 98% rename from modelscope/models/nlp/deberta_v2/configuration_deberta_v2.py rename to modelscope/models/nlp/deberta_v2/configuration.py index 65e8f0b7..7921ca2f 100644 --- a/modelscope/models/nlp/deberta_v2/configuration_deberta_v2.py +++ b/modelscope/models/nlp/deberta_v2/configuration.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """ DeBERTa-v2 model configuration, mainly copied from :class:`~transformers.DeBERTaV2Config""" -from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Mapping, Optional, Union from transformers import PretrainedConfig diff --git a/modelscope/models/nlp/deberta_v2/fill_mask.py b/modelscope/models/nlp/deberta_v2/fill_mask.py new file mode 100644 index 00000000..ed127d4c --- /dev/null +++ b/modelscope/models/nlp/deberta_v2/fill_mask.py @@ -0,0 +1,230 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2020 Microsoft and the Hugging Face Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionFillMaskModelOutput +from modelscope.utils.constant import Tasks +from .backbone import DebertaV2Model, DebertaV2PreTrainedModel + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM with Deberta->DebertaV2 +@MODELS.register_module(Tasks.fill_mask, module_name=Models.deberta_v2) +class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel): + r"""DeBERTa_v2 Model with a `language modeling` head on top. + + The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled + Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build + on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two + improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data. + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Preprocessor: + This is the fill_mask model of Deberta_v2, the preprocessor of this model + is `modelscope.preprocessors.NLPPreprocessor`. + + Parameters: + config (`DebertaV2Config`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. + """ + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config, **kwargs): + super().__init__(config) + + self.deberta = DebertaV2Model(config) + self.cls = DebertaV2OnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, AttentionFillMaskModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `('batch_size, sequence_length')`): + Indices of input sequence tokens in the vocabulary. + + attention_mask (`torch.FloatTensor` of shape `('batch_size, sequence_length')`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + token_type_ids (`torch.LongTensor` of shape `('batch_size, sequence_length')`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + position_ids (`torch.LongTensor` of shape `('batch_size, sequence_length')`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range `[0, config.max_position_embeddings - 1]`. + + inputs_embeds (`torch.FloatTensor` of shape `('batch_size, sequence_length', hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert *input_ids* indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a dataclass instead of a plain tuple. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + + Returns: + Returns `modelscope.outputs.AttentionFillMaskModelOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_debertav2_fill-mask_chinese-lite') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_debertav2_fill-mask_chinese-lite') + >>> # Call the model, return some tensors + >>> print(model(**preprocessor('你师父差得动你,你师父可[MASK]不动我。'))) + >>> # Call the pipeline + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline('fill-mask', model=model, preprocessor=preprocessor) + >>> print(pipeline_ins('你师父差得动你,你师父可[MASK]不动我。')) + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + + if not return_dict: + output = (prediction_scores, ) + outputs[1:] + return ((masked_lm_loss, ) + + output) if masked_lm_loss is not None else output + + return AttentionFillMaskModelOutput( + loss=masked_lm_loss, + logits=prediction_scores, + input_ids=input_ids, + attentions=outputs.attentions, + hidden_states=outputs.hidden_states) + + +# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta +class DebertaV2PredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta +class DebertaV2LMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.transform = DebertaV2PredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta +class DebertaV2OnlyMLMHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = DebertaV2LMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores diff --git a/modelscope/models/nlp/deberta_v2/tokenization_deberta_v2.py b/modelscope/models/nlp/deberta_v2/tokenization.py similarity index 100% rename from modelscope/models/nlp/deberta_v2/tokenization_deberta_v2.py rename to modelscope/models/nlp/deberta_v2/tokenization.py diff --git a/modelscope/models/nlp/deberta_v2/tokenization_deberta_v2_fast.py b/modelscope/models/nlp/deberta_v2/tokenization_fast.py similarity index 99% rename from modelscope/models/nlp/deberta_v2/tokenization_deberta_v2_fast.py rename to modelscope/models/nlp/deberta_v2/tokenization_fast.py index a1fcecf4..913ea5bd 100644 --- a/modelscope/models/nlp/deberta_v2/tokenization_deberta_v2_fast.py +++ b/modelscope/models/nlp/deberta_v2/tokenization_fast.py @@ -24,7 +24,7 @@ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast from modelscope.utils import logger as logging if is_sentencepiece_available(): - from .tokenization_deberta_v2 import DebertaV2Tokenizer + from .tokenization import DebertaV2Tokenizer else: DebertaV2Tokenizer = None diff --git a/modelscope/models/nlp/gpt3/__init__.py b/modelscope/models/nlp/gpt3/__init__.py index 076a0c6b..051cc8f2 100644 --- a/modelscope/models/nlp/gpt3/__init__.py +++ b/modelscope/models/nlp/gpt3/__init__.py @@ -4,14 +4,16 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .configuration_gpt3 import GPT3Config - from .modeling_gpt3 import GPT3Model - from .gpt3_for_text_generation import GPT3ForTextGeneration + from .configuration import GPT3Config + from .backbone import GPT3Model + from .text_generation import GPT3ForTextGeneration + from .tokenizer import JiebaBPETokenizer else: _import_structure = { - 'configuration_gpt3': ['GPT3Config'], - 'modeling_gpt3': ['GPT3Model'], - 'gpt3_for_text_generation': ['GPT3ForTextGeneration'], + 'configuration': ['GPT3Config'], + 'backbone': ['GPT3Model'], + 'text_generation': ['GPT3ForTextGeneration'], + 'tokenizer': ['JiebaBPETokenizer'], } import sys diff --git a/modelscope/models/nlp/gpt3/modeling_gpt3.py b/modelscope/models/nlp/gpt3/backbone.py similarity index 87% rename from modelscope/models/nlp/gpt3/modeling_gpt3.py rename to modelscope/models/nlp/gpt3/backbone.py index 498d15de..587c7a9d 100644 --- a/modelscope/models/nlp/gpt3/modeling_gpt3.py +++ b/modelscope/models/nlp/gpt3/backbone.py @@ -19,16 +19,15 @@ from typing import Optional, Union import addict import torch -from torch.nn import (CrossEntropyLoss, Dropout, Embedding, LayerNorm, Linear, - Module, Softmax) +from torch import nn from torch.nn import functional as F from transformers.modeling_utils import PreTrainedModel from modelscope.utils.constant import ModelFile -from .configuration_gpt3 import GPT3Config +from .configuration import GPT3Config -class GPT3SelfAttention(Module): +class GPT3SelfAttention(nn.Module): """Parallel self-attention layer abstract class. Self-attention layer takes input with size [s, b, h] @@ -44,13 +43,15 @@ class GPT3SelfAttention(Module): self.hidden_size_per_attention_head = \ self.hidden_size // self.num_attention_heads - self.query_key_value = Linear(self.hidden_size, 3 * self.hidden_size) - self.softmax = Softmax(dim=-1) - self.attention_dropout = Dropout(config.attention_probs_dropout_prob) + self.query_key_value = nn.Linear(self.hidden_size, + 3 * self.hidden_size) + self.softmax = nn.Softmax(dim=-1) + self.attention_dropout = nn.Dropout( + config.attention_probs_dropout_prob) # Output. - self.dense = Linear(self.hidden_size, self.hidden_size) - self.output_dropout = torch.nn.Dropout(config.hidden_dropout_prob) + self.dense = nn.Linear(self.hidden_size, self.hidden_size) + self.output_dropout = nn.Dropout(config.hidden_dropout_prob) def _transpose_for_scores(self, tensor): """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with @@ -133,7 +134,7 @@ class GPT3SelfAttention(Module): return output -class GPT3MLP(Module): +class GPT3MLP(nn.Module): """MLP. MLP will take the input with h hidden state, project it to 4*h @@ -146,12 +147,12 @@ class GPT3MLP(Module): hidden_size = config.hidden_size # Project to 4h. - self.dense_h_to_4h = Linear(hidden_size, 4 * hidden_size) + self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size) self.activation_func = F.gelu # Project back to h. - self.dense_4h_to_h = Linear(4 * hidden_size, hidden_size) + self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size) - self.dropout = Dropout(config.hidden_dropout_prob) + self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states): @@ -164,7 +165,7 @@ class GPT3MLP(Module): return output -class GPT3TransformerLayer(Module): +class GPT3TransformerLayer(nn.Module): """A single transformer layer. Transformer layer takes input with size [s, b, h] and returns an @@ -175,14 +176,14 @@ class GPT3TransformerLayer(Module): super().__init__() # Layernorm on the input data. - self.input_layernorm = LayerNorm( + self.input_layernorm = nn.LayerNorm( config.hidden_size, eps=config.layernorm_epsilon) # Self attention. self.attention = GPT3SelfAttention(config) # Layernorm on the attention output - self.post_attention_layernorm = LayerNorm( + self.post_attention_layernorm = nn.LayerNorm( config.hidden_size, eps=config.layernorm_epsilon) # MLP @@ -208,7 +209,7 @@ class GPT3TransformerLayer(Module): return output -class GPT3Transformer(Module): +class GPT3Transformer(nn.Module): """Transformer class.""" def __init__(self, config): @@ -223,7 +224,7 @@ class GPT3Transformer(Module): [GPT3TransformerLayer(config) for _ in range(self.num_layers)]) # Final layer norm before output. - self.final_layernorm = LayerNorm( + self.final_layernorm = nn.LayerNorm( config.hidden_size, eps=config.layernorm_epsilon) def _get_layer(self, layer_number): @@ -242,7 +243,7 @@ class GPT3Transformer(Module): return hidden_states -class GPT3TransformerLanguageModel(Module): +class GPT3TransformerLanguageModel(nn.Module): """Transformer language model. Arguments: @@ -259,10 +260,11 @@ class GPT3TransformerLanguageModel(Module): super().__init__() # Embeddings. - self.word_embeddings = Embedding(config.vocab_size, config.hidden_size) - self.position_embeddings = Embedding(config.max_position_embeddings, - config.hidden_size) - self.embedding_dropout = Dropout(config.hidden_dropout_prob) + self.word_embeddings = nn.Embedding(config.vocab_size, + config.hidden_size) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob) # Transformer. self.transformer = GPT3Transformer(config) @@ -286,19 +288,19 @@ class GPT3Model(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, Linear): + if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_( mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, Embedding): + elif isinstance(module, nn.Embedding): module.weight.data.normal_( mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - elif isinstance(module, LayerNorm): + elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) @@ -325,7 +327,7 @@ class GPT3Model(PreTrainedModel): logits = self.language_model(input_ids, attention_mask, position_ids) loss = None if labels is not None: - loss_fct = CrossEntropyLoss() + loss_fct = nn.CrossEntropyLoss() loss = loss_fct( logits.view(-1, self.config.vocab_size), labels.view(-1)) return addict.Dict(loss=loss, logits=logits) @@ -346,3 +348,6 @@ class GPT3Model(PreTrainedModel): } model.load_state_dict(state_dict) return model + + def prepare_inputs_for_generation(self, input_ids, *args, **kwargs): + return {'input_ids': input_ids} diff --git a/modelscope/models/nlp/gpt3/configuration.py b/modelscope/models/nlp/gpt3/configuration.py new file mode 100644 index 00000000..66e8b836 --- /dev/null +++ b/modelscope/models/nlp/gpt3/configuration.py @@ -0,0 +1,111 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class GPT3Config(PretrainedConfig): + + model_type = 'gpt3' + + def __init__( + self, + vocab_size=25600, + hidden_size=768, + ffn_hidden_size=None, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=2048, + type_vocab_size=2, + layernorm_epsilon=1e-12, + bias_gelu_fusion=True, + fp32_residual_connection=False, + sequence_parallel=False, + fp16=False, + bf16=False, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=False, + kv_channels=None, + masked_softmax_fusion=True, + attention_dropout=0.1, + bias_dropout_fusion=True, + apply_residual_connection_post_layernorm=False, + hidden_dropout=0.1, + init_method_std=0.02, + # generate + eod_id=7, + tokens_to_generate=100, + top_k=0, + top_p=0.9, + **kwargs): + super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = 4 * hidden_size \ + if ffn_hidden_size is None else ffn_hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.layernorm_epsilon = layernorm_epsilon + self.bias_gelu_fusion = bias_gelu_fusion + self.fp32_residual_connection = fp32_residual_connection + self.sequence_parallel = sequence_parallel + self.fp16 = fp16 + self.bf16 = bf16 + assert not (fp16 and bf16) + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + if kv_channels is None: + assert hidden_size % num_attention_heads == 0 + self.kv_channels = hidden_size // num_attention_heads + self.masked_softmax_fusion = masked_softmax_fusion + self.attention_dropout = attention_dropout + self.bias_dropout_fusion = bias_dropout_fusion + self.apply_residual_connection_post_layernorm = \ + apply_residual_connection_post_layernorm + self.hidden_dropout = hidden_dropout + self.init_method_std = init_method_std + self.eod_id = eod_id + self.tokens_to_generate = tokens_to_generate + self.top_k = top_k + self.top_p = top_p + + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + self.no_persist_layer_norm = \ + TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11) + + @property + def params_dtype(self): + if self.fp16: + return torch.half + elif self.bf16: + return torch.bfloat16 + else: + return torch.float diff --git a/modelscope/models/nlp/gpt3/configuration_gpt3.py b/modelscope/models/nlp/gpt3/configuration_gpt3.py deleted file mode 100644 index d5a054fd..00000000 --- a/modelscope/models/nlp/gpt3/configuration_gpt3.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - -logger = logging.get_logger(__name__) - - -class GPT3Config(PretrainedConfig): - - model_type = 'gpt' - - def __init__(self, - vocab_size=25600, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act='gelu', - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=2048, - type_vocab_size=2, - layernorm_epsilon=1e-12, - **kwargs): - super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) - - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.hidden_act = hidden_act - self.intermediate_size = intermediate_size - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.layernorm_epsilon = layernorm_epsilon diff --git a/modelscope/models/nlp/gpt3/distributed_gpt3.py b/modelscope/models/nlp/gpt3/distributed_gpt3.py new file mode 100644 index 00000000..a0091259 --- /dev/null +++ b/modelscope/models/nlp/gpt3/distributed_gpt3.py @@ -0,0 +1,1057 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +from megatron import mpu +from megatron.global_vars import get_global_memory_buffer, set_global_variables +from megatron.model import (AttnMaskType, Float16Module, LayerNorm, + bias_gelu_impl) +from megatron.model.fused_softmax import FusedScaleMaskSoftmax +from torch import nn +from torch.nn import functional as F +from transformers.modeling_utils import PreTrainedModel + +from modelscope.models import TorchModel +from modelscope.models.nlp.gpt3 import GPT3Config +from modelscope.utils.nlp.distributed import initialize_distributed +from modelscope.utils.nlp.load_checkpoint import pre_load +from modelscope.utils.torch_utils import set_random_seed_mpu + + +class GPT3ParallelMLP(nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config, init_method, output_layer_init_method): + super().__init__() + + # Project to 4h. + self.dense_h_to_4h = mpu.ColumnParallelLinearV3( + config, + config.hidden_size, + config.ffn_hidden_size, + gather_output=False, + init_method=init_method, + skip_bias_add=True) + + self.bias_gelu_fusion = config.bias_gelu_fusion + self.activation_func = F.gelu + + # Project back to h. + self.dense_4h_to_h = mpu.RowParallelLinearV3( + config, + config.ffn_hidden_size, + config.hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + skip_bias_add=True) + + def forward(self, hidden_states): + + # [s, b, 4hp] + intermediate_parallel, bias_parallel = self.dense_h_to_4h( + hidden_states) + + if self.bias_gelu_fusion: + intermediate_parallel = \ + bias_gelu_impl(intermediate_parallel, bias_parallel) + else: + intermediate_parallel = \ + self.activation_func(intermediate_parallel + bias_parallel) + + # [s, b, h] + output, output_bias = self.dense_4h_to_h(intermediate_parallel) + return output, output_bias + + +class GPT3Embedding(nn.Module): + """Language model embeddings. + + Arguments: + hidden_size: hidden size + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + init_method: weight initialization method + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__(self, config, init_method): + super().__init__() + + self.hidden_size = config.hidden_size + self.init_method = init_method + + # Word embeddings (parallel). + self.word_embeddings = mpu.VocabParallelEmbedding( + config.vocab_size, self.hidden_size, init_method=self.init_method) + + # Position embedding (serial). + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + self.hidden_size) + # Initialize the position embeddings. + self.init_method(self.position_embeddings.weight) + + self.fp32_residual_connection = config.fp32_residual_connection + self.sequence_parallel = config.sequence_parallel + # Embeddings dropout + self.embedding_dropout = nn.Dropout(config.hidden_dropout) + + def zero_parameters(self): + """Zero out all parameters in embedding.""" + self.word_embeddings.weight.data.fill_(0) + self.word_embeddings.weight.shared = True + self.position_embeddings.weight.data.fill_(0) + self.position_embeddings.weight.shared = True + + def forward(self, input_ids, position_ids): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + embeddings = words_embeddings + position_embeddings + + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + + # Dropout. + if self.sequence_parallel: + embeddings = mpu.scatter_to_sequence_parallel_region(embeddings) + with mpu.get_cuda_rng_tracker().fork(): + embeddings = self.embedding_dropout(embeddings) + else: + embeddings = self.embedding_dropout(embeddings) + return embeddings + + +class NoopTransformerLayer(nn.Module): + + def __init__(self, layer_number): + super().__init__() + self.layer_number = layer_number + + def forward(self, + hidden_states, + attention_mask, + encoder_output=None, + enc_dec_attn_mask=None, + inference_params=None): + return hidden_states.clone() + + +def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + +class GPT3CoreAttention(nn.Module): + + def __init__(self, + config, + layer_number, + attn_mask_type=AttnMaskType.padding): + super().__init__() + self.fp16 = config.fp16 + self.bf16 = config.bf16 + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + self.attn_mask_type = attn_mask_type + self.sequence_parallel = config.sequence_parallel + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + world_size = mpu.get_model_parallel_world_size() + self.hidden_size_per_partition = mpu.divide(projection_size, + world_size) + self.hidden_size_per_attention_head = mpu.divide( + projection_size, config.num_attention_heads) + self.num_attention_heads_per_partition = mpu.divide( + config.num_attention_heads, world_size) + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + + self.scale_mask_softmax = FusedScaleMaskSoftmax( + self.fp16, self.bf16, self.attn_mask_type, + config.masked_softmax_fusion, attention_mask_func, + self.attention_softmax_in_fp32, coeff) + + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.attention_dropout = nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, value_layer, attention_mask): + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(1), query_layer.size(2), + query_layer.size(0), key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], + output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], + output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = get_global_memory_buffer().get_tensor( + (output_size[0] * output_size[1], output_size[2], output_size[3]), + query_layer.dtype, 'mpu') + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor)) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + attention_probs = self.scale_mask_softmax(attention_scores, + attention_mask) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + + if not self.sequence_parallel: + with mpu.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + else: + attention_probs = self.attention_dropout(attention_probs) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), value_layer.size(2), + query_layer.size(0), value_layer.size(3)) + + # change view [sk, b * np, hn] + value_layer = value_layer.view( + value_layer.size(0), output_size[0] * output_size[1], -1) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], + output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +class GPT3ParallelAttention(nn.Module): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config, init_method, output_layer_init_method, + layer_number): + super().__init__() + self.layer_number = max(1, layer_number) + self.params_dtype = config.params_dtype + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + world_size = mpu.get_model_parallel_world_size() + self.hidden_size_per_attention_head = mpu.divide( + projection_size, config.num_attention_heads) + self.num_attention_heads_per_partition = mpu.divide( + config.num_attention_heads, world_size) + + # Strided linear layer. + self.query_key_value = mpu.ColumnParallelLinearV3( + config, + config.hidden_size, + 3 * projection_size, + gather_output=False, + init_method=init_method) + + self.core_attention = GPT3CoreAttention(config, self.layer_number) + + # Output. + self.dense = mpu.RowParallelLinearV3( + config, + projection_size, + config.hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + skip_bias_add=True) + + def _allocate_memory(self, inference_max_sequence_len, batch_size): + return torch.empty( + inference_max_sequence_len, + batch_size, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + dtype=self.params_dtype, + device=torch.cuda.current_device()) + + def forward(self, hidden_states, attention_mask, inference_params=None): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + if inference_params: + if self.layer_number not in inference_params.key_value_memory_dict: + inf_max_seq_len = inference_params.max_sequence_len + inf_max_batch_size = inference_params.max_batch_size + inference_key_memory = self._allocate_memory( + inf_max_seq_len, inf_max_batch_size) + inference_value_memory = self._allocate_memory( + inf_max_seq_len, inf_max_batch_size) + inference_params.key_value_memory_dict[self.layer_number] = ( + inference_key_memory, inference_value_memory) + else: + inference_key_memory, inference_value_memory = \ + inference_params.key_value_memory_dict[self.layer_number] + + # ===================== + # Query, Key, and Value + # ===================== + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer, _ = self.query_key_value(hidden_states) + + # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, + value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) + + # ================================== + # Adjust key and value for inference + # ================================== + + if inference_params: + batch_start = inference_params.batch_size_offset + batch_end = batch_start + key_layer.size(1) + assert batch_end <= inference_key_memory.size(1) + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + key_layer.size(0) + assert sequence_end <= inference_key_memory.size(0) + # Copy key and values. + inference_key_memory[sequence_start:sequence_end, + batch_start:batch_end, ...] = key_layer + inference_value_memory[sequence_start:sequence_end, + batch_start:batch_end, ...] = value_layer + key_layer = inference_key_memory[:sequence_end, + batch_start:batch_end, ...] + value_layer = inference_value_memory[:sequence_end, + batch_start:batch_end, ...] + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, + value_layer, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + + output, bias = self.dense(context_layer) + + return output, bias + + +class nullcontext: + + def __init__(self, enter_result=None): + self.enter_result = enter_result + + def __enter__(self): + return self.enter_result + + def __exit__(self, *excinfo): + pass + + +def bias_dropout_add(x, bias, residual, prob, training): + # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor + out = torch.nn.functional.dropout(x + bias, p=prob, training=training) + out = residual + out + return out + + +def get_bias_dropout_add(training): + + def _bias_dropout_add(x, bias, residual, prob): + return bias_dropout_add(x, bias, residual, prob, training) + + return _bias_dropout_add + + +@torch.jit.script +def bias_dropout_add_fused_train(x: torch.Tensor, bias: torch.Tensor, + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, True) + + +@torch.jit.script +def bias_dropout_add_fused_inference(x: torch.Tensor, bias: torch.Tensor, + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, False) + + +class GPT3ParallelTransformerLayer(nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config, init_method, output_layer_init_method, + layer_number): + + super().__init__() + self.layer_number = layer_number + + self.apply_residual_connection_post_layernorm \ + = config.apply_residual_connection_post_layernorm + + self.bf16 = config.bf16 + self.fp32_residual_connection = config.fp32_residual_connection + + # Layernorm on the input data. + self.input_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon, + no_persist_layer_norm=config.no_persist_layer_norm, + sequence_parallel=config.sequence_parallel) + + # Self attention. + self.self_attention = GPT3ParallelAttention(config, init_method, + output_layer_init_method, + layer_number) + self.hidden_dropout = config.hidden_dropout + self.bias_dropout_fusion = config.bias_dropout_fusion + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon, + no_persist_layer_norm=config.no_persist_layer_norm, + sequence_parallel=config.sequence_parallel) + + # MLP + self.mlp = GPT3ParallelMLP(config, init_method, + output_layer_init_method) + + # Set bias+dropout+add fusion grad_enable execution handler. + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 + and TORCH_MINOR >= 10) + self.bias_dropout_add_exec_handler = \ + nullcontext if use_nvfuser else torch.enable_grad + + def forward(self, hidden_states, attention_mask, inference_params=None): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, attention_bias = \ + self.self_attention( + layernorm_output, + attention_mask, + inference_params=inference_params) + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + if self.bias_dropout_fusion: + if self.training: + bias_dropout_add_func = bias_dropout_add_fused_train + else: + bias_dropout_add_func = bias_dropout_add_fused_inference + else: + bias_dropout_add_func = get_bias_dropout_add(self.training) + + with self.bias_dropout_add_exec_handler(): + layernorm_input = bias_dropout_add_func( + attention_output, attention_bias.expand_as(residual), residual, + self.hidden_dropout) + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output, mlp_bias = self.mlp(layernorm_output) + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + with self.bias_dropout_add_exec_handler(): + output = bias_dropout_add_func(mlp_output, + mlp_bias.expand_as(residual), + residual, self.hidden_dropout) + + # Jit compiled function creates 'view' tensor. This tensor + # potentially gets saved in the MPU checkpoint function context, + # which rejects view tensors. While making a viewless tensor here + # won't result in memory savings (like the data loader, or + # p2p_communication), it serves to document the origin of this + # 'view' tensor. + output = mpu.make_viewless_tensor( + inp=output, requires_grad=output.requires_grad, keep_graph=True) + + return output + + +class GPT3ParallelTransformer(nn.Module): + """Transformer class.""" + + def __init__(self, + config, + init_method, + output_layer_init_method, + post_layer_norm=True, + pre_process=True, + post_process=True): + super().__init__() + + self.bf16 = config.bf16 + self.fp32_residual_connection = config.fp32_residual_connection + self.post_layer_norm = post_layer_norm + self.pre_process = pre_process + self.post_process = post_process + self.input_tensor = None + + self.sequence_parallel = config.sequence_parallel + + # Number of layers. + self.num_layers = config.num_hidden_layers + + # Transformer layers. + def build_layer(layer_number): + return GPT3ParallelTransformerLayer(config, init_method, + output_layer_init_method, + layer_number) + + if self.num_layers == 0: + self.num_layers = 1 + self.layers = torch.nn.ModuleList([NoopTransformerLayer(1)]) + else: + self.layers = torch.nn.ModuleList( + [build_layer(i + 1) for i in range(self.num_layers)]) + + if self.post_process and self.post_layer_norm: + # Final layer norm before output. + self.final_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon, + no_persist_layer_norm=config.no_persist_layer_norm, + sequence_parallel=config.sequence_parallel) + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def forward(self, hidden_states, attention_mask, inference_params=None): + # hidden_states: [s, b, h] + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = mpu.make_viewless_tensor( + hidden_states, + requires_grad=True, + keep_graph=True, + ) + + if self.sequence_parallel: + rng_context = mpu.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + with rng_context: + # Forward pass. + for index in range(self.num_layers): + layer = self._get_layer(index) + hidden_states = layer( + hidden_states, + attention_mask, + inference_params=inference_params) + + # Final layer norm. + if self.post_process and self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +class GPT3TransformerLanguageModel(nn.Module): + """Transformer language model. + + Arguments: + transformer_hparams: transformer hyperparameters + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__(self, config, init_method, output_layer_init_method): + super().__init__() + + self.hidden_size = config.hidden_size + self.init_method = init_method + self.encoder_hidden_state = None + + # Embeddings. + self.embedding = GPT3Embedding(config, self.init_method) + + # Transformer. + self.encoder = GPT3ParallelTransformer( + config, + self.init_method, + output_layer_init_method, + ) + + def forward(self, + enc_input_ids, + enc_position_ids, + enc_attn_mask, + inference_params=None, + enc_hidden_states=None): + + # Encoder embedding. + encoder_input = self.embedding(enc_input_ids, enc_position_ids) + + # Run encoder. + if enc_hidden_states is None: + if self.encoder is not None: + encoder_output = self.encoder( + encoder_input, + enc_attn_mask, + inference_params=inference_params) + else: + encoder_output = self.encoder_hidden_state + else: + encoder_output = enc_hidden_states.to(encoder_input.dtype) + + return encoder_output + + +def init_method_normal(sigma): + """Init method based on N(0, sigma).""" + + def init_(tensor): + return nn.init.normal_(tensor, mean=0.0, std=sigma) + + return init_ + + +def scaled_init_method_normal(sigma, num_layers): + """Init method based on N(0, sigma/sqrt(2*num_layers).""" + std = sigma / math.sqrt(2.0 * num_layers) + + def init_(tensor): + return nn.init.normal_(tensor, mean=0.0, std=std) + + return init_ + + +class GPT3Model(PreTrainedModel): + + config_class = GPT3Config + + def __init__(self, config, parallel_output=False): + super().__init__(config) + + self.parallel_output = parallel_output + + self.language_model = GPT3TransformerLanguageModel( + config, init_method_normal(config.init_method_std), + scaled_init_method_normal(config.init_method_std, + config.num_hidden_layers)) + + def word_embeddings_weight(self): + return self.language_model.embedding.word_embeddings.weight + + @staticmethod + def build_attention_mask_and_position_ids(tokens): + seq_length = tokens.size(1) + attention_mask = torch.tril( + torch.ones((1, 1, seq_length, seq_length), + dtype=torch.long, + device=tokens.device)) + attention_mask = (attention_mask < 0.5) + + position_ids = torch.arange( + seq_length, dtype=torch.long, device=tokens.device) + position_ids = position_ids.unsqueeze(0).expand_as(tokens) + + return attention_mask, position_ids + + def forward(self, + input_ids, + attention_mask=None, + position_ids=None, + inference_params=None, + **kwargs): + if attention_mask is None and position_ids is None: + attention_mask, position_ids = \ + self.build_attention_mask_and_position_ids(input_ids) + + lm_output = self.language_model( + input_ids, + position_ids, + attention_mask, + inference_params=inference_params) + + logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply( + lm_output, self.word_embeddings_weight(), None, False, True, + self.config.sequence_parallel) + # Gather if needed. + + output = logits_parallel + if not self.parallel_output: + output = mpu.gather_from_model_parallel_region(logits_parallel) + return output.transpose(0, 1).contiguous() + + +def modify_logits_for_top_k_filtering(logits, top_k): + """Set the logits for none top-k values to -inf.""" + + filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits.masked_fill_(filter_, float('-Inf')) + + +def modify_logits_for_top_p_filtering(logits, top_p): + """Set the logits for none top-p values to -inf.""" + + # First sort and calculate cumulative sum of probabilities. + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Filteration based on the cumulative sum. + filter_ = cumulative_probs > top_p + # This shift by 1 is weird and I cannot justify it. This existed + # in the original implementation: + # https://github.com/ari-holtzman/degen/blob/master/gen.py + # and I guess it is needed so keeping it for now. + filter_[:, 1:] = filter_[:, :-1].clone() + # Make sure we at least have one token to select from. + filter_[..., 0] = 0 + + # Fill in the filtered part + filter_ = filter_.scatter(1, sorted_indices, filter_) + logits.masked_fill_(filter_, float('-Inf')) + + +def sample(logits, top_k=0, top_p=0.0, temperature=1.0, vocab_size=None): + """ Sample and generate a token. + Note: logits has the dimension [b, v] where b is the batch size + and v is the vocabulary size. + If vocab_size is provided, we will make sure the sample that is + generated is in [0, vocab-size). This will avoid out of vocabulary + generations due to padding. + """ + + # Check logits for consistency. + assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.' + assert logits.type() == 'torch.cuda.FloatTensor', \ + 'input logits should be floats.' + + # Greedy is just simple argmax. + if top_k == 1: + assert top_p == 0.0, 'cannot set both greedy and top-p samplings.' + samples = torch.argmax(logits, dim=-1) + + # Top-k or top-p sampling. + else: + # Clone so we do not modify the inputs, + logits = logits.clone() + # Apply temperature in place. + if temperature != 1.0: + logits.div_(temperature) + + if top_k > 1: + assert top_p == 0.0, 'cannot set both top-k and top-p samplings.' + assert top_k <= logits.size(1), 'top-k is larger than logit size.' + if vocab_size: + assert top_k < vocab_size, 'top-k is larger than vocab size.' + modify_logits_for_top_k_filtering(logits, top_k) + + elif top_p > 0.0: + assert top_p <= 1.0, 'top-p should be in (0, 1].' + modify_logits_for_top_p_filtering(logits, top_p) + + # After filtering, we need to recalculate the distribution. + probs = logits.softmax(dim=-1) + samples = torch.multinomial(probs, num_samples=1).view(-1) + + # If vocab size is provided, make sure the samples are in + # in the range [0, vocab-size). + if vocab_size: + samples = torch.clamp(samples, min=0, max=(vocab_size - 1)) + + return samples + + +class InferenceParams: + """Inference parameters that are passed to the main model in order + to efficienly calculate and store the context during inference.""" + + def __init__(self, max_batch_size, max_sequence_len): + """Note that offsets are set to zero and we always set the + flag to allocate memory. After the first call, make sure to + set this flag to False.""" + self.max_sequence_len = max_sequence_len + self.max_batch_size = max_batch_size + self.sequence_len_offset = 0 + self.batch_size_offset = 0 + self.key_value_memory_dict = {} + + def swap_key_value_dict(self, batch_idx): + 'swap between batches' + if len(self.key_value_memory_dict) == 0: + raise ValueError('should not swap when dict in empty') + + for layer_number in self.key_value_memory_dict.keys(): + inference_key_memory, inference_value_memory = self.key_value_memory_dict[ + layer_number] + assert len(batch_idx) == inference_key_memory.shape[ + 1] # make sure batch size is the same + new_inference_key_memory = inference_key_memory[:, batch_idx] + new_inference_value_memory = inference_value_memory[:, batch_idx] + self.key_value_memory_dict[layer_number] = ( + new_inference_key_memory, new_inference_value_memory) + + +class DistributedGPT3(TorchModel): + + def __init__(self, + model_dir, + rank, + path_load_tag='model', + *args, + **kwargs): + super().__init__(model_dir, *args, **kwargs) + initialize_distributed(rank, mpu, kwargs['world_size'], + kwargs['model_parallel_size'], + kwargs['master_ip'], kwargs['master_port']) + seed = 0 if 'seed' not in kwargs else kwargs['seed'] + set_random_seed_mpu(seed) + set_global_variables() + + self.config = GPT3Config.from_pretrained(model_dir) + # Build model. + model = GPT3Model(self.config) + + for param in model.parameters(): + mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param) + + # GPU allocation. + model.cuda(torch.cuda.current_device()) + + # Fp16 conversion. + if self.config.fp16 or self.config.bf16: + model = Float16Module(model, self.config) + + self.dist_model = model + load_model = pre_load(mpu, model_dir, tag=path_load_tag) + self.dist_model.load_state_dict(load_model) + + self.inference_params = None + + def forward_step(self, tokens, attention_mask, position_ids): + logits = self.dist_model( + tokens, + attention_mask, + position_ids, + inference_params=self.inference_params) + self.inference_params.sequence_len_offset += tokens.size(1) + return logits + + def generate(self, + tokens, + temperature=1.0, + use_eod_token_for_early_termination=True, + stop_on_double_eol=False, + stop_on_eol=False): + lengths = torch.tensor([tokens.size(1)], device=tokens.device) + pads = torch.ones( + 1, self.config.tokens_to_generate, + device=tokens.device).long() * self.config.eod_id + tokens = torch.cat((tokens, pads), dim=-1) + + batch_size = tokens.size(0) + min_prompt_length = lengths.min().item() + max_sequence_length = tokens.size(1) + max_sequence_length = min(max_sequence_length, + self.config.max_position_embeddings) + + # If the context is too big, this happens + if min_prompt_length >= max_sequence_length: + raise ValueError('context length + tokens_to_generate too large') + + # Initialize inference parameters. + self.inference_params = InferenceParams(batch_size, + max_sequence_length) + + # Added termination_id to support the case that we want to terminate the + # generation once that id is generated. + termination_id = self.config.eod_id + + # Whether we have reached a termination id. + is_generation_done = torch.zeros( + batch_size, dtype=torch.uint8, device=torch.cuda.current_device()) + + # ============= + # Run infernece + # ============= + + with torch.no_grad(): + attention_mask, position_ids = \ + GPT3Model.build_attention_mask_and_position_ids(tokens) + prev_context_length = 0 + for context_length in range(min_prompt_length, + max_sequence_length): + + # Pick the slice that we need to pass through the network. + tokens2use = tokens[:, prev_context_length:context_length] + positions2use = position_ids[:, prev_context_length: + context_length] + attention_mask2use = attention_mask[ + ..., prev_context_length:context_length, :context_length] + + # logits will be meanigful only in the last pipeline stage. + logits = self.forward_step(tokens2use, attention_mask2use, + positions2use) + + # Sample. + last_token_logits = logits[:, -1, :] + new_sample = sample( + last_token_logits, + top_k=self.config.top_k, + top_p=self.config.top_p, + temperature=temperature, + vocab_size=self.config.vocab_size) + + # If a prompt length is smaller or equal th current context + # length, it means we have started generating tokens + started = lengths <= context_length + # Update the tokens. + tokens[started, context_length] = new_sample[started] + + # Update the context length for the next token generation. + prev_context_length = context_length + + # instead tokenization should be in the inference loop so stop sequences can be used + if stop_on_double_eol: + hit_double_eol = (new_sample + == 628).byte() & started.byte() + hit_two_eols = (new_sample == 198).byte() & ( + tokens[:, context_length - 1] + == 198).byte() & started.byte() + done_token = hit_double_eol | hit_two_eols + elif stop_on_eol: + hit_double_eol = (new_sample + == 628).byte() & started.byte() + hit_eol = (new_sample == 198).byte() & started.byte() + done_token = hit_double_eol | hit_eol + else: + done_token = (new_sample == termination_id).byte() & \ + started.byte() + + is_generation_done = is_generation_done | done_token + done = torch.all(is_generation_done) + + if use_eod_token_for_early_termination and done: + break + + tokens = tokens[:, :(context_length + 1)] + return tokens diff --git a/modelscope/models/nlp/gpt3/gpt3_for_text_generation.py b/modelscope/models/nlp/gpt3/text_generation.py similarity index 100% rename from modelscope/models/nlp/gpt3/gpt3_for_text_generation.py rename to modelscope/models/nlp/gpt3/text_generation.py diff --git a/modelscope/models/nlp/gpt3/tokenizer.py b/modelscope/models/nlp/gpt3/tokenizer.py new file mode 100644 index 00000000..5780ddbd --- /dev/null +++ b/modelscope/models/nlp/gpt3/tokenizer.py @@ -0,0 +1,69 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tokenizers import Tokenizer + + +class JiebaBPETokenizer: + """SentencePiece BPE tokenizer with Jieba integration""" + + def __init__(self, tokenizer_json_file): + self.name = 'Jieba BPE Tokenizer' + + self.tokenizer = Tokenizer.from_file(tokenizer_json_file) + self.eod_id = self.tokenizer.token_to_id('<|endoftext|>') + try: + import jieba + except ImportError: + raise ImportError( + 'You need to install rjieba to use JiebaTokenizer. ' + 'See https://pypi.org/project/rjieba/ for installation.') + self.jieba = jieba + self.new_line = self.vocab['\n'] + self.sep_token = self.vocab[''] + + @property + def vocab_size(self): + return self.tokenizer.get_vocab_size(with_added_tokens=True) + + @property + def vocab(self): + return self.tokenizer.get_vocab(with_added_tokens=True) + + @property + def inv_vocab(self): + vocab = self.vocab + inv_vocab = dict() + for key, val in vocab.items(): + inv_vocab[val] = key + return inv_vocab + + def tokenize(self, text, is_code=False): + """ + """ + if not is_code: + seg_list = [x for x in self.jieba.cut(text)] + return self.tokenizer.encode( + seg_list, is_pretokenized=True, add_special_tokens=True).ids + else: + return self.tokenizer.encode( + text, is_pretokenized=False, add_special_tokens=True).ids + + def detokenize(self, token_ids): + text = self.tokenizer.decode(token_ids, skip_special_tokens=False) + return text + + @property + def eod(self): + return self.eod_id diff --git a/modelscope/models/nlp/gpt_neo/__init__.py b/modelscope/models/nlp/gpt_neo/__init__.py new file mode 100644 index 00000000..ef5fdee5 --- /dev/null +++ b/modelscope/models/nlp/gpt_neo/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .backbone import GPTNeoModel +else: + _import_structure = { + 'backbone': ['GPTNeoModel'], + } + import sys + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/gpt_neo/backbone.py b/modelscope/models/nlp/gpt_neo/backbone.py new file mode 100644 index 00000000..a809bcde --- /dev/null +++ b/modelscope/models/nlp/gpt_neo/backbone.py @@ -0,0 +1,16 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from transformers import GPTNeoConfig +from transformers import GPTNeoModel as GPTNeoModelTransform + +from modelscope.metainfo import Models +from modelscope.models.builder import BACKBONES +from modelscope.utils.constant import Tasks + + +@BACKBONES.register_module( + group_key=Tasks.backbone, module_name=Models.gpt_neo) +class GPTNeoModel(GPTNeoModelTransform): + + def __init__(self, **kwargs): + config = GPTNeoConfig(**kwargs) + super().__init__(config) diff --git a/modelscope/models/nlp/heads/infromation_extraction_head.py b/modelscope/models/nlp/heads/infromation_extraction_head.py index 6c3388f0..626f1b59 100644 --- a/modelscope/models/nlp/heads/infromation_extraction_head.py +++ b/modelscope/models/nlp/heads/infromation_extraction_head.py @@ -10,6 +10,8 @@ from modelscope.utils.constant import Tasks @HEADS.register_module( Tasks.information_extraction, module_name=Heads.information_extraction) +@HEADS.register_module( + Tasks.relation_extraction, module_name=Heads.information_extraction) class InformationExtractionHead(TorchHead): def __init__(self, **kwargs): diff --git a/modelscope/models/nlp/heads/text_generation_head.py b/modelscope/models/nlp/heads/text_generation_head.py new file mode 100644 index 00000000..606d5a1f --- /dev/null +++ b/modelscope/models/nlp/heads/text_generation_head.py @@ -0,0 +1,35 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Dict + +import torch +import torch.nn.functional as F +from torch import nn + +from modelscope.metainfo import Heads +from modelscope.models.base import TorchHead +from modelscope.models.builder import HEADS +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks + + +@HEADS.register_module( + Tasks.text_generation, module_name=Heads.text_generation) +class TextGenerationHead(TorchHead): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + config = self.config + self.linear = nn.Linear( + config['hidden_size'], config['vocab_size'], bias=False) + + def get_output_embeddings(self): + return self.linear + + def forward(self, inputs=None): + logits = self.linear(inputs) + return {OutputKeys.LOGITS: logits} + + def compute_loss(self, outputs: Dict[str, torch.Tensor], + labels) -> Dict[str, torch.Tensor]: + logits = outputs[OutputKeys.LOGITS] + return {OutputKeys.LOSS: F.cross_entropy(logits, labels)} diff --git a/modelscope/models/nlp/heads/token_classification_head.py b/modelscope/models/nlp/heads/token_classification_head.py index ace3deac..443f93df 100644 --- a/modelscope/models/nlp/heads/token_classification_head.py +++ b/modelscope/models/nlp/heads/token_classification_head.py @@ -14,6 +14,8 @@ from modelscope.utils.constant import Tasks @HEADS.register_module( Tasks.token_classification, module_name=Heads.token_classification) +@HEADS.register_module( + Tasks.part_of_speech, module_name=Heads.token_classification) class TokenClassificationHead(TorchHead): def __init__(self, **kwargs): @@ -35,9 +37,9 @@ class TokenClassificationHead(TorchHead): sequence_output = inputs sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) - return {OutputKeys.LOGITS: logits} + return logits def compute_loss(self, outputs: Dict[str, torch.Tensor], labels) -> Dict[str, torch.Tensor]: logits = outputs[OutputKeys.LOGITS] - return {OutputKeys.LOSS: F.cross_entropy(logits, labels)} + return F.cross_entropy(logits, labels) diff --git a/modelscope/models/nlp/masked_language.py b/modelscope/models/nlp/masked_language.py deleted file mode 100644 index b7a890c1..00000000 --- a/modelscope/models/nlp/masked_language.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -from modelscope.metainfo import Models -from modelscope.models.base import TorchModel -from modelscope.models.builder import MODELS -from modelscope.models.nlp.bert import \ - BertForMaskedLM as BertForMaskedLMTransformer -from modelscope.models.nlp.deberta_v2 import \ - DebertaV2ForMaskedLM as DebertaV2ForMaskedLMTransformer -from modelscope.models.nlp.structbert import SbertForMaskedLM -from modelscope.models.nlp.veco import \ - VecoForMaskedLM as VecoForMaskedLMTransformer -from modelscope.outputs import OutputKeys -from modelscope.utils.constant import Tasks - -__all__ = ['BertForMaskedLM', 'StructBertForMaskedLM', 'VecoForMaskedLM'] - - -@MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert) -class StructBertForMaskedLM(TorchModel, SbertForMaskedLM): - """Structbert for MLM model. - - Inherited from structbert.SbertForMaskedLM and TorchModel, so this class can be registered into Model sets. - """ - - def __init__(self, config, model_dir): - super(TorchModel, self).__init__(model_dir) - SbertForMaskedLM.__init__(self, config) - - def forward(self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - labels=None): - output = SbertForMaskedLM.forward( - self, - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - labels=labels) - output[OutputKeys.INPUT_IDS] = input_ids - return output - - @classmethod - def _instantiate(cls, **kwargs): - model_dir = kwargs.get('model_dir') - return super(SbertForMaskedLM, StructBertForMaskedLM).from_pretrained( - pretrained_model_name_or_path=model_dir, model_dir=model_dir) - - -@MODELS.register_module(Tasks.fill_mask, module_name=Models.bert) -class BertForMaskedLM(TorchModel, BertForMaskedLMTransformer): - """Bert for MLM model. - - Inherited from transformers.BertForMaskedLM and TorchModel, so this class can be registered into Model sets. - """ - - def __init__(self, config, model_dir): - super(TorchModel, self).__init__(model_dir) - BertForMaskedLMTransformer.__init__(self, config) - - def forward(self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - labels=None): - output = BertForMaskedLMTransformer.forward( - self, - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - labels=labels) - output[OutputKeys.INPUT_IDS] = input_ids - return output - - @classmethod - def _instantiate(cls, **kwargs): - model_dir = kwargs.get('model_dir') - return super(BertForMaskedLMTransformer, - BertForMaskedLM).from_pretrained( - pretrained_model_name_or_path=model_dir, - model_dir=model_dir) - - -@MODELS.register_module(Tasks.fill_mask, module_name=Models.veco) -class VecoForMaskedLM(TorchModel, VecoForMaskedLMTransformer): - """Veco for MLM model. - - Inherited from veco.VecoForMaskedLM and TorchModel, so this class can be registered into Model sets. - """ - - def __init__(self, config, model_dir): - super(TorchModel, self).__init__(model_dir) - VecoForMaskedLMTransformer.__init__(self, config) - - def forward(self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - labels=None): - output = VecoForMaskedLMTransformer.forward( - self, - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - labels=labels) - output[OutputKeys.INPUT_IDS] = input_ids - return output - - @classmethod - def _instantiate(cls, **kwargs): - model_dir = kwargs.get('model_dir') - return super(VecoForMaskedLMTransformer, - VecoForMaskedLM).from_pretrained( - pretrained_model_name_or_path=model_dir, - model_dir=model_dir) - - -@MODELS.register_module(Tasks.fill_mask, module_name=Models.deberta_v2) -class DebertaV2ForMaskedLM(TorchModel, DebertaV2ForMaskedLMTransformer): - """Deberta v2 for MLM model. - - Inherited from deberta_v2.DebertaV2ForMaskedLM and TorchModel, so this class can be registered into Model sets. - """ - - def __init__(self, config, model_dir): - super(TorchModel, self).__init__(model_dir) - DebertaV2ForMaskedLMTransformer.__init__(self, config) - - def forward(self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - labels=None): - output = DebertaV2ForMaskedLMTransformer.forward( - self, - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - labels=labels) - output[OutputKeys.INPUT_IDS] = input_ids - return output - - @classmethod - def _instantiate(cls, **kwargs): - model_dir = kwargs.get('model_dir') - return super(DebertaV2ForMaskedLMTransformer, - DebertaV2ForMaskedLM).from_pretrained( - pretrained_model_name_or_path=model_dir, - model_dir=model_dir) diff --git a/modelscope/models/nlp/palm_v2/__init__.py b/modelscope/models/nlp/palm_v2/__init__.py index 3a9960ec..45ab6621 100644 --- a/modelscope/models/nlp/palm_v2/__init__.py +++ b/modelscope/models/nlp/palm_v2/__init__.py @@ -17,19 +17,19 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .configuration_palm import PalmConfig - from .modeling_palm import ( + from .configuration import PalmConfig + from .backbone import ( AbsSummarizer, PalmForConditionalGeneration, Translator, ) - from .palm_for_text_generation import PalmForTextGeneration + from .text_generation import PalmForTextGeneration else: _import_structure = { - 'configuration_palm': ['PalmConfig'], - 'modeling_palm': + 'configuration': ['PalmConfig'], + 'backbone': ['AbsSummarizer', 'PalmForConditionalGeneration', 'Translator'], - 'palm_for_text_generation': ['PalmForTextGeneration'], + 'text_generation': ['PalmForTextGeneration'], } import sys diff --git a/modelscope/models/nlp/palm_v2/modeling_palm.py b/modelscope/models/nlp/palm_v2/backbone.py similarity index 99% rename from modelscope/models/nlp/palm_v2/modeling_palm.py rename to modelscope/models/nlp/palm_v2/backbone.py index f395ebd4..3e0ff805 100644 --- a/modelscope/models/nlp/palm_v2/modeling_palm.py +++ b/modelscope/models/nlp/palm_v2/backbone.py @@ -35,7 +35,7 @@ from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel from modelscope.utils import logger as logging -from .configuration_palm import PalmConfig +from .configuration import PalmConfig from .dureader_eval import compute_bleu_rouge, normalize CONFIG_NAME = 'config.json' diff --git a/modelscope/models/nlp/palm_v2/configuration_palm.py b/modelscope/models/nlp/palm_v2/configuration.py similarity index 100% rename from modelscope/models/nlp/palm_v2/configuration_palm.py rename to modelscope/models/nlp/palm_v2/configuration.py diff --git a/modelscope/models/nlp/palm_v2/palm_for_text_generation.py b/modelscope/models/nlp/palm_v2/text_generation.py similarity index 100% rename from modelscope/models/nlp/palm_v2/palm_for_text_generation.py rename to modelscope/models/nlp/palm_v2/text_generation.py diff --git a/modelscope/models/nlp/passage_ranking.py b/modelscope/models/nlp/passage_ranking.py deleted file mode 100644 index 2a06ce45..00000000 --- a/modelscope/models/nlp/passage_ranking.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -from typing import Any, Dict - -import numpy as np -import torch - -from modelscope.metainfo import Models -from modelscope.models import TorchModel -from modelscope.models.builder import MODELS -from modelscope.models.nlp import SbertForSequenceClassification -from modelscope.models.nlp.structbert import SbertPreTrainedModel -from modelscope.outputs import OutputKeys -from modelscope.utils.constant import Tasks - -__all__ = ['PassageRanking'] - - -@MODELS.register_module(Tasks.passage_ranking, module_name=Models.bert) -class PassageRanking(SbertForSequenceClassification, SbertPreTrainedModel): - base_model_prefix: str = 'bert' - supports_gradient_checkpointing = True - _keys_to_ignore_on_load_missing = [r'position_ids'] - - def __init__(self, config, model_dir, *args, **kwargs): - if hasattr(config, 'base_model_prefix'): - PassageRanking.base_model_prefix = config.base_model_prefix - super().__init__(config, model_dir) - self.train_batch_size = kwargs.get('train_batch_size', 4) - self.register_buffer( - 'target_label', - torch.zeros(self.train_batch_size, dtype=torch.long)) - - def build_base_model(self): - from .structbert import SbertModel - return SbertModel(self.config, add_pooling_layer=True) - - def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: - outputs = self.base_model.forward(**input) - - # backbone model should return pooled_output as its second output - pooled_output = outputs[1] - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - if self.base_model.training: - scores = logits.view(self.train_batch_size, -1) - loss_fct = torch.nn.CrossEntropyLoss() - loss = loss_fct(scores, self.target_label) - return {OutputKeys.LOGITS: logits, OutputKeys.LOSS: loss} - return {OutputKeys.LOGITS: logits} - - def sigmoid(self, logits): - return np.exp(logits) / (1 + np.exp(logits)) - - def postprocess(self, inputs: Dict[str, np.ndarray], - **kwargs) -> Dict[str, np.ndarray]: - logits = inputs['logits'].squeeze(-1).detach().cpu().numpy() - logits = self.sigmoid(logits).tolist() - result = {OutputKeys.SCORES: logits} - return result - - @classmethod - def _instantiate(cls, **kwargs): - """Instantiate the model. - - @param kwargs: Input args. - model_dir: The model dir used to load the checkpoint and the label information. - num_labels: An optional arg to tell the model how many classes to initialize. - Method will call utils.parse_label_mapping if num_labels not supplied. - If num_labels is not found, the model will use the default setting (1 classes). - @return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained - """ - - num_labels = kwargs.get('num_labels', 1) - model_args = {} if num_labels is None else {'num_labels': num_labels} - - return super(SbertPreTrainedModel, PassageRanking).from_pretrained( - pretrained_model_name_or_path=kwargs.get('model_dir'), - model_dir=kwargs.get('model_dir'), - **model_args) diff --git a/modelscope/models/nlp/plug/__init__.py b/modelscope/models/nlp/plug/__init__.py index dbc20751..589a636a 100644 --- a/modelscope/models/nlp/plug/__init__.py +++ b/modelscope/models/nlp/plug/__init__.py @@ -4,13 +4,13 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .configuration_plug import PlugNLGConfig - from .modeling_plug import PlugModel + from .configuration import PlugNLGConfig + from .backbone import PlugModel from .distributed_plug import DistributedPlug else: _import_structure = { - 'configuration_plug': ['PlugNLGConfig'], - 'modeling_plug': ['PlugModel'], + 'configuration': ['PlugNLGConfig'], + 'backbone': ['PlugModel'], 'distributed_plug': ['DistributedPlug'], } diff --git a/modelscope/models/nlp/plug/modeling_plug.py b/modelscope/models/nlp/plug/backbone.py similarity index 84% rename from modelscope/models/nlp/plug/modeling_plug.py rename to modelscope/models/nlp/plug/backbone.py index 9d2bb14f..7f3f12de 100644 --- a/modelscope/models/nlp/plug/modeling_plug.py +++ b/modelscope/models/nlp/plug/backbone.py @@ -28,7 +28,7 @@ from torch import nn from modelscope.utils.nlp.distributed import (normal_init_method, scaled_init_method) -from .configuration_plug import PlugNLGConfig, PlugNLUConfig +from .configuration import PlugNLGConfig, PlugNLUConfig logger = logging.getLogger(__name__) @@ -152,15 +152,7 @@ class BertSelfOutput(nn.Module): bias=True, input_is_parallel=True, stride=1, - init_method=init_method, - pruning_method=config.pruning_method if config.pruning_module in [ - 'all', 'encoder', 'encoder_self', 'encoder_selfvo', - 'encoder_selfo' - ] else None, - pruning_mask_init=config.pruning_mask_init, - pruning_mask_scale=config.pruning_mask_scale, - LR_weight_rank=config.LR_weight_rank, - LR_mask_rank=config.LR_mask_rank) + init_method=init_method) self.fp32_layernorm = config.fp32_layernorm if not config.pre_ln: self.LayerNorm = BertLayerNorm( @@ -173,12 +165,8 @@ class BertSelfOutput(nn.Module): self, hidden_states, input_tensor, - pruning_threshold=None, ): - hidden_states = self.dense( - hidden_states, - pruning_threshold=pruning_threshold, - ) + hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) ln_input = hidden_states + input_tensor if self.LayerNorm is not None: @@ -210,20 +198,13 @@ class BertAttention(nn.Module): output_parallel=True, init_method=normal_init_method( mean=0.0, std=config.initializer_range), - separate=config.attn_separate, - pruning_method=config.pruning_method, - pruning_mask_init=config.pruning_mask_init, - pruning_mask_scale=config.pruning_mask_scale, - pruning_module=config.pruning_module, - LR_weight_rank=config.LR_weight_rank, - LR_mask_rank=config.LR_mask_rank) + separate=config.attn_separate) self.output = BertSelfOutput(config) def forward( self, input_tensor, attention_mask, - pruning_threshold=None, ): if self.LayerNorm is not None: ln_input = input_tensor @@ -236,20 +217,16 @@ class BertAttention(nn.Module): self_output = self.self( ln_output, attention_mask, - pruning_threshold=pruning_threshold, ) else: self_output = self.self( input_tensor, attention_mask, - pruning_threshold=pruning_threshold, ) - output_pruning_threshold = pruning_threshold attention_output = self.output( self_output, input_tensor, - pruning_threshold=output_pruning_threshold, ) return attention_output @@ -265,25 +242,15 @@ class BertIntermediate(nn.Module): gather_output=False, stride=1, init_method=normal_init_method( - mean=0.0, std=config.initializer_range), - pruning_method=config.pruning_method if config.pruning_module - in ['all', 'encoder', 'encoder_ffn'] else None, - pruning_mask_init=config.pruning_mask_init, - pruning_mask_scale=config.pruning_mask_scale, - LR_weight_rank=config.LR_weight_rank, - LR_mask_rank=config.LR_mask_rank) + mean=0.0, std=config.initializer_range)) self.intermediate_act_fn = ACT2FN[config.hidden_act] \ if isinstance(config.hidden_act, str) else config.hidden_act def forward( self, hidden_states, - pruning_threshold=None, ): - hidden_states = self.dense( - hidden_states, - pruning_threshold=pruning_threshold, - ) + hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states @@ -306,13 +273,7 @@ class BertOutput(nn.Module): bias=True, input_is_parallel=True, stride=1, - init_method=init_method, - pruning_method=config.pruning_method if config.pruning_module - in ['all', 'encoder', 'encoder_ffn'] else None, - pruning_mask_init=config.pruning_mask_init, - pruning_mask_scale=config.pruning_mask_scale, - LR_weight_rank=config.LR_weight_rank, - LR_mask_rank=config.LR_mask_rank) + init_method=init_method) self.fp32_layernorm = config.fp32_layernorm if not config.pre_ln: self.LayerNorm = BertLayerNorm( @@ -325,12 +286,8 @@ class BertOutput(nn.Module): self, hidden_states, input_tensor, - pruning_threshold=None, ): - hidden_states = self.dense( - hidden_states, - pruning_threshold=pruning_threshold, - ) + hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) ln_input = hidden_states + input_tensor if self.LayerNorm is not None: @@ -359,14 +316,8 @@ class BertLayer(nn.Module): else: self.LayerNorm = None - def forward( - self, - hidden_states, - attention_mask, - pruning_threshold=None, - ): - attention_output = self.attention( - hidden_states, attention_mask, pruning_threshold=pruning_threshold) + def forward(self, hidden_states, attention_mask): + attention_output = self.attention(hidden_states, attention_mask) if self.LayerNorm is not None: ln_input = attention_output previous_type = attention_output.type() @@ -375,15 +326,10 @@ class BertLayer(nn.Module): ln_output = self.LayerNorm(ln_input) if self.fp32_layernorm: ln_output = ln_output.type(previous_type) - intermediate_output = self.intermediate( - ln_output, pruning_threshold=pruning_threshold) + intermediate_output = self.intermediate(ln_output) else: - intermediate_output = self.intermediate( - attention_output, pruning_threshold=pruning_threshold) - layer_output = self.output( - intermediate_output, - attention_output, - pruning_threshold=pruning_threshold) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) return layer_output @@ -407,7 +353,6 @@ class BertEncoder(nn.Module): output_all_encoded_layers=True, checkpoint_activations=False, detach_index=-1, - pruning_threshold=None, ): all_encoder_layers = [] @@ -417,8 +362,7 @@ class BertEncoder(nn.Module): layers = self.layer[start:end] x_ = inputs[0] for layer in layers: - x_ = layer( - x_, inputs[1], pruning_threshold=pruning_threshold) + x_ = layer(x_, inputs[1]) return x_ return custom_forward @@ -654,7 +598,6 @@ class BertModel(PreTrainedBertModel): output_all_encoded_layers=True, checkpoint_activations=False, detach_index=-1, - pruning_threshold=None, ): if attention_mask is None: attention_mask = torch.ones_like(input_ids) @@ -683,8 +626,7 @@ class BertModel(PreTrainedBertModel): extended_attention_mask, output_all_encoded_layers=output_all_encoded_layers, checkpoint_activations=checkpoint_activations, - detach_index=detach_index, - pruning_threshold=pruning_threshold) + detach_index=detach_index) sequence_output = encoded_layers[-1] for p in self.pooler.parameters(): if p is None: @@ -709,18 +651,6 @@ class DecodeLayer(nn.Module): std=config.initializer_range, num_layers=config.num_hidden_layers) - self_pruning_method = config.pruning_method - cross_pruning_method = config.pruning_method - ffn_pruning_method = config.pruning_method - - if config.ft_module is not None: - if 'decoder_self' in config.ft_module: - self_pruning_method = 'finetune' - if 'decoder_cross' in config.ft_module: - cross_pruning_method = 'finetune' - if 'decoder_ffn' in config.ft_module: - ffn_pruning_method = 'finetune' - self.attention = mpu.GPT2ParallelSelfAttention( hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, @@ -728,13 +658,6 @@ class DecodeLayer(nn.Module): output_dropout_prob=config.hidden_dropout_prob, init_method=init_method, output_layer_init_method=output_layer_init_method, - pruning_method=self_pruning_method if config.pruning_module in [ - 'all', 'decoder', 'decoder_self', 'decoder_self+ffn' - ] else None, - pruning_mask_init=config.pruning_mask_init, - pruning_mask_scale=config.pruning_mask_scale, - LR_weight_rank=config.LR_weight_rank, - LR_mask_rank=config.LR_mask_rank, ) self.cross_attention = mpu.PalmParallelCrossAttention( @@ -745,12 +668,6 @@ class DecodeLayer(nn.Module): init_method=init_method, attn_separate=False, output_layer_init_method=output_layer_init_method, - pruning_method=cross_pruning_method, - pruning_mask_init=config.pruning_mask_init, - pruning_mask_scale=config.pruning_mask_scale, - pruning_module=config.pruning_module, - LR_weight_rank=config.LR_weight_rank, - LR_mask_rank=config.LR_mask_rank, ) self.input_layernorm = BertLayerNorm( @@ -765,12 +682,6 @@ class DecodeLayer(nn.Module): config.intermediate_size, gather_output=False, init_method=init_method, - pruning_method=ffn_pruning_method if config.pruning_module - in ['all', 'decoder', 'decoder_ffn', 'decoder_self+ffn'] else None, - pruning_mask_init=config.pruning_mask_init, - pruning_mask_scale=config.pruning_mask_scale, - LR_weight_rank=config.LR_weight_rank, - LR_mask_rank=config.LR_mask_rank, ) self.intermediate_act_fn = ACT2FN[config.hidden_act] \ if isinstance(config.hidden_act, str) else config.hidden_act @@ -779,12 +690,6 @@ class DecodeLayer(nn.Module): config.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, - pruning_method=ffn_pruning_method if config.pruning_module - in ['all', 'decoder', 'decoder_ffn', 'decoder_self+ffn'] else None, - pruning_mask_init=config.pruning_mask_init, - pruning_mask_scale=config.pruning_mask_scale, - LR_weight_rank=config.LR_weight_rank, - LR_mask_rank=config.LR_mask_rank, ) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) @@ -804,8 +709,7 @@ class DecodeLayer(nn.Module): enc_hidden_states, enc_attn_mask, dec_attn_mask, - is_infer=False, - pruning_threshold=None): + is_infer=False): residual = hidden_states previous_type = hidden_states.type() hidden_states = self.input_layernorm( @@ -813,10 +717,7 @@ class DecodeLayer(nn.Module): if self.fp32_layernorm: hidden_states = hidden_states.type(previous_type) hidden_states = self.attention( - hidden_states, - dec_attn_mask, - is_infer=is_infer, - pruning_threshold=pruning_threshold) + hidden_states, dec_attn_mask, is_infer=is_infer) hidden_states = residual + hidden_states @@ -825,23 +726,18 @@ class DecodeLayer(nn.Module): self.type_converter(hidden_states)) if self.fp32_layernorm: hidden_states = hidden_states.type(previous_type) - hidden_states = self.cross_attention( - hidden_states, - enc_hidden_states, - enc_attn_mask, - pruning_threshold=pruning_threshold) + hidden_states = self.cross_attention(hidden_states, enc_hidden_states, + enc_attn_mask) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_cross_attention_layernorm( self.type_converter(hidden_states)) if self.fp32_layernorm: hidden_states = hidden_states.type(previous_type) - hidden_states = self.intermediate( - hidden_states, pruning_threshold=pruning_threshold) + hidden_states = self.intermediate(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) - hidden_states = self.output( - hidden_states, pruning_threshold=pruning_threshold) + hidden_states = self.output(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = residual + hidden_states @@ -866,8 +762,7 @@ class BertDecoder(nn.Module): dec_attn_mask, checkpoint_activations=False, output_all_encoded_layers=False, - is_infer=False, - pruning_threshold=None): + is_infer=False): def custom(start, end): @@ -880,8 +775,7 @@ class BertDecoder(nn.Module): inputs[1], inputs[2], dec_attn_mask * 1, - is_infer=is_infer, - pruning_threshold=pruning_threshold) + is_infer=is_infer) return x_ return custom_forward @@ -904,8 +798,7 @@ class BertDecoder(nn.Module): enc_hidden_states, enc_attn_mask, dec_attn_mask, - is_infer=is_infer, - pruning_threshold=pruning_threshold) + is_infer=is_infer) previous_type = hidden_states.type() if self.fp32_layernorm: @@ -932,8 +825,7 @@ class DecodeModel(PreTrainedBertModel): enc_attn_mask=None, dec_attn_mask=None, checkpoint_activations=False, - is_infer=False, - pruning_threshold=None): + is_infer=False): extended_attention_mask = enc_attn_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = extended_attention_mask.to( dtype=next(self.decoder.parameters()).dtype) # fp16 compatibility @@ -946,8 +838,7 @@ class DecodeModel(PreTrainedBertModel): extended_attention_mask, dec_attn_mask, checkpoint_activations=False, - is_infer=is_infer, - pruning_threshold=pruning_threshold) + is_infer=is_infer) return sequence_output[-1] @@ -972,16 +863,14 @@ class PalmForPreTraining(PreTrainedBertModel): checkpoint_activations=False, is_infer=False, sequence_output=None, - parallel_output=True, - pruning_threshold=None): + parallel_output=True): if sequence_output is None: sequence_output, pooled_output = self.bert( input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, - checkpoint_activations=checkpoint_activations, - pruning_threshold=pruning_threshold) + checkpoint_activations=checkpoint_activations) prediction_scores, seq_relationship_score = self.cls( sequence_output, pooled_output) else: @@ -998,8 +887,7 @@ class PalmForPreTraining(PreTrainedBertModel): attention_mask, decode_attention_mask, checkpoint_activations=checkpoint_activations, - is_infer=is_infer, - pruning_threshold=pruning_threshold) + is_infer=is_infer) transformer_output_parallel = mpu.copy_to_model_parallel_region( decode_output) @@ -1017,6 +905,29 @@ class PalmForPreTraining(PreTrainedBertModel): class PlugModel(torch.nn.Module): + """ + The bare Plug Model transformer outputting raw hidden-states without any specific head on top. + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`PlugNLGConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~DistributedPlug.initialize_model`] method to load the model weights. + Example: + + ```python + >>> # The PLUG model has 27B parameters and usually need to run on multiple GPUs. The example given + >>> # here only initializes a slice of the model on a single GPU. + >>> # Check out the [`~DistributedPipeline.__init__`] method to initialize entire PLUG model. + >>> from modelscope.models.nlp.plug import PlugNLGConfig, PlugModel + + >>> # Initializing a Plug configuration + >>> configuration = PlugNLGConfig() + + >>> # Initializing a model from the configuration + >>> model = PlugModel(configuration) + """ def __init__(self, config): super(PlugModel, self).__init__() @@ -1034,6 +945,58 @@ class PlugModel(torch.nn.Module): is_infer=False, sequence_output=None, parallel_output=True): + """ + Parameters: + input_tokens (`torch.LongTensor` of shape `(batch_size, input_tokens_length)`): + `input_tokens_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary. + Indices can be obtained using transformers [`BertTokenizer`]. See + [`TextGenerationPreprocessor.__call__`] for details. + token_type_ids (`torch.LongTensor` of shape `(batch_size, input_tokens_length)`, *optional*, defaults to + None): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to None): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + target_tokens (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to None): + Target token ids(labels) for language modeling. Note that the labels **are shifted** inside the model, + i.e. you can set `target_tokens = input_tokens` Indices are selected in + `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only + computed for labels in `[0, ..., config.vocab_size]` + + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to None): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.max_position_embeddings - 1]`. + + decode_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults + to None): + Mask to avoid performing attention on padding token indices of target tokens. Mask values selected in + `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + checkpoint_activations (`boolean`, *optional*, defaults to `False`): + Whether gradient checkpointing is activated for this model or not. + is_infer (`boolean`, *optional*, defaults to `False`): + Whether or not to perform single inference. + sequence_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*, + defaults to None): + Also known as last_hidden_state. Sequence of hidden-states at the output of the last layer of the + model. A single forward() call can produce one single token. To generate the current token, the + sequence_output generated by the `forward()` of the previous token is required. + parallel_output (`boolean`, *optional*, defaults to `True`): + To parallel return output, or gather it before return. + + + """ return self.model( input_tokens, token_type_ids, diff --git a/modelscope/models/nlp/plug/configuration_plug.py b/modelscope/models/nlp/plug/configuration.py similarity index 51% rename from modelscope/models/nlp/plug/configuration_plug.py rename to modelscope/models/nlp/plug/configuration.py index 64807392..c3a526a9 100644 --- a/modelscope/models/nlp/plug/configuration_plug.py +++ b/modelscope/models/nlp/plug/configuration.py @@ -40,8 +40,6 @@ class PlugNLUConfig(PretrainedConfig): max_position_embeddings=2048, type_vocab_size=3, initializer_range=0.00707, - deep_init=False, - deepspeed=False, lr_decay_style='linear', weight_decay=1e-2, clip_grad=1.0, @@ -53,20 +51,7 @@ class PlugNLUConfig(PretrainedConfig): fp32_tokentypes=False, layernorm_epsilon=1e-5, dec_hidden_layers=6, - pruning_method=None, - pruning_mask_init='constant', - pruning_mask_scale=0.0, - pruning_initial_threshold=1.0, - pruning_final_threshold=0.01, - pruning_initial_warmup=1, - pruning_final_warmup=20, - pruning_module='decoder', - pruning_decay_step=50, - pruning_decay_type='exp', - ft_module=None, attn_separate=False, - LR_weight_rank=8, - LR_mask_rank=8, **kwargs): super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) @@ -82,8 +67,6 @@ class PlugNLUConfig(PretrainedConfig): self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range - self.deep_init = deep_init - self.deepspeed = deepspeed self.lr_decay_style = lr_decay_style self.weight_decay = weight_decay self.clip_grad = clip_grad @@ -95,20 +78,7 @@ class PlugNLUConfig(PretrainedConfig): self.layernorm_epsilon = layernorm_epsilon self.fp32_tokentypes = fp32_tokentypes self.dec_hidden_layers = dec_hidden_layers - self.pruning_method = pruning_method - self.pruning_mask_init = pruning_mask_init - self.pruning_mask_scale = pruning_mask_scale - self.pruning_module = pruning_module - self.pruning_initial_threshold = pruning_initial_threshold - self.pruning_final_threshold = pruning_final_threshold - self.pruning_initial_warmup = pruning_initial_warmup - self.pruning_final_warmup = pruning_final_warmup - self.pruning_decay_step = pruning_decay_step - self.pruning_decay_type = pruning_decay_type - self.ft_module = ft_module self.attn_separate = attn_separate - self.LR_weight_rank = LR_weight_rank - self.LR_mask_rank = LR_mask_rank @classmethod def from_dict(cls, json_object): @@ -148,47 +118,115 @@ class PlugNLUConfig(PretrainedConfig): class PlugNLGConfig(PlugNLUConfig): + """ + This is the configuration class to store the configuration of a [`PlugModel`]. It is used to instantiate a + PLUG understanding model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the PLUG + [PLUG](https://modelscope.cn/models/damo/nlp_plug_text-generation_27B/summary) architecture. + + Configuration objects inherit from [`PlugNLUConfig`] and can be used to control the model outputs. Read the + documentation from [`PlugNLUConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 21504): + Padded vocabulary size of the PLUG model for vocab tensor parallel. Defines the number of different tokens + that can be represented by the `inputs_ids` passed when calling [`PlugModel`]. + original_vocab_size (`int`, *optional*, defaults to 21128): + True vocabulary size of the PLUG model. Defines the number of different tokens that can be represented. + hidden_size (`int`, *optional*, defaults to 8192): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + dec_hidden_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 128): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 32768): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the Transformer Attention. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 3): + The vocabulary size of the `token_type_ids` passed when calling [`PlugModel`]. + initializer_range (`float`, *optional*, defaults to 0.00707): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + lr_decay_style (`str`, *optional*, defaults to 'linear'): + The decay style of learning rate during fine-tunining. If string, `"linear"`, `"cosine"`, `"exponential"`, + `"constant"`, `"None"` are supported. + weight_decay (`float`, *optional*, defaults to 1e-2): + Decoupled weight decay to apply. + clip_grad (`float`, *optional*, defaults to 1.0): + Maximum gradient norm for gradient clipping. + warmup (`float`, *optional*, defaults to 0.01): + Ratio of total training steps used for a linear warmup from 0 to `learning_rate`. + pre_ln (`boolean`, *optional*, defaults to `True`): + Whether or not to apply LayerNorm to the input instead of the output in the blocks. + fp16 (`boolean`, *optional*, defaults to `True`): + Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training. + fp32_layernorm (`boolean`, *optional*, defaults to `True`): + Whether to use fp32 32-bit precision LayerNorm training while the argument `fp16` set to `True`. + fp32_embedding (`boolean`, *optional*, defaults to `False`): + Whether to use fp32 32-bit precision Embedding training while the argument `fp16` set to `True`. + fp32_tokentypes (`boolean`, *optional*, defaults to `False`): + Whether to use fp32 32-bit precision token types training while the argument `fp16` set to `True`. + layernorm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon to use in the layer normalization layers. + attn_separate (`boolean`, *optional*, defaults to `False`): + Whether or not to separate query-key-value to query, key, value in the Attention. + + Example: + + ```python + >>> # The PLUG model has 27B parameters and usually need to run on multiple GPUs. The example given + >>> # here only initializes a slice of the model on a single GPU. + >>> # Check out the [`~DistributedPipeline.__init__`] method to initialize entire PLUG model. + >>> from modelscope.models.nlp.plug import PlugNLGConfig, PlugModel + + >>> # Initializing a Plug configuration + >>> configuration = PlugNLGConfig() + + >>> # Initializing a model from the configuration + >>> model = PlugModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + model_type = 'plugNLG' def __init__(self, vocab_size=21504, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, + original_vocab_size=21128, + hidden_size=8192, + num_hidden_layers=24, + dec_hidden_layers=6, + num_attention_heads=128, + intermediate_size=32768, hidden_act='gelu', hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, + max_position_embeddings=2048, + type_vocab_size=3, initializer_range=0.00707, - deep_init=False, - deepspeed=False, lr_decay_style='linear', weight_decay=1e-2, clip_grad=1.0, warmup=0.01, - pre_ln=False, - fp16=False, - fp32_layernorm=False, + pre_ln=True, + fp16=True, + fp32_layernorm=True, fp32_embedding=False, fp32_tokentypes=False, - layernorm_epsilon=1e-12, - dec_hidden_layers=6, - pruning_method=None, - pruning_mask_init='constant', - pruning_mask_scale=0.0, - pruning_initial_threshold=1.0, - pruning_final_threshold=0.01, - pruning_initial_warmup=1, - pruning_final_warmup=20, - pruning_module='decoder', - pruning_decay_step=50, - pruning_decay_type='exp', - ft_module=None, + layernorm_epsilon=1e-5, attn_separate=False, - LR_weight_rank=8, - LR_mask_rank=8, **kwargs): super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) @@ -203,8 +241,6 @@ class PlugNLGConfig(PlugNLUConfig): self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range - self.deep_init = deep_init - self.deepspeed = deepspeed self.lr_decay_style = lr_decay_style self.weight_decay = weight_decay self.clip_grad = clip_grad @@ -216,17 +252,4 @@ class PlugNLGConfig(PlugNLUConfig): self.layernorm_epsilon = layernorm_epsilon self.fp32_tokentypes = fp32_tokentypes self.dec_hidden_layers = dec_hidden_layers - self.pruning_method = pruning_method - self.pruning_mask_init = pruning_mask_init - self.pruning_mask_scale = pruning_mask_scale - self.pruning_module = pruning_module - self.pruning_initial_threshold = pruning_initial_threshold - self.pruning_final_threshold = pruning_final_threshold - self.pruning_initial_warmup = pruning_initial_warmup - self.pruning_final_warmup = pruning_final_warmup - self.pruning_decay_step = pruning_decay_step - self.pruning_decay_type = pruning_decay_type - self.ft_module = ft_module self.attn_separate = attn_separate - self.LR_weight_rank = LR_weight_rank - self.LR_mask_rank = LR_mask_rank diff --git a/modelscope/models/nlp/plug/distributed_plug.py b/modelscope/models/nlp/plug/distributed_plug.py index 2992f595..c72e92ba 100644 --- a/modelscope/models/nlp/plug/distributed_plug.py +++ b/modelscope/models/nlp/plug/distributed_plug.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os from typing import Dict @@ -14,12 +15,54 @@ from modelscope.utils.nlp.distributed import initialize_distributed from modelscope.utils.nlp.load_checkpoint import pre_load from modelscope.utils.torch_utils import set_random_seed_mpu from . import PlugModel -from .configuration_plug import PlugNLGConfig +from .configuration import PlugNLGConfig logger = get_logger(__name__) class DistributedPlug(TorchModel): + """ + The wapper class of PLUG Model to initialize parallel environment, load model weights, generate sentences. + Parameters: + model_dir (`str`, *required*): + Path to model damo/nlp_plug_text-generation_27B. + The model structure in model_dir should be like this: + model_dir + |_ config.json + |_ configuration.json + |_ ds_zero-offload_10B_config.json + |_ vocab.txt + |_ model <-- an empty directory + + Model binaries shall be downloaded separately to populate the model directory, so that + the model directory would contain the following binaries: + |_ model + |_ mp_rank_00_model_states.pt + |_ mp_rank_01_model_states.pt + |_ mp_rank_02_model_states.pt + |_ mp_rank_03_model_states.pt + |_ mp_rank_04_model_states.pt + |_ mp_rank_05_model_states.pt + |_ mp_rank_06_model_states.pt + |_ mp_rank_07_model_states.pt + rank (`int`, *required*): + Used to identify different GPUs in a tensor parallel environment. eg. The rank of GPU #0 is 0, and the + model file `mp_rank_00_model_states.pt` will be loaded on this GPU. + world_size (`int`, *required*, defaults to 8): + The parallel size in total. + model_parallel_size (`int`, *required*, defaults to 8): + The parallel size of model(tensor parallel). + master_ip (`str`, *required*): + The master IP, can usually be set to `"127.0.0.1"`, used as part of + [`~torch.distributed.init_process_group`] method parameter `init_method`. + `init_method` = `"tcp://{master_ip}:{master_port}"` + master_port (`str`, *required*): + The master port, can usually be set to `"29500"`, used as part of + [`~torch.distributed.init_process_group`] method parameter `init_method`. + `init_method` = `"tcp://{master_ip}:{master_port}"` + seed (`int`, *optional*, defaults to 42): + Random seed to control sampling. + """ def __init__(self, model_dir, rank, **kwargs): super().__init__(model_dir, **kwargs) @@ -29,7 +72,7 @@ class DistributedPlug(TorchModel): initialize_distributed(rank, mpu, kwargs['world_size'], kwargs['model_parallel_size'], kwargs['master_ip'], kwargs['master_port']) - seed = 0 if 'seed' not in kwargs else kwargs['seed'] + seed = 42 if 'seed' not in kwargs else kwargs['seed'] set_random_seed_mpu(seed) self.iteration = 0 self.dist_model = self.initialize_model(path_load_tag='model') diff --git a/modelscope/models/nlp/ponet/__init__.py b/modelscope/models/nlp/ponet/__init__.py index 6d26b194..df996167 100644 --- a/modelscope/models/nlp/ponet/__init__.py +++ b/modelscope/models/nlp/ponet/__init__.py @@ -18,16 +18,16 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .configuration_ponet import PoNetConfig - from .modeling_ponet import (PoNetForMaskedLM, PoNetModel, - PoNetPreTrainedModel) - from .tokenization_ponet import PoNetTokenizer + from .configuration import PoNetConfig + from .backbone import (PoNetModel, PoNetPreTrainedModel) + from .tokenization import PoNetTokenizer + from .fill_mask import PoNetForMaskedLM else: _import_structure = { - 'configuration_ponet': ['PoNetConfig'], - 'modeling_ponet': - ['PoNetForMaskedLM', 'PoNetModel', 'PoNetPreTrainedModel'], - 'tokenization_ponet': ['PoNetTokenizer'], + 'configuration': ['PoNetConfig'], + 'backbone': ['PoNetModel', 'PoNetPreTrainedModel'], + 'fill_mask': ['PoNetForMaskedLM'], + 'tokenization': ['PoNetTokenizer'], } import sys diff --git a/modelscope/models/nlp/ponet/modeling_ponet.py b/modelscope/models/nlp/ponet/backbone.py similarity index 55% rename from modelscope/models/nlp/ponet/modeling_ponet.py rename to modelscope/models/nlp/ponet/backbone.py index f37954db..fcc62fa2 100644 --- a/modelscope/models/nlp/ponet/modeling_ponet.py +++ b/modelscope/models/nlp/ponet/backbone.py @@ -16,43 +16,32 @@ """PyTorch PoNet model. """ import math -from dataclasses import dataclass from distutils.version import LooseVersion -from typing import Optional, Tuple import torch import torch.utils.checkpoint from packaging import version from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN -from transformers.file_utils import (ModelOutput, add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - replace_return_docstrings) -from transformers.modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, - BaseModelOutputWithPoolingAndCrossAttentions, - CausalLMOutputWithCrossAttentions, MaskedLMOutput, - SequenceClassifierOutput, TokenClassifierOutput) +from transformers.modeling_outputs import \ + BaseModelOutputWithPastAndCrossAttentions from transformers.modeling_utils import (PreTrainedModel, apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer) -from transformers.models.bert.modeling_bert import \ - load_tf_weights_in_bert as load_tf_weights_in_ponet +from modelscope.metainfo import Models +from modelscope.models import Model, TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionBackboneModelOutput +from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger -from .configuration_ponet import PoNetConfig +from .configuration import PoNetConfig logger = get_logger(__name__) is_pytorch_12plus = LooseVersion(torch.__version__) >= LooseVersion('1.12.0') -_CHECKPOINT_FOR_DOC = 'ponet-base-uncased' -_CONFIG_FOR_DOC = 'PoNetConfig' -_TOKENIZER_FOR_DOC = 'PoNetTokenizer' - CLS_ID = 101 EOS_ID = 102 @@ -609,82 +598,20 @@ class PoNetPooler(nn.Module): return pooled_output -class PoNetPredictionHeadTransform(nn.Module): - - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - if isinstance(config.hidden_act, str): - self.transform_act_fn = ACT2FN[config.hidden_act] - else: - self.transform_act_fn = config.hidden_act - self.LayerNorm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps) - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.LayerNorm(hidden_states) - return hidden_states - - -class PoNetLMPredictionHead(nn.Module): - - def __init__(self, config): - super().__init__() - self.transform = PoNetPredictionHeadTransform(config) - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = nn.Linear( - config.hidden_size, config.vocab_size, bias=False) - - self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def forward(self, hidden_states): - hidden_states = self.transform(hidden_states) - hidden_states = self.decoder(hidden_states) - return hidden_states - - -class PoNetOnlyMLMHead(nn.Module): - - def __init__(self, config): - super().__init__() - self.predictions = PoNetLMPredictionHead(config) - - def forward(self, sequence_output): - prediction_scores = self.predictions(sequence_output) - return prediction_scores - - -class PoNetPreTrainingHeads(nn.Module): - - def __init__(self, config): - super().__init__() - self.predictions = PoNetLMPredictionHead(config) - self.seq_relationship = nn.Linear(config.hidden_size, 3) - - def forward(self, sequence_output, pooled_output): - prediction_scores = self.predictions(sequence_output) - seq_relationship_score = self.seq_relationship(pooled_output) - return prediction_scores, seq_relationship_score - - -class PoNetPreTrainedModel(PreTrainedModel): +class PoNetPreTrainedModel(TorchModel, PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = PoNetConfig - load_tf_weights = load_tf_weights_in_ponet base_model_prefix = 'ponet' _keys_to_ignore_on_load_missing = [r'position_ids'] + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): @@ -703,51 +630,22 @@ class PoNetPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - -@dataclass -class PoNetForPreTrainingOutput(ModelOutput): - """ - Output type of :class:`~transformers.PoNetForPreTraining`. - - Args: - loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`): - Total loss as the sum of the masked language modeling loss and the next sequence prediction - (classification) loss. - mlm_loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`): - Masked language modeling loss. - sop_loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`): - sop loss. - prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`): - Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation - before SoftMax). - hidden_states - (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed - or when ``config.output_hidden_states=True``): - Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) - of shape :obj:`(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed - or when ``config.output_attentions=True``): - Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - mlm_loss: Optional[torch.FloatTensor] = None - sop_loss: Optional[torch.FloatTensor] = None - prediction_logits: torch.FloatTensor = None - seq_relationship_logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None + @classmethod + def _instantiate(cls, **kwargs): + model_dir = kwargs.pop('model_dir', None) + if model_dir is None: + ponet_config = PoNetConfig(**kwargs) + model = cls(ponet_config) + else: + model = super( + Model, + cls).from_pretrained(pretrained_model_name_or_path=model_dir) + return model -PONET_START_DOCSTRING = r""" +@MODELS.register_module(Tasks.backbone, module_name=Models.ponet) +class PoNetModel(PoNetPreTrainedModel): + """The bare PoNet Model transformer outputting raw hidden-states without any specific head on top. This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, @@ -763,65 +661,6 @@ PONET_START_DOCSTRING = r""" Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. -""" - -PONET_INPUTS_DOCSTRING = r""" - Args: - input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using :class:`~modelscope.models.nlp.ponet.PoNetTokenizer`. See - :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for - details. - - `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): - Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - `What are attention masks? <../glossary.html#attention-mask>`__ - token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, - 1]``: - - - 0 corresponds to a `sentence A` token, - - 1 corresponds to a `sentence B` token. - - `What are token type IDs? <../glossary.html#token-type-ids>`_ - position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, - config.max_position_embeddings - 1]``. - - `What are position IDs? <../glossary.html#position-ids>`_ - head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): - Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert :obj:`input_ids` indices into associated - vectors than the model's internal embedding lookup matrix. - output_attentions (:obj:`bool`, `optional`): - Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned - tensors for more detail. - output_hidden_states (:obj:`bool`, `optional`): - Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for - more detail. - return_dict (:obj:`bool`, `optional`): - Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. -""" - - -@add_start_docstrings( - 'The bare PoNet Model transformer outputting raw hidden-states without any specific head on top.', - PONET_START_DOCSTRING, -) -class PoNetModel(PoNetPreTrainedModel): - """ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of cross-attention is added between the self-attention layers, following the architecture described in `Attention is @@ -834,8 +673,8 @@ class PoNetModel(PoNetPreTrainedModel): input to the forward pass. """ - def __init__(self, config, add_pooling_layer=True): - super().__init__(config) + def __init__(self, config, add_pooling_layer=True, **kwargs): + super().__init__(config, **kwargs) self.config = config self.embeddings = PoNetEmbeddings(config) @@ -859,14 +698,6 @@ class PoNetModel(PoNetPreTrainedModel): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) - @add_start_docstrings_to_model_forward( - PONET_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=BaseModelOutputWithPoolingAndCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) def forward( self, input_ids=None, @@ -885,6 +716,49 @@ class PoNetModel(PoNetPreTrainedModel): return_dict=None, ): r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~modelscope.models.nlp.ponet.PoNetTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if @@ -906,6 +780,16 @@ class PoNetModel(PoNetPreTrainedModel): use_cache (:obj:`bool`, `optional`): If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up decoding (see :obj:`past_key_values`). + + Returns: + Returns `modelscope.outputs.AttentionBackboneModelOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_ponet_fill-mask_chinese-base', task='backbone') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_ponet_fill-mask_chinese-base') + >>> print(model(**preprocessor('这是个测试'))) """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1006,7 +890,7 @@ class PoNetModel(PoNetPreTrainedModel): if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] - return BaseModelOutputWithPoolingAndCrossAttentions( + return AttentionBackboneModelOutput( last_hidden_state=sequence_output, pooler_output=pooled_output, past_key_values=encoder_outputs.past_key_values, @@ -1014,578 +898,3 @@ class PoNetModel(PoNetPreTrainedModel): attentions=encoder_outputs.attentions, cross_attentions=encoder_outputs.cross_attentions, ) - - -@add_start_docstrings( - """ - PoNet Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next - sentence prediction (classification)` head. - """, - PONET_START_DOCSTRING, -) -class PoNetForPreTraining(PoNetPreTrainedModel): - - def __init__(self, config): - super().__init__(config) - - self.ponet = PoNetModel(config) - self.cls = PoNetPreTrainingHeads(config) - - self.init_weights() - - def get_output_embeddings(self): - return self.cls.predictions.decoder - - def set_output_embeddings(self, new_embeddings): - self.cls.predictions.decoder = new_embeddings - - @add_start_docstrings_to_model_forward( - PONET_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - @replace_return_docstrings( - output_type=PoNetForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - segment_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - next_sentence_label=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`): - Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., - config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored - (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` - next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`): - Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair - (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``: - - - 0 indicates sequence B is a continuation of sequence A, - - 1 indicates sequence B is a random sequence. - kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): - Used to hide legacy arguments that have been deprecated. - - Returns: - - Example:: - - >>> from transformers import PoNetTokenizer, PoNetForPreTraining - >>> import torch - - >>> tokenizer = PoNetTokenizer.from_pretrained('ponet-base-uncased') - >>> model = PoNetForPreTraining.from_pretrained('ponet-base-uncased') - - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") - >>> outputs = model(**inputs) - - >>> prediction_logits = outputs.prediction_logits - >>> seq_relationship_logits = outputs.seq_relationship_logits - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.ponet( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - segment_ids=segment_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output, pooled_output = outputs[:2] - prediction_scores, seq_relationship_score = self.cls( - sequence_output, pooled_output) - - total_loss = None - masked_lm_loss = None - next_sentence_loss = None - if labels is not None and next_sentence_label is not None: - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct( - prediction_scores.view(-1, self.config.vocab_size), - labels.view(-1)) - next_sentence_loss = loss_fct( - seq_relationship_score.view(-1, 3), - next_sentence_label.view(-1)) - total_loss = masked_lm_loss + next_sentence_loss - - if not return_dict: - output = (prediction_scores, seq_relationship_score) + outputs[2:] - return ((total_loss, masked_lm_loss, next_sentence_loss) - + output) if total_loss is not None else output - - return PoNetForPreTrainingOutput( - loss=total_loss, - mlm_loss=masked_lm_loss, - sop_loss=next_sentence_loss, - prediction_logits=prediction_scores, - seq_relationship_logits=seq_relationship_score, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """PoNet Model with a `language modeling` head on top for CLM fine-tuning. """, - PONET_START_DOCSTRING) -class PoNetLMHeadModel(PoNetPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r'pooler'] - _keys_to_ignore_on_load_missing = [ - r'position_ids', r'predictions.decoder.bias' - ] - - def __init__(self, config): - super().__init__(config) - - if not config.is_decoder: - logger.warning( - 'If you want to use `PoNetLMHeadModel` as a standalone, add `is_decoder=True.`' - ) - - self.ponet = PoNetModel(config, add_pooling_layer=False) - self.cls = PoNetOnlyMLMHead(config) - - self.init_weights() - - def get_output_embeddings(self): - return self.cls.predictions.decoder - - def set_output_embeddings(self, new_embeddings): - self.cls.predictions.decoder = new_embeddings - - @add_start_docstrings_to_model_forward( - PONET_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - @replace_return_docstrings( - output_type=CausalLMOutputWithCrossAttentions, - config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - segment_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj: - `(batch_size, sequence_length, hidden_size)`, `optional`): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in - ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are - ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` - past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` - with each tuple having 4 tensors of shape : - obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` - (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` - instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. - use_cache (:obj:`bool`, `optional`): - If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up - decoding (see :obj:`past_key_values`). - - Returns: - - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if labels is not None: - use_cache = False - - outputs = self.ponet( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - segment_ids=segment_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - prediction_scores = self.cls(sequence_output) - - lm_loss = None - if labels is not None: - # we are doing next-token prediction; shift prediction scores and input ids by one - shifted_prediction_scores = prediction_scores[:, : - -1, :].contiguous() - labels = labels[:, 1:].contiguous() - loss_fct = CrossEntropyLoss() - lm_loss = loss_fct( - shifted_prediction_scores.view(-1, self.config.vocab_size), - labels.view(-1)) - - if not return_dict: - output = (prediction_scores, ) + outputs[2:] - return ((lm_loss, ) + output) if lm_loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=lm_loss, - logits=prediction_scores, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - def prepare_inputs_for_generation(self, - input_ids, - past=None, - attention_mask=None, - **model_kwargs): - input_shape = input_ids.shape - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = input_ids.new_ones(input_shape) - - # cut decoder_input_ids if past is used - if past is not None: - input_ids = input_ids[:, -1:] - - return { - 'input_ids': input_ids, - 'attention_mask': attention_mask, - 'past_key_values': past - } - - def _reorder_cache(self, past, beam_idx): - reordered_past = () - for layer_past in past: - reordered_past += (tuple( - past_state.index_select(0, beam_idx) - for past_state in layer_past), ) - return reordered_past - - -@add_start_docstrings( - """PoNet Model with a `language modeling` head on top. """, - PONET_START_DOCSTRING) -class PoNetForMaskedLM(PoNetPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r'pooler'] - _keys_to_ignore_on_load_missing = [ - r'position_ids', r'predictions.decoder.bias' - ] - - def __init__(self, config): - super().__init__(config) - - if config.is_decoder: - logger.warning( - 'If you want to use `PoNetForMaskedLM` make sure `config.is_decoder=False` for ' - 'bi-directional self-attention.') - - self.ponet = PoNetModel(config, add_pooling_layer=False) - self.cls = PoNetOnlyMLMHead(config) - - self.init_weights() - - def get_output_embeddings(self): - return self.cls.predictions.decoder - - def set_output_embeddings(self, new_embeddings): - self.cls.predictions.decoder = new_embeddings - - @add_start_docstrings_to_model_forward( - PONET_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=MaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - segment_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., - config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored - (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` - """ - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.ponet( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - segment_ids=segment_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - prediction_scores = self.cls(sequence_output) - - masked_lm_loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() # -100 index = padding token - masked_lm_loss = loss_fct( - prediction_scores.view(-1, self.config.vocab_size), - labels.view(-1)) - - if not return_dict: - output = (prediction_scores, ) + outputs[2:] - return ((masked_lm_loss, ) - + output) if masked_lm_loss is not None else output - - return MaskedLMOutput( - loss=masked_lm_loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - PoNet Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled - output) e.g. for GLUE tasks. - """, - PONET_START_DOCSTRING, -) -class PoNetForSequenceClassification(PoNetPreTrainedModel): - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.config = config - - self.ponet = PoNetModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) - - self.init_weights() - - @add_start_docstrings_to_model_forward( - PONET_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=SequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - segment_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): - Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., - config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), - If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.ponet( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - segment_ids=segment_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = outputs[1] - - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = 'regression' - elif self.num_labels > 1 and (labels.dtype == torch.long - or labels.dtype == torch.int): - self.config.problem_type = 'single_label_classification' - else: - self.config.problem_type = 'multi_label_classification' - - if self.config.problem_type == 'regression': - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(logits, labels) - elif self.config.problem_type == 'single_label_classification': - loss_fct = CrossEntropyLoss() - loss = loss_fct( - logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == 'multi_label_classification': - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits, labels) - if not return_dict: - output = (logits, ) + outputs[2:] - return ((loss, ) + output) if loss is not None else output - - return SequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - PoNet Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - PONET_START_DOCSTRING, -) -class PoNetForTokenClassification(PoNetPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r'pooler'] - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - - self.ponet = PoNetModel(config, add_pooling_layer=False) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) - - self.init_weights() - - @add_start_docstrings_to_model_forward( - PONET_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - segment_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - - 1]``. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.ponet( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - segment_ids=segment_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - - sequence_output = self.dropout(sequence_output) - logits = self.classifier(sequence_output) - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - # Only keep active parts of the loss - if attention_mask is not None: - active_loss = attention_mask.view(-1) == 1 - active_logits = logits.view(-1, self.num_labels) - active_labels = torch.where( - active_loss, labels.view(-1), - torch.tensor(loss_fct.ignore_index).type_as(labels)) - loss = loss_fct(active_logits, active_labels) - else: - loss = loss_fct( - logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits, ) + outputs[2:] - return ((loss, ) + output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) diff --git a/modelscope/models/nlp/ponet/configuration_ponet.py b/modelscope/models/nlp/ponet/configuration.py similarity index 96% rename from modelscope/models/nlp/ponet/configuration_ponet.py rename to modelscope/models/nlp/ponet/configuration.py index 70294fc2..7dfaba48 100644 --- a/modelscope/models/nlp/ponet/configuration_ponet.py +++ b/modelscope/models/nlp/ponet/configuration.py @@ -34,8 +34,7 @@ class PoNetConfig(PretrainedConfig): Args: vocab_size (:obj:`int`, `optional`, defaults to 30522): Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the - :obj:`inputs_ids` passed when calling :class:`~transformers.BertModel` or - :class:`~transformers.TFBertModel`. + :obj:`inputs_ids` passed. hidden_size (:obj:`int`, `optional`, defaults to 768): Dimensionality of the encoder layers and the pooler layer. num_hidden_layers (:obj:`int`, `optional`, defaults to 12): @@ -55,8 +54,7 @@ class PoNetConfig(PretrainedConfig): The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). type_vocab_size (:obj:`int`, `optional`, defaults to 2): - The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BertModel` or - :class:`~transformers.TFBertModel`. + The vocabulary size of the :obj:`token_type_ids` passed. initializer_range (:obj:`float`, `optional`, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): diff --git a/modelscope/models/nlp/ponet/fill_mask.py b/modelscope/models/nlp/ponet/fill_mask.py new file mode 100644 index 00000000..fb09efc0 --- /dev/null +++ b/modelscope/models/nlp/ponet/fill_mask.py @@ -0,0 +1,252 @@ +# Copyright 2021-2022 The Alibaba DAMO Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionFillMaskModelOutput +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from .backbone import PoNetModel, PoNetPreTrainedModel + +logger = get_logger(__name__) + + +class PoNetPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class PoNetLMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.transform = PoNetPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class PoNetOnlyMLMHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = PoNetLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +@MODELS.register_module(Tasks.fill_mask, module_name=Models.ponet) +class PoNetForMaskedLM(PoNetPreTrainedModel): + r"""PoNet Model with a `language modeling` head on top. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Preprocessor: + This is the fill_mask model of PoNet, the preprocessor of this model + is `modelscope.preprocessors.FillMaskPoNetPreprocessor`. + + Parameters: + config (:class:`~modelscope.models.nlp.ponet.PoNetConfig`): + Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. + """ + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config, **kwargs): + super().__init__(config) + + if config.is_decoder: + logger.warning( + 'If you want to use `PoNetForMaskedLM` make sure `config.is_decoder=False` for ' + 'bi-directional self-attention.') + + self.ponet = PoNetModel(config, add_pooling_layer=False) + self.cls = PoNetOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + segment_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`('batch_size, sequence_length')`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~modelscope.models.nlp.ponet.PoNetTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`('batch_size, sequence_length')`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`('batch_size, sequence_length')`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + position_ids (:obj:`torch.LongTensor` of shape :obj:`('batch_size, sequence_length')`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`('batch_size, sequence_length', hidden_size)`, + `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + + Returns: + Returns `modelscope.outputs.AttentionFillMaskModelOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_ponet_fill-mask_chinese-base') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_ponet_fill-mask_chinese-base') + >>> # Call the model, return some tensors + >>> print(model(**preprocessor('你师父差得动你,你师父可[MASK]不动我。'))) + >>> # Call the pipeline + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline('fill-mask', model=model, preprocessor=preprocessor) + >>> print(pipeline_ins('你师父差得动你,你师父可[MASK]不动我。')) + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.ponet( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + segment_ids=segment_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + + if not return_dict: + output = (prediction_scores, ) + outputs[2:] + return ((masked_lm_loss, ) + + output) if masked_lm_loss is not None else output + + return AttentionFillMaskModelOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + input_ids=input_ids, + ) diff --git a/modelscope/models/nlp/ponet/tokenization_ponet.py b/modelscope/models/nlp/ponet/tokenization.py similarity index 98% rename from modelscope/models/nlp/ponet/tokenization_ponet.py rename to modelscope/models/nlp/ponet/tokenization.py index 21544886..2da91545 100644 --- a/modelscope/models/nlp/ponet/tokenization_ponet.py +++ b/modelscope/models/nlp/ponet/tokenization.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from transformers.file_utils import PaddingStrategy from transformers.models.bert.tokenization_bert import BertTokenizer +from transformers.tokenization_utils import BatchEncoding, EncodedInput from modelscope.utils.constant import ModelFile from modelscope.utils.logger import get_logger diff --git a/modelscope/models/nlp/ponet_for_masked_language.py b/modelscope/models/nlp/ponet_for_masked_language.py deleted file mode 100644 index 11f4bc11..00000000 --- a/modelscope/models/nlp/ponet_for_masked_language.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -from typing import Any, Dict - -from modelscope.metainfo import Models -from modelscope.models.base import TorchModel -from modelscope.models.builder import MODELS -from modelscope.models.nlp.ponet import \ - PoNetForMaskedLM as PoNetForMaskedLMTransformer -from modelscope.outputs import OutputKeys -from modelscope.utils.constant import Tasks - -__all__ = ['PoNetForMaskedLM'] - - -@MODELS.register_module(Tasks.fill_mask, module_name=Models.ponet) -class PoNetForMaskedLM(TorchModel, PoNetForMaskedLMTransformer): - """PoNet for MLM model.'. - - Inherited from ponet.PoNetForMaskedLM and TorchModel, so this class can be registered into Model sets. - """ - - def __init__(self, config, model_dir): - super(TorchModel, self).__init__(model_dir) - PoNetForMaskedLMTransformer.__init__(self, config) - - def forward(self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - segment_ids=None, - position_ids=None, - head_mask=None, - labels=None): - output = PoNetForMaskedLMTransformer.forward( - self, - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - segment_ids=segment_ids, - position_ids=position_ids, - head_mask=head_mask, - labels=labels) - output[OutputKeys.INPUT_IDS] = input_ids - return output - - @classmethod - def _instantiate(cls, **kwargs): - model_dir = kwargs.get('model_dir') - return super(PoNetForMaskedLMTransformer, - PoNetForMaskedLM).from_pretrained( - pretrained_model_name_or_path=model_dir, - model_dir=model_dir) diff --git a/modelscope/models/nlp/sentence_embedding.py b/modelscope/models/nlp/sentence_embedding.py deleted file mode 100644 index 340c133f..00000000 --- a/modelscope/models/nlp/sentence_embedding.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -from typing import Any, Dict - -import numpy as np - -from modelscope.metainfo import Models -from modelscope.models import TorchModel -from modelscope.models.builder import MODELS -from modelscope.models.nlp.structbert import SbertPreTrainedModel -from modelscope.utils.constant import Tasks - -__all__ = ['SentenceEmbedding'] - - -@MODELS.register_module(Tasks.sentence_embedding, module_name=Models.bert) -class SentenceEmbedding(TorchModel, SbertPreTrainedModel): - base_model_prefix: str = 'bert' - supports_gradient_checkpointing = True - _keys_to_ignore_on_load_missing = [r'position_ids'] - - def __init__(self, config, model_dir): - super().__init__(model_dir) - self.config = config - setattr(self, self.base_model_prefix, self.build_base_model()) - - def build_base_model(self): - from .structbert import SbertModel - return SbertModel(self.config, add_pooling_layer=False) - - def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: - """return the result by the model - - Args: - input (Dict[str, Any]): the preprocessed data - - Returns: - Dict[str, np.ndarray]: results - Example: - { - 'predictions': array([1]), # lable 0-negative 1-positive - 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), - 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value - } - """ - return self.base_model(**input) - - def postprocess(self, inputs: Dict[str, np.ndarray], - **kwargs) -> Dict[str, np.ndarray]: - embs = inputs['last_hidden_state'][:, 0].cpu().numpy() - num_sent = embs.shape[0] - if num_sent >= 2: - scores = np.dot(embs[0:1, ], np.transpose(embs[1:, ], - (1, 0))).tolist()[0] - else: - scores = [] - result = {'text_embedding': embs, 'scores': scores} - - return result - - @classmethod - def _instantiate(cls, **kwargs): - """Instantiate the model. - - @param kwargs: Input args. - model_dir: The model dir used to load the checkpoint and the label information. - @return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained - """ - model_args = {} - - return super(SbertPreTrainedModel, SentenceEmbedding).from_pretrained( - pretrained_model_name_or_path=kwargs.get('model_dir'), - model_dir=kwargs.get('model_dir'), - **model_args) diff --git a/modelscope/models/nlp/sequence_classification.py b/modelscope/models/nlp/sequence_classification.py deleted file mode 100644 index 156c615c..00000000 --- a/modelscope/models/nlp/sequence_classification.py +++ /dev/null @@ -1,287 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -from abc import abstractmethod - -from torch import nn - -from modelscope.metainfo import Models -from modelscope.models.base import TorchModel -from modelscope.models.builder import MODELS -from modelscope.models.nlp.bert import BertPreTrainedModel -from modelscope.models.nlp.structbert import SbertPreTrainedModel -from modelscope.models.nlp.veco import \ - VecoForSequenceClassification as VecoForSequenceClassificationTransform -from modelscope.outputs import OutputKeys -from modelscope.utils.constant import Tasks -from modelscope.utils.hub import parse_label_mapping -from modelscope.utils.tensor_utils import (torch_nested_detach, - torch_nested_numpify) - -__all__ = [ - 'SbertForSequenceClassification', 'VecoForSequenceClassification', - 'BertForSequenceClassification' -] - - -class SequenceClassificationBase(TorchModel): - """A sequence classification base class for all the fitted sequence classification models. - """ - base_model_prefix: str = 'bert' - - def __init__(self, config, model_dir): - super().__init__(model_dir) - self.num_labels = config.num_labels - self.config = config - setattr(self, self.base_model_prefix, self.build_base_model()) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) - - @abstractmethod - def build_base_model(self): - """Build the backbone model. - - Returns: the backbone instance. - """ - pass - - @property - def base_model(self): - return getattr(self, self.base_model_prefix) - - def forward(self, **kwargs): - labels = None - if OutputKeys.LABEL in kwargs: - labels = kwargs.pop(OutputKeys.LABEL) - elif OutputKeys.LABELS in kwargs: - labels = kwargs.pop(OutputKeys.LABELS) - - outputs = self.base_model.forward(**kwargs) - - # backbone model should return pooled_output as its second output - pooled_output = outputs[1] - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - if labels is not None: - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - return {OutputKeys.LOGITS: logits, OutputKeys.LOSS: loss} - return {OutputKeys.LOGITS: logits} - - def postprocess(self, input, **kwargs): - logits = input[OutputKeys.LOGITS] - probs = torch_nested_numpify(torch_nested_detach(logits.softmax(-1))) - pred = torch_nested_numpify(torch_nested_detach(logits.argmax(-1))) - logits = torch_nested_numpify(torch_nested_detach(logits)) - res = { - OutputKeys.PREDICTIONS: pred, - OutputKeys.PROBABILITIES: probs, - OutputKeys.LOGITS: logits - } - return res - - -@MODELS.register_module( - Tasks.sentence_similarity, module_name=Models.structbert) -@MODELS.register_module( - Tasks.sentiment_classification, module_name=Models.structbert) -@MODELS.register_module(Tasks.nli, module_name=Models.structbert) -@MODELS.register_module( - Tasks.zero_shot_classification, module_name=Models.structbert) -class SbertForSequenceClassification(SequenceClassificationBase, - SbertPreTrainedModel): - """Sbert sequence classification model. - - Inherited from SequenceClassificationBase. - """ - base_model_prefix: str = 'bert' - supports_gradient_checkpointing = True - _keys_to_ignore_on_load_missing = [r'position_ids'] - - def __init__(self, config, model_dir): - if hasattr(config, 'base_model_prefix'): - SbertForSequenceClassification.base_model_prefix = config.base_model_prefix - super().__init__(config, model_dir) - - def build_base_model(self): - from .structbert import SbertModel - return SbertModel(self.config, add_pooling_layer=True) - - def forward(self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - labels=None, - **kwargs): - return super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - labels=labels) - - @classmethod - def _instantiate(cls, **kwargs): - """Instantiate the model. - - @param kwargs: Input args. - model_dir: The model dir used to load the checkpoint and the label information. - num_labels: An optional arg to tell the model how many classes to initialize. - Method will call utils.parse_label_mapping if num_labels not supplied. - If num_labels is not found, the model will use the default setting (2 classes). - @return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained - """ - - model_dir = kwargs.get('model_dir') - num_labels = kwargs.get('num_labels') - if num_labels is None: - label2id = parse_label_mapping(model_dir) - if label2id is not None and len(label2id) > 0: - num_labels = len(label2id) - cls.id2label = {id: label for label, id in label2id.items()} - model_args = {} if num_labels is None else {'num_labels': num_labels} - return super(SbertPreTrainedModel, - SbertForSequenceClassification).from_pretrained( - pretrained_model_name_or_path=kwargs.get('model_dir'), - model_dir=kwargs.get('model_dir'), - **model_args) - - -@MODELS.register_module(Tasks.sentence_similarity, module_name=Models.veco) -@MODELS.register_module( - Tasks.sentiment_classification, module_name=Models.veco) -@MODELS.register_module(Tasks.nli, module_name=Models.veco) -class VecoForSequenceClassification(TorchModel, - VecoForSequenceClassificationTransform): - """Veco sequence classification model. - - Inherited from VecoForSequenceClassification and TorchModel, so this class can be registered into the model set. - This model cannot be inherited from SequenceClassificationBase, because Veco/XlmRoberta's classification structure - is different. - """ - - def __init__(self, config, model_dir): - super().__init__(model_dir) - VecoForSequenceClassificationTransform.__init__(self, config) - - def forward(self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - **kwargs): - return VecoForSequenceClassificationTransform.forward( - self, - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - labels=labels) - - @classmethod - def _instantiate(cls, **kwargs): - """Instantiate the model. - - @param kwargs: Input args. - model_dir: The model dir used to load the checkpoint and the label information. - num_labels: An optional arg to tell the model how many classes to initialize. - Method will call utils.parse_label_mapping if num_labels not supplied. - If num_labels is not found, the model will use the default setting (2 classes). - @return: The loaded model, which is initialized by veco.VecoForSequenceClassification.from_pretrained - """ - - model_dir = kwargs.get('model_dir') - num_labels = kwargs.get('num_labels') - if num_labels is None: - label2id = parse_label_mapping(model_dir) - if label2id is not None and len(label2id) > 0: - num_labels = len(label2id) - - model_args = {} if num_labels is None else {'num_labels': num_labels} - return super(VecoForSequenceClassificationTransform, - VecoForSequenceClassification).from_pretrained( - pretrained_model_name_or_path=kwargs.get('model_dir'), - model_dir=kwargs.get('model_dir'), - **model_args) - - -@MODELS.register_module(Tasks.sentence_similarity, module_name=Models.bert) -@MODELS.register_module( - Tasks.sentiment_classification, module_name=Models.bert) -@MODELS.register_module(Tasks.nli, module_name=Models.bert) -@MODELS.register_module(Tasks.text_classification, module_name=Models.bert) -class BertForSequenceClassification(SequenceClassificationBase, - BertPreTrainedModel): - """Bert sequence classification model. - - Inherited from SequenceClassificationBase. - """ - base_model_prefix: str = 'bert' - supports_gradient_checkpointing = True - _keys_to_ignore_on_load_missing = [r'position_ids'] - - def __init__(self, config, model_dir): - if hasattr(config, 'base_model_prefix'): - BertForSequenceClassification.base_model_prefix = config.base_model_prefix - super().__init__(config, model_dir) - - def build_base_model(self): - from .bert import BertModel - return BertModel(self.config, add_pooling_layer=True) - - def forward(self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - **kwargs): - return super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - labels=labels, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict) - - @classmethod - def _instantiate(cls, **kwargs): - """Instantiate the model. - - @param kwargs: Input args. - model_dir: The model dir used to load the checkpoint and the label information. - num_labels: An optional arg to tell the model how many classes to initialize. - Method will call utils.parse_label_mapping if num_labels not supplied. - If num_labels is not found, the model will use the default setting (2 classes). - @return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained - """ - - model_dir = kwargs.get('model_dir') - num_labels = kwargs.get('num_labels') - if num_labels is None: - label2id = parse_label_mapping(model_dir) - if label2id is not None and len(label2id) > 0: - num_labels = len(label2id) - - model_args = {} if num_labels is None else {'num_labels': num_labels} - return super(BertPreTrainedModel, - BertForSequenceClassification).from_pretrained( - pretrained_model_name_or_path=kwargs.get('model_dir'), - model_dir=kwargs.get('model_dir'), - **model_args) diff --git a/modelscope/models/nlp/space/__init__.py b/modelscope/models/nlp/space/__init__.py index 45f856c1..32713c34 100644 --- a/modelscope/models/nlp/space/__init__.py +++ b/modelscope/models/nlp/space/__init__.py @@ -1,20 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .model import SpaceGenerator - from .model import SpaceModelBase, SpaceTokenizer, SpaceConfig - from .space_for_dialog_intent_prediction import SpaceForDialogIntent - from .space_for_dialog_modeling import SpaceForDialogModeling - from .space_for_dialog_state_tracking import SpaceForDialogStateTracking + from .model import SpaceModelBase, SpaceTokenizer + from .dialog_intent_prediction import SpaceForDialogIntent + from .dialog_modeling import SpaceForDialogModeling + from .dialog_state_tracking import SpaceForDST + from .configuration import SpaceConfig else: _import_structure = { - 'model': - ['SpaceGenerator', 'SpaceModelBase', 'SpaceTokenizer', 'SpaceConfig'], - 'space_for_dialog_intent_prediction': ['SpaceForDialogIntent'], - 'space_for_dialog_modeling': ['SpaceForDialogModeling'], - 'space_for_dialog_state_tracking': ['SpaceForDialogStateTracking'], + 'model': ['SpaceGenerator', 'SpaceModelBase', 'SpaceTokenizer'], + 'dialog_intent_prediction': ['SpaceForDialogIntent'], + 'dialog_modeling': ['SpaceForDialogModeling'], + 'dialog_state_tracking': ['SpaceForDST'], + 'configuration': ['SpaceConfig'] } import sys diff --git a/modelscope/models/nlp/space/model/configuration_space.py b/modelscope/models/nlp/space/configuration.py similarity index 100% rename from modelscope/models/nlp/space/model/configuration_space.py rename to modelscope/models/nlp/space/configuration.py diff --git a/modelscope/models/nlp/space/space_for_dialog_intent_prediction.py b/modelscope/models/nlp/space/dialog_intent_prediction.py similarity index 66% rename from modelscope/models/nlp/space/space_for_dialog_intent_prediction.py rename to modelscope/models/nlp/space/dialog_intent_prediction.py index b93a6d83..79ff01cd 100644 --- a/modelscope/models/nlp/space/space_for_dialog_intent_prediction.py +++ b/modelscope/models/nlp/space/dialog_intent_prediction.py @@ -8,7 +8,7 @@ from modelscope.models import TorchModel from modelscope.models.base import Tensor from modelscope.models.builder import MODELS from modelscope.models.nlp.space import SpaceGenerator, SpaceModelBase -from modelscope.preprocessors.space import IntentBPETextField +from modelscope.preprocessors.nlp import IntentBPETextField from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile, Tasks @@ -24,6 +24,10 @@ class SpaceForDialogIntent(TorchModel): Args: model_dir (str): the model path. + text_field (`BPETextField`, *optional*, defaults to `IntentBPETextField`): + The text field. + config (`Config`, *optional*, defaults to config in model hub): + The config. """ super().__init__(model_dir, *args, **kwargs) @@ -72,10 +76,21 @@ class SpaceForDialogIntent(TorchModel): Example: { 'pred': array([2.62349960e-03 4.12110658e-03 4.12748595e-05 3.77560973e-05 - 1.08599677e-04 1.72710388e-05 2.95618793e-05 1.93638436e-04 - 6.45841064e-05 1.15997791e-04 5.11605394e-05 9.87020373e-01 - 2.66957268e-05 4.72324500e-05 9.74208378e-05], dtype=float32) + 1.08599677e-04 1.72710388e-05 2.95618793e-05 1.93638436e-04 + 6.45841064e-05 1.15997791e-04 5.11605394e-05 9.87020373e-01 + 2.66957268e-05 4.72324500e-05 9.74208378e-05], dtype=float32), } + Example: + >>> from modelscope.hub.snapshot_download import snapshot_download + >>> from modelscope.models.nlp import SpaceForDialogIntent + >>> from modelscope.preprocessors import DialogIntentPredictionPreprocessor + >>> cache_path = snapshot_download('damo/nlp_space_dialog-intent-prediction') + >>> preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path) + >>> model = SpaceForDialogIntent( + model_dir=cache_path, + text_field=preprocessor.text_field, + config=preprocessor.config) + >>> print(model(preprocessor("What do I need to do for the card activation?"))) """ import numpy as np pred = self.trainer.forward(input) diff --git a/modelscope/models/nlp/space/space_for_dialog_modeling.py b/modelscope/models/nlp/space/dialog_modeling.py similarity index 73% rename from modelscope/models/nlp/space/space_for_dialog_modeling.py rename to modelscope/models/nlp/space/dialog_modeling.py index efa9b851..16e9dc53 100644 --- a/modelscope/models/nlp/space/space_for_dialog_modeling.py +++ b/modelscope/models/nlp/space/dialog_modeling.py @@ -8,7 +8,7 @@ from modelscope.models import TorchModel from modelscope.models.base import Tensor from modelscope.models.builder import MODELS from modelscope.models.nlp.space import SpaceGenerator, SpaceModelBase -from modelscope.preprocessors.space import MultiWOZBPETextField +from modelscope.preprocessors.nlp import MultiWOZBPETextField from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile, Tasks @@ -23,7 +23,12 @@ class SpaceForDialogModeling(TorchModel): """initialize the test generation model from the `model_dir` path. Args: - model_dir (str): the model path. + model_dir (`str`): + The model path. + text_field (`BPETextField`, *optional*, defaults to `MultiWOZBPETextField`): + The text field. + config (`Config`, *optional*, defaults to config in model hub): + The config. """ super().__init__(model_dir, *args, **kwargs) @@ -82,6 +87,19 @@ class SpaceForDialogModeling(TorchModel): 'aspn': array([47,8345,32,29,1983]), 'db': array([19, 24, 20]), } + Examples: + >>> from modelscope.hub.snapshot_download import snapshot_download + >>> from modelscope.models.nlp import SpaceForDialogModeling + >>> from modelscope.preprocessors import DialogModelingPreprocessor + >>> cache_path = snapshot_download('damo/nlp_space_dialog-modeling') + >>> preprocessor = DialogModelingPreprocessor(model_dir=cache_path) + >>> model = SpaceForDialogModeling(model_dir=cache_path, + text_field=preprocessor.text_field, + config=preprocessor.config) + >>> print(model(preprocessor({ + 'user_input': 'i would like a taxi from saint john \'s college to pizza hut fen ditton .', + 'history': {} + }))) """ first_turn = input['first_turn'] diff --git a/modelscope/models/nlp/space/model/modeling_space.py b/modelscope/models/nlp/space/dialog_state_tracking.py similarity index 57% rename from modelscope/models/nlp/space/model/modeling_space.py rename to modelscope/models/nlp/space/dialog_state_tracking.py index f093cbc5..9a713a59 100644 --- a/modelscope/models/nlp/space/model/modeling_space.py +++ b/modelscope/models/nlp/space/dialog_state_tracking.py @@ -1,6 +1,6 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. # Copyright 2019 Facebook AI Research and the HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. -# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,14 +16,22 @@ # limitations under the License. """PyTorch Space model. mainly copied from :module:`~transformers.modeling_xlm_roberta`""" +from typing import Dict + import torch from torch import nn from torch.nn import CrossEntropyLoss from transformers.file_utils import add_start_docstrings +from transformers.modeling_utils import PreTrainedModel -from modelscope.models.nlp.structbert.modeling_sbert import ( - SbertForMaskedLM, SbertModel, SbertPreTrainedModel) -from .configuration_space import SpaceConfig +from modelscope.metainfo import Models +from modelscope.models import Model, TorchModel +from modelscope.models.base import Tensor +from modelscope.models.builder import MODELS +from modelscope.models.nlp.structbert import (SbertForMaskedLM, SbertModel, + SbertPreTrainedModel) +from modelscope.utils.constant import Tasks +from .configuration import SpaceConfig SPACE_START_DOCSTRING = r""" @@ -57,6 +65,63 @@ class SpaceModel(SbertModel): config_class = SpaceConfig +class SpacePreTrainedModel(TorchModel, PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SpaceConfig + base_model_prefix = 'bert' + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r'position_ids'] + + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + @param kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the label information. + num_labels: An optional arg to tell the model how many classes to initialize. + Method will call utils.parse_label_mapping if num_labels is not input. + label2id: An optional label2id mapping, which will cover the label2id in configuration (if exists). + + @return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained + """ + + model_dir = kwargs.pop('model_dir', None) + if model_dir is None: + config = SpaceConfig(**kwargs) + model = cls(config) + else: + model_kwargs = {} + model = super(Model, cls).from_pretrained( + pretrained_model_name_or_path=model_dir, **model_kwargs) + return model + + @add_start_docstrings( """ Space Model transformer with Dialog state tracking heads on top (a inform projection @@ -65,7 +130,9 @@ class SpaceModel(SbertModel): """, SPACE_START_DOCSTRING, ) -class SpaceForDST(SbertPreTrainedModel): +@MODELS.register_module( + Tasks.task_oriented_conversation, module_name=Models.space_dst) +class SpaceForDST(SpacePreTrainedModel): def __init__(self, config): super(SpaceForDST, self).__init__(config) @@ -113,18 +180,105 @@ class SpaceForDST(SbertPreTrainedModel): self.init_weights() - def forward(self, - input_ids, - input_mask=None, - segment_ids=None, - position_ids=None, - head_mask=None, - start_pos=None, - end_pos=None, - inform_slot_id=None, - refer_id=None, - class_label_id=None, - diag_state=None): + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + + Returns: + Dict[str, Tensor]: results + Example: + { + 'inputs': dict(input_ids, input_masks,start_pos), # tracking states + 'outputs': dict(slots_logits), + 'unique_ids': str(test-example.json-0), # default value + 'input_ids_unmasked': array([101, 7632, 1010,0,0,0]) + 'values': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]), + 'inform': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]), + 'prefix': str('final'), #default value + 'ds': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]) + } + + Example: + >>> from modelscope.hub.snapshot_download import snapshot_download + >>> from modelscope.models.nlp import SpaceForDST + >>> from modelscope.preprocessors import DialogStateTrackingPreprocessor + >>> cache_path = snapshot_download('damo/nlp_space_dialog-state-tracking') + >>> model = SpaceForDST.from_pretrained(cache_path) + >>> preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path) + >>> print(model(preprocessor({ + 'utter': { + 'User-1': "Hi, I'm looking for a train that is going" + "to cambridge and arriving there by 20:45, is there anything like that?" + }, + 'history_states': [{}] + }))) + """ + import numpy as np + import torch + + # self.model.eval() ???? + batch = input['batch'] + + features = input['features'] + diag_state = input['diag_state'] + turn_itrs = [features[i.item()].guid.split('-')[2] for i in batch[9]] + reset_diag_state = np.where(np.array(turn_itrs) == '0')[0] + for slot in self.config.dst_slot_list: + for i in reset_diag_state: + diag_state[slot][i] = 0 + + with torch.no_grad(): + inputs = { + 'input_ids': batch[0], + 'input_mask': batch[1], + 'segment_ids': batch[2], + 'start_pos': batch[3], + 'end_pos': batch[4], + 'inform_slot_id': batch[5], + 'refer_id': batch[6], + 'diag_state': diag_state, + 'class_label_id': batch[8] + } + unique_ids = [features[i.item()].guid for i in batch[9]] + values = [features[i.item()].values for i in batch[9]] + input_ids_unmasked = [ + features[i.item()].input_ids_unmasked for i in batch[9] + ] + inform = [features[i.item()].inform for i in batch[9]] + outputs = self._forward(**inputs) + + # Update dialog state for next turn. + for slot in self.config.dst_slot_list: + updates = outputs[2][slot].max(1)[1] + for i, u in enumerate(updates): + if u != 0: + diag_state[slot][i] = u + + return { + 'inputs': inputs, + 'outputs': outputs, + 'unique_ids': unique_ids, + 'input_ids_unmasked': input_ids_unmasked, + 'values': values, + 'inform': inform, + 'prefix': 'final', + 'ds': input['ds'] + } + + def _forward(self, + input_ids, + input_mask=None, + segment_ids=None, + position_ids=None, + head_mask=None, + start_pos=None, + end_pos=None, + inform_slot_id=None, + refer_id=None, + class_label_id=None, + diag_state=None): outputs = self.bert( input_ids, attention_mask=input_mask, @@ -132,8 +286,8 @@ class SpaceForDST(SbertPreTrainedModel): position_ids=position_ids, head_mask=head_mask) - sequence_output = outputs[0] - pooled_output = outputs[1] + sequence_output = outputs.last_hidden_state + pooled_output = outputs.pooler_output sequence_output = self.dropout(sequence_output) pooled_output = self.dropout(pooled_output) @@ -233,36 +387,6 @@ class SpaceForDST(SbertPreTrainedModel): per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits, - ) + outputs[2:] + ) + (outputs.embedding_output, ) return outputs - - -@add_start_docstrings( - 'The Space Model Model with a `language modeling` head on tops', - SPACE_START_DOCSTRING, -) -class SpaceForMaskedLM(SbertForMaskedLM): - """ - This class overrides [`SbertForMaskedLM`]. Please check the superclass for the - appropriate documentation alongside usage examples. - """ - - config_class = SpaceConfig - - -@add_start_docstrings( - """ - Space Model with only one head on top as done during the pretraining: a `masked language modeling` head. - """, - SPACE_START_DOCSTRING, -) -class SpaceForPreTraining(SbertPreTrainedModel): - - def __init__(self, model_name_or_path: str): - super(SpaceForPreTraining, self).__init__() - self.bert_model = SpaceForMaskedLM.from_pretrained(model_name_or_path) - - def forward(self, input_ids: torch.tensor, mlm_labels: torch.tensor): - outputs = self.bert_model(input_ids, masked_lm_labels=mlm_labels) - return outputs[0] diff --git a/modelscope/models/nlp/space/model/__init__.py b/modelscope/models/nlp/space/model/__init__.py index bb1d18e4..cfff335d 100644 --- a/modelscope/models/nlp/space/model/__init__.py +++ b/modelscope/models/nlp/space/model/__init__.py @@ -1,10 +1,8 @@ -from .configuration_space import SpaceConfig +# Copyright (c) Alibaba, Inc. and its affiliates. from .gen_unified_transformer import GenUnifiedTransformer from .generator import SpaceGenerator from .intent_unified_transformer import IntentUnifiedTransformer from .model_base import SpaceModelBase -from .modeling_space import (SpaceForDST, SpaceForMaskedLM, - SpaceForPreTraining, SpaceModel) from .tokenization_space import (BasicTokenizer, SpaceTokenizer, WordpieceTokenizer) from .unified_transformer import UnifiedTransformer diff --git a/modelscope/models/nlp/space/model/generator.py b/modelscope/models/nlp/space/model/generator.py index 0e7833e6..2e05b545 100644 --- a/modelscope/models/nlp/space/model/generator.py +++ b/modelscope/models/nlp/space/model/generator.py @@ -71,14 +71,11 @@ class SpaceGenerator(object): return def __call__(self, step_fn, state): - """ - Running generation. - - @param : step_fn : decoding one step - @type : function + """Running generation. - @param : state : initial state - @type : dict + Args: + step_fn (`function`) : decoding one step + state(`dict`) : initial state """ raise NotImplementedError @@ -104,11 +101,9 @@ class BeamSearch(SpaceGenerator): """ Running beam search. - @param : step_fn : decoding one step - @type : function - - @param : state : initial state - @type : dict + Args: + step_fn(`function`) : decoding one step + state(`dict`) : initial state """ if prev_input is not None: diff --git a/modelscope/models/nlp/space/model/model_base.py b/modelscope/models/nlp/space/model/model_base.py index d3d0baa4..b7812182 100644 --- a/modelscope/models/nlp/space/model/model_base.py +++ b/modelscope/models/nlp/space/model/model_base.py @@ -64,8 +64,8 @@ class SpaceModelBase(nn.Module): """ Forward process, include real forward, collect metrices and optimize(optional) - @params : inputs : input data - @type : dict of numpy.ndarray/int/float/... + Args: + inputs(`dict` of numpy.ndarray/int/float/...) : input data """ if is_training: self.train() @@ -85,11 +85,10 @@ class SpaceModelBase(nn.Module): eos_id=None, max_gen_len=None, prev_input=None): - """ - Inference process. + """Inference process. - @params : inputs : input data - @type : dict of numpy.ndarray/int/float/... + Args: + inputs(`dict` of numpy.ndarray/int/float/...) : input data """ self.eval() results = self._infer( diff --git a/modelscope/models/nlp/space/model/tokenization_space.py b/modelscope/models/nlp/space/model/tokenization_space.py index 84712b7b..e3b358d4 100644 --- a/modelscope/models/nlp/space/model/tokenization_space.py +++ b/modelscope/models/nlp/space/model/tokenization_space.py @@ -1,5 +1,5 @@ -# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/modelscope/models/nlp/space/model/unified_transformer.py b/modelscope/models/nlp/space/model/unified_transformer.py index b0775541..19069971 100644 --- a/modelscope/models/nlp/space/model/unified_transformer.py +++ b/modelscope/models/nlp/space/model/unified_transformer.py @@ -119,15 +119,12 @@ class UnifiedTransformer(SpaceModelBase): input_mask, append_head=False, auto_regressive=False): - """ - Create attention mask. + """Create attention mask. from sequence to matrix:[batch_size, max_seq_len, 1] -> [batch_size, max_seq_len, max_seq_len] - @param : input_mask - @type : Variable(shape: [batch_size, max_seq_len]) - - @param : auto_regressive - @type : bool + Args: + input_mask (Variable(shape: [batch_size, max_seq_len])) + auto_regressive(bool) """ seq_len = input_mask.shape[1] @@ -150,15 +147,12 @@ class UnifiedTransformer(SpaceModelBase): return mask def _join_mask(self, mask1, mask2): - """ - Merge source attention mask and target attention mask. + """Merge source attention mask and target attention mask. There are four parts:left upper (lu) / right upper (ru) / left below (lb) / right below (rb) - @param : mask1 : source attention mask - @type : Variable(shape: [batch_size, max_src_len, max_src_len]) - - @param : mask1 : target attention mask - @type : Variable(shape: [batch_size, max_tgt_len, max_tgt_len]) + Args: + mask1(Variable(shape: [batch_size, max_src_len, max_src_len])) : source attention mask + mask2(Variable(shape: [batch_size, max_tgt_len, max_tgt_len])) : target attention mask """ batch_size = mask1.shape[0] seq_len1 = mask1.shape[1] diff --git a/modelscope/models/nlp/space/modules/transformer_block.py b/modelscope/models/nlp/space/modules/transformer_block.py index 37f968d9..3044963a 100644 --- a/modelscope/models/nlp/space/modules/transformer_block.py +++ b/modelscope/models/nlp/space/modules/transformer_block.py @@ -30,18 +30,13 @@ class TransformerBlock(nn.Module): return def forward(self, inp, mask=None, cache=None): - """ - Forward process on one transformer layer. - - @param : x - @type : Variable(shape: [batch_size, seq_len, hidden_size]) - - @param : memory - @type : Variable(shape: [batch_size, seq_len, hidden_size]) - - @param : mask + """Forward process on one transformer layer. - @param : cache + Args: + x(Variable(shape: [batch_size, seq_len, hidden_size])) + memory(Variable(shape: [batch_size, seq_len, hidden_size])) + mask + cache """ attn_out = self.attn(inp, mask, cache) attn_out = self.dropout_layer(attn_out) diff --git a/modelscope/models/nlp/space/space_for_dialog_state_tracking.py b/modelscope/models/nlp/space/space_for_dialog_state_tracking.py deleted file mode 100644 index 4b9cf5c3..00000000 --- a/modelscope/models/nlp/space/space_for_dialog_state_tracking.py +++ /dev/null @@ -1,101 +0,0 @@ -from typing import Dict - -from modelscope.metainfo import Models -from modelscope.models import TorchModel -from modelscope.models.base import Tensor -from modelscope.models.builder import MODELS -from modelscope.utils.constant import Tasks - -__all__ = ['SpaceForDialogStateTracking'] - - -@MODELS.register_module( - Tasks.task_oriented_conversation, module_name=Models.space_dst) -class SpaceForDialogStateTracking(TorchModel): - - def __init__(self, model_dir: str, *args, **kwargs): - """initialize the test generation model from the `model_dir` path. - - Args: - model_dir (str): the model path. - """ - - super().__init__(model_dir, *args, **kwargs) - - from modelscope.models.nlp.space.model import SpaceForDST, SpaceConfig - self.model_dir = model_dir - - self.config = SpaceConfig.from_pretrained(self.model_dir) - self.model = SpaceForDST.from_pretrained(self.model_dir) - - def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: - """return the result by the model - - Args: - input (Dict[str, Tensor]): the preprocessed data - - Returns: - Dict[str, Tensor]: results - Example: - { - 'inputs': dict(input_ids, input_masks,start_pos), # tracking states - 'outputs': dict(slots_logits), - 'unique_ids': str(test-example.json-0), # default value - 'input_ids_unmasked': array([101, 7632, 1010,0,0,0]) - 'values': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]), - 'inform': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]), - 'prefix': str('final'), #default value - 'ds': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]) - } - """ - import numpy as np - import torch - - self.model.eval() - batch = input['batch'] - - features = input['features'] - diag_state = input['diag_state'] - turn_itrs = [features[i.item()].guid.split('-')[2] for i in batch[9]] - reset_diag_state = np.where(np.array(turn_itrs) == '0')[0] - for slot in self.config.dst_slot_list: - for i in reset_diag_state: - diag_state[slot][i] = 0 - - with torch.no_grad(): - inputs = { - 'input_ids': batch[0], - 'input_mask': batch[1], - 'segment_ids': batch[2], - 'start_pos': batch[3], - 'end_pos': batch[4], - 'inform_slot_id': batch[5], - 'refer_id': batch[6], - 'diag_state': diag_state, - 'class_label_id': batch[8] - } - unique_ids = [features[i.item()].guid for i in batch[9]] - values = [features[i.item()].values for i in batch[9]] - input_ids_unmasked = [ - features[i.item()].input_ids_unmasked for i in batch[9] - ] - inform = [features[i.item()].inform for i in batch[9]] - outputs = self.model(**inputs) - - # Update dialog state for next turn. - for slot in self.config.dst_slot_list: - updates = outputs[2][slot].max(1)[1] - for i, u in enumerate(updates): - if u != 0: - diag_state[slot][i] = u - - return { - 'inputs': inputs, - 'outputs': outputs, - 'unique_ids': unique_ids, - 'input_ids_unmasked': input_ids_unmasked, - 'values': values, - 'inform': inform, - 'prefix': 'final', - 'ds': input['ds'] - } diff --git a/modelscope/models/nlp/space_T_cn/__init__.py b/modelscope/models/nlp/space_T_cn/__init__.py new file mode 100644 index 00000000..b9deb700 --- /dev/null +++ b/modelscope/models/nlp/space_T_cn/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .table_question_answering import TableQuestionAnswering +else: + _import_structure = { + 'table_question_answering': ['TableQuestionAnswering'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/star3/modeling_star3.py b/modelscope/models/nlp/space_T_cn/backbone.py similarity index 98% rename from modelscope/models/nlp/star3/modeling_star3.py rename to modelscope/models/nlp/space_T_cn/backbone.py index 13f7136a..5afde06e 100644 --- a/modelscope/models/nlp/star3/modeling_star3.py +++ b/modelscope/models/nlp/space_T_cn/backbone.py @@ -1,6 +1,6 @@ +# Copyright 2021-2022 The Alibaba DAMO Team Authors. All rights reserved. # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# Copyright 2021-2022 The Alibaba DAMO Team Authors. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,7 +27,7 @@ import numpy as np import torch from torch import nn -from modelscope.models.nlp.star3.configuration_star3 import Star3Config +from modelscope.models.nlp.space_T_cn.configuration import SpaceTCnConfig from modelscope.utils.constant import ModelFile from modelscope.utils.logger import get_logger @@ -609,9 +609,9 @@ class PreTrainedBertModel(nn.Module): def __init__(self, config, *inputs, **kwargs): super(PreTrainedBertModel, self).__init__() - if not isinstance(config, Star3Config): + if not isinstance(config, SpaceTCnConfig): raise ValueError( - 'Parameter config in `{}(config)` should be an instance of class `Star3Config`. ' + 'Parameter config in `{}(config)` should be an instance of class `SpaceTCnConfig`. ' 'To create a model from a Google pretrained model use ' '`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`'.format( self.__class__.__name__, self.__class__.__name__)) @@ -676,7 +676,7 @@ class PreTrainedBertModel(nn.Module): serialization_dir = tempdir # Load config config_file = os.path.join(serialization_dir, CONFIG_NAME) - config = Star3Config.from_json_file(config_file) + config = SpaceTCnConfig.from_json_file(config_file) logger.info('Model config {}'.format(config)) # Instantiate model. model = cls(config, *inputs, **kwargs) @@ -742,11 +742,11 @@ class PreTrainedBertModel(nn.Module): return model -class Star3Model(PreTrainedBertModel): - """Star3Model model ("Bidirectional Embedding Representations from a Transformer pretrained on STAR3.0"). +class SpaceTCnModel(PreTrainedBertModel): + """SpaceTCnModel model ("Bidirectional Embedding Representations from a Transformer pretrained on STAR-T-CN"). Params: - config: a Star3Config class instance with the configuration to build a new model + config: a SpaceTCnConfig class instance with the configuration to build a new model Inputs: `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] @@ -780,16 +780,16 @@ class Star3Model(PreTrainedBertModel): input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - config = modeling.Star3Config(vocab_size_or_config_json_file=32000, hidden_size=768, + config = modeling.SpaceTCnConfig(vocab_size_or_config_json_file=32000, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - model = modeling.Star3Model(config=config) + model = modeling.SpaceTCnModel(config=config) all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) ``` """ def __init__(self, config, schema_link_module='none'): - super(Star3Model, self).__init__(config) + super(SpaceTCnModel, self).__init__(config) self.embeddings = BertEmbeddings(config) self.encoder = BertEncoder( config, schema_link_module=schema_link_module) diff --git a/modelscope/models/nlp/star3/configuration_star3.py b/modelscope/models/nlp/space_T_cn/configuration.py similarity index 91% rename from modelscope/models/nlp/star3/configuration_star3.py rename to modelscope/models/nlp/space_T_cn/configuration.py index 4c5ae677..e698b310 100644 --- a/modelscope/models/nlp/star3/configuration_star3.py +++ b/modelscope/models/nlp/space_T_cn/configuration.py @@ -1,6 +1,6 @@ +# Copyright 2021-2022 The Alibaba DAMO Team Authors. All rights reserved. # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# Copyright 2021-2022 The Alibaba DAMO Team Authors. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,8 +24,8 @@ import json logger = logging.getLogger(__name__) -class Star3Config(object): - """Configuration class to store the configuration of a `Star3Model`. +class SpaceTCnConfig(object): + """Configuration class to store the configuration of a `SpaceTCnModel`. """ def __init__(self, @@ -40,10 +40,10 @@ class Star3Config(object): max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02): - """Constructs Star3Config. + """Constructs SpaceTCnConfig. Args: - vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `Star3Model`. + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `SpaceTCnConfig`. hidden_size: Size of the encoder layers and the pooler layer. num_hidden_layers: Number of hidden layers in the Transformer encoder. num_attention_heads: Number of attention heads for each attention layer in @@ -59,7 +59,7 @@ class Star3Config(object): max_position_embeddings: The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). - type_vocab_size: The vocabulary size of the `token_type_ids` passed into `Star3Model`. + type_vocab_size: The vocabulary size of the `token_type_ids` passed into `SpaceTCnConfig`. initializer_range: The sttdev of the truncated_normal_initializer for initializing all weight matrices. """ @@ -89,15 +89,15 @@ class Star3Config(object): @classmethod def from_dict(cls, json_object): - """Constructs a `Star3Config` from a Python dictionary of parameters.""" - config = Star3Config(vocab_size_or_config_json_file=-1) + """Constructs a `SpaceTCnConfig` from a Python dictionary of parameters.""" + config = SpaceTCnConfig(vocab_size_or_config_json_file=-1) for key, value in json_object.items(): config.__dict__[key] = value return config @classmethod def from_json_file(cls, json_file): - """Constructs a `Star3Config` from a json file of parameters.""" + """Constructs a `SpaceTCnConfig` from a json file of parameters.""" with open(json_file, 'r', encoding='utf-8') as reader: text = reader.read() return cls.from_dict(json.loads(text)) diff --git a/modelscope/models/nlp/table_question_answering.py b/modelscope/models/nlp/space_T_cn/table_question_answering.py similarity index 92% rename from modelscope/models/nlp/table_question_answering.py rename to modelscope/models/nlp/space_T_cn/table_question_answering.py index 3c91a518..a3f504b7 100644 --- a/modelscope/models/nlp/table_question_answering.py +++ b/modelscope/models/nlp/space_T_cn/table_question_answering.py @@ -11,17 +11,17 @@ from transformers import BertTokenizer from modelscope.metainfo import Models from modelscope.models.base import Model, Tensor from modelscope.models.builder import MODELS -from modelscope.models.nlp.star3.configuration_star3 import Star3Config -from modelscope.models.nlp.star3.modeling_star3 import Seq2SQL, Star3Model -from modelscope.preprocessors.star3.fields.struct import Constant +from modelscope.preprocessors.nlp.space_T_cn.fields.struct import Constant from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.device import verify_device +from .backbone import Seq2SQL, SpaceTCnModel +from .configuration import SpaceTCnConfig __all__ = ['TableQuestionAnswering'] @MODELS.register_module( - Tasks.table_question_answering, module_name=Models.star3) + Tasks.table_question_answering, module_name=Models.space_T_cn) class TableQuestionAnswering(Model): def __init__(self, model_dir: str, *args, **kwargs): @@ -41,9 +41,9 @@ class TableQuestionAnswering(Model): os.path.join(self.model_dir, ModelFile.TORCH_MODEL_BIN_FILE), map_location='cpu') - self.backbone_config = Star3Config.from_json_file( + self.backbone_config = SpaceTCnConfig.from_json_file( os.path.join(self.model_dir, ModelFile.CONFIGURATION)) - self.backbone_model = Star3Model( + self.backbone_model = SpaceTCnModel( config=self.backbone_config, schema_link_module='rat') self.backbone_model.load_state_dict(state_dict['backbone_model']) @@ -82,7 +82,6 @@ class TableQuestionAnswering(Model): if ntok.startswith('##'): ntok = ntok.replace('##', '') - tok = nlu1[idx:idx + 1].lower() if ntok == tok: conv_dict[i] = [idx, idx + 1] @@ -690,11 +689,11 @@ class TableQuestionAnswering(Model): sels.append(l_hs[ib] - 1) aggs.append(sql['agg'][ia]) continue - sels.append(sel) + sels.append(int(sel)) if sql['agg'][ia] == -1: aggs.append(0) else: - aggs.append(sql['agg'][ia]) + aggs.append(int(sql['agg'][ia])) if len(sels) == 0: sels.append(l_hs[ib] - 1) aggs.append(0) @@ -711,7 +710,7 @@ class TableQuestionAnswering(Model): for i in range(wl): if wc_os[i] == -1: continue - conds.append([wc_os[i], wo_os[i], pr_wvi_str[ib][i]]) + conds.append([int(wc_os[i]), int(wo_os[i]), pr_wvi_str[ib][i]]) if len(conds) == 0: conds.append([l_hs[ib] - 1, 2, 'Nulll']) sql['conds'] = conds @@ -733,9 +732,41 @@ class TableQuestionAnswering(Model): Args: input (Dict[str, Tensor]): the preprocessed data + Returns: Dict[str, Tensor]: results Example: + { + 'result': + { + 'question_tok': ['有', '哪', '些', '风', '险', '类', '型', '?'], + 'question': '有哪些风险类型?', + 'table_id': 'fund', + 'sql': { + 'cond_conn_op': 0, + 'sel': [5], + 'agg': [0], + 'conds': [[10, 2, 'Nulll']] + }, + 'action': 10, + 'model_out': [ + [6, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [2, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0] + ] + }, + 'history_sql': None + } + + Example: + >>> from modelscope.models.nlp import TableQuestionAnswering + >>> from modelscope.preprocessors import TableQuestionAnsweringPreprocessor + >>> model = TableQuestionAnswering.from_pretrained('damo/nlp_convai_text2sql_pretrain_cn') + >>> preprocessor = TableQuestionAnsweringPreprocessor(model_dir=model.model_dir) + >>> print(model(preprocessor({'question': '有哪些风险类型?'}))) """ result = self.predict(input['datas'])[0] diff --git a/modelscope/models/nlp/space_T_en/__init__.py b/modelscope/models/nlp/space_T_en/__init__.py new file mode 100644 index 00000000..46c8b38c --- /dev/null +++ b/modelscope/models/nlp/space_T_en/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .text_to_sql import StarForTextToSql +else: + _import_structure = { + 'text_to_sql': ['StarForTextToSql'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/star_text_to_sql.py b/modelscope/models/nlp/space_T_en/text_to_sql.py similarity index 57% rename from modelscope/models/nlp/star_text_to_sql.py rename to modelscope/models/nlp/space_T_en/text_to_sql.py index eef76e8a..ca2d2596 100644 --- a/modelscope/models/nlp/star_text_to_sql.py +++ b/modelscope/models/nlp/space_T_en/text_to_sql.py @@ -4,14 +4,13 @@ import os from typing import Dict, Optional import torch -import torch.nn as nn from text2sql_lgesql.asdl.asdl import ASDLGrammar from text2sql_lgesql.asdl.transition_system import TransitionSystem from text2sql_lgesql.model.model_constructor import Text2SQL -from text2sql_lgesql.utils.constants import GRAMMAR_FILEPATH from modelscope.metainfo import Models -from modelscope.models.base import Model, Tensor +from modelscope.models import TorchModel +from modelscope.models.base import Tensor from modelscope.models.builder import MODELS from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile, Tasks @@ -20,8 +19,8 @@ __all__ = ['StarForTextToSql'] @MODELS.register_module( - Tasks.conversational_text_to_sql, module_name=Models.star) -class StarForTextToSql(Model): + Tasks.table_question_answering, module_name=Models.space_T_en) +class StarForTextToSql(TorchModel): def __init__(self, model_dir: str, *args, **kwargs): """initialize the star model from the `model_dir` path. @@ -59,6 +58,33 @@ class StarForTextToSql(Model): Returns: Dict[str, Tensor]: results Example: + + Example: + >>> from modelscope.hub.snapshot_download import snapshot_download + >>> from modelscope.models.nlp import StarForTextToSql + >>> from modelscope.preprocessors import ConversationalTextToSqlPreprocessor + >>> test_case = { + 'database_id': 'employee_hire_evaluation', + 'local_db_path': None, + 'utterance': [ + "I'd like to see Shop names.", 'Which of these are hiring?', + 'Which shop is hiring the highest number of employees?' + ' | do you want the name of the shop ? | Yes' + ] + } + >>> cache_path = snapshot_download('damo/nlp_star_conversational-text-to-sql') + >>> preprocessor = ConversationalTextToSqlPreprocessor( + model_dir=cache_path, + database_id=test_case['database_id'], + db_content=True) + >>> model = StarForTextToSql(cache_path, config=preprocessor.config) + >>> print(model(preprocessor({ + 'utterance': "I'd like to see Shop names.", + 'history': [], + 'last_sql': '', + 'database_id': 'employee_hire_evaluation', + 'local_db_path': None + }))) """ self.model.eval() hyps = self.model.parse(input['batch'], self.beam_size) # diff --git a/modelscope/models/nlp/structbert/__init__.py b/modelscope/models/nlp/structbert/__init__.py index d42db83c..60d369e0 100644 --- a/modelscope/models/nlp/structbert/__init__.py +++ b/modelscope/models/nlp/structbert/__init__.py @@ -18,20 +18,26 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .configuration_sbert import SbertConfig - from .modeling_sbert import (SbertForMaskedLM, SbertModel, - SbertPreTrainedModel) - from .tokenization_sbert import (BasicTokenizer, SbertTokenizer, - WordpieceTokenizer) - from .tokenization_sbert_fast import SbertTokenizerFast + from .backbone import (SbertModel, SbertPreTrainedModel) + from .configuration import SbertConfig + from .faq_question_answering import SbertForFaqQuestionAnswering + from .fill_mask import SbertForMaskedLM + from .text_classification import SbertForSequenceClassification + from .token_classification import SbertForTokenClassification + from .tokenization import (BasicTokenizer, SbertTokenizer, + WordpieceTokenizer) + from .tokenization_fast import SbertTokenizerFast else: _import_structure = { - 'configuration_sbert': ['SbertConfig'], - 'modeling_sbert': - ['SbertForMaskedLM', 'SbertModel', 'SbertPreTrainedModel'], - 'tokenization_sbert': + 'backbone': ['SbertModel', 'SbertPreTrainedModel'], + 'configuration': ['SbertConfig'], + 'fill_mask': ['SbertForMaskedLM'], + 'faq_question_answering': ['SbertForFaqQuestionAnswering'], + 'text_classification': ['SbertForSequenceClassification'], + 'token_classification': ['SbertForTokenClassification'], + 'tokenization': ['BasicTokenizer', 'SbertTokenizer', 'WordpieceTokenizer'], - 'tokenization_sbert_fast': ['SbertTokenizerFast'], + 'tokenization_fast': ['SbertTokenizerFast'], } import sys diff --git a/modelscope/models/nlp/structbert/adv_utils.py b/modelscope/models/nlp/structbert/adv_utils.py index 44aae85c..91a4cb82 100644 --- a/modelscope/models/nlp/structbert/adv_utils.py +++ b/modelscope/models/nlp/structbert/adv_utils.py @@ -98,7 +98,7 @@ def compute_adv_loss(embedding, if is_nan: logger.warning('Nan occured when calculating adv loss.') return ori_loss - emb_grad = emb_grad / emb_grad_norm + emb_grad = emb_grad / (emb_grad_norm + 1e-6) embedding_2 = embedding_1 + adv_grad_factor * emb_grad embedding_2 = torch.max(embedding_1 - adv_bound, embedding_2) embedding_2 = torch.min(embedding_1 + adv_bound, embedding_2) diff --git a/modelscope/models/nlp/structbert/backbone.py b/modelscope/models/nlp/structbert/backbone.py new file mode 100755 index 00000000..039db3ce --- /dev/null +++ b/modelscope/models/nlp/structbert/backbone.py @@ -0,0 +1,932 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch StructBERT model. mainly copied from :module:`~transformers.modeling_bert`""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from packaging import version +from transformers.activations import ACT2FN +from transformers.modeling_outputs import \ + BaseModelOutputWithPastAndCrossAttentions +from transformers.modeling_utils import (PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer) + +from modelscope.metainfo import Models +from modelscope.models import Model, TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionBackboneModelOutput +from modelscope.utils.constant import Tasks +from modelscope.utils.hub import parse_label_mapping +from modelscope.utils.logger import get_logger +from .configuration import SbertConfig + +logger = get_logger(__name__) + + +class SbertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + self.register_buffer( + 'position_ids', + torch.arange(config.max_position_embeddings).expand((1, -1))) + if version.parse(torch.__version__) > version.parse('1.6.0'): + self.register_buffer( + 'token_type_ids', + torch.zeros( + self.position_ids.size(), + dtype=torch.long, + device=self.position_ids.device), + persistent=False, + ) + + def forward(self, + input_ids=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0, + return_inputs_embeds=False): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, + past_key_values_length:seq_length + + past_key_values_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users + # when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, 'token_type_ids'): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand( + input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros( + input_shape, + dtype=torch.long, + device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == 'absolute': + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + if not return_inputs_embeds: + return embeddings + else: + return embeddings, inputs_embeds + + +class SbertSelfAttention(nn.Module): + + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, 'embedding_size'): + raise ValueError( + f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention ' + f'heads ({config.num_attention_heads})') + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size + / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, + self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores( + self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores( + self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + + if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == 'relative_key': + relative_position_scores = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == 'relative_key_query': + relative_position_scores_query = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + relative_position_scores_key = torch.einsum( + 'bhrd,lrd->bhlr', key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in SbertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if output_attentions else (context_layer, ) + + if self.is_decoder: + outputs = outputs + (past_key_value, ) + return outputs + + +class SbertSelfOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class SbertAttention(nn.Module): + + def __init__(self, config): + super().__init__() + self.self = SbertSelfAttention(config) + self.output = SbertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, + self.self.attention_head_size, self.pruned_heads) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len( + heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, + ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class SbertIntermediate(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class SbertOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class SbertLayer(nn.Module): + + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = SbertAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError( + f'{self} should be used as a decoder model if cross attention is added' + ) + self.crossattention = SbertAttention(config) + self.intermediate = SbertIntermediate(config) + self.output = SbertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[: + 2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[ + 1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, 'crossattention'): + raise ValueError( + f'If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention ' + f'layers by setting `config.add_cross_attention=True`') + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[ + -2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[ + 1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward(self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output) + outputs = (layer_output, ) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value, ) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class SbertEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [SbertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + ) if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' + ) + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, + output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_outputs[1], ) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + ( + layer_outputs[2], ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class SbertPooler(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class SbertPreTrainedModel(TorchModel, PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SbertConfig + base_model_prefix = 'bert' + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r'position_ids'] + + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, SbertEncoder): + module.gradient_checkpointing = value + + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + Args: + kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the label information. + num_labels: An optional arg to tell the model how many classes to initialize. + Method will call utils.parse_label_mapping if num_labels is not input. + label2id: An optional label2id mapping, which will cover the label2id in configuration (if exists). + + Returns: + The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained + """ + + model_dir = kwargs.pop('model_dir', None) + if model_dir is None: + config = SbertConfig(**kwargs) + model = cls(config) + else: + model_kwargs = {} + label2id = kwargs.get('label2id', parse_label_mapping(model_dir)) + id2label = kwargs.get( + 'id2label', None if label2id is None else + {id: label + for label, id in label2id.items()}) + if id2label is not None and label2id is None: + label2id = {label: id for id, label in id2label.items()} + + num_labels = kwargs.get( + 'num_labels', None if label2id is None else len(label2id)) + if num_labels is not None: + model_kwargs['num_labels'] = num_labels + if label2id is not None: + model_kwargs['label2id'] = label2id + if id2label is not None: + model_kwargs['id2label'] = id2label + model = super(Model, cls).from_pretrained( + pretrained_model_name_or_path=model_dir, **model_kwargs) + return model + + +@dataclass +class AttentionBackboneModelOutputWithEmbedding(AttentionBackboneModelOutput): + embedding_output: torch.FloatTensor = None + logits: Optional[Union[tuple, torch.FloatTensor]] = None + kwargs: dict = None + + +@MODELS.register_module(Tasks.backbone, module_name=Models.structbert) +class SbertModel(SbertPreTrainedModel): + """The StructBERT Model transformer outputting raw hidden-states without any specific head on top. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with + all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration + set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder` + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config: SbertConfig, add_pooling_layer=True, **kwargs): + super().__init__(config) + self.config = config + + self.embeddings = SbertEmbeddings(config) + self.encoder = SbertEncoder(config) + + self.pooler = SbertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward(self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, + `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple + having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + + Returns: + Returns `modelscope.outputs.AttentionBackboneModelOutputWithEmbedding` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_structbert_backbone_base_std', task='backbone') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_structbert_backbone_base_std') + >>> print(model(**preprocessor('这是个测试'))) + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time' + ) + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError( + 'You have to specify either input_ids or inputs_embeds') + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[ + 2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), + device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, 'token_type_ids'): + buffered_token_type_ids = self.embeddings.token_type_ids[:, : + seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand( + batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size( + ) + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + embedding_output, orignal_embeds = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + return_inputs_embeds=True, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler( + sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, + pooled_output) + encoder_outputs[1:] + (orignal_embeds, ) + + return AttentionBackboneModelOutputWithEmbedding( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + embedding_output=orignal_embeds) diff --git a/modelscope/models/nlp/structbert/configuration_sbert.py b/modelscope/models/nlp/structbert/configuration.py similarity index 94% rename from modelscope/models/nlp/structbert/configuration_sbert.py rename to modelscope/models/nlp/structbert/configuration.py index a727a978..8f095f9d 100644 --- a/modelscope/models/nlp/structbert/configuration_sbert.py +++ b/modelscope/models/nlp/structbert/configuration.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" SBERT model configuration, mainly copied from :class:`~transformers.BertConfig` """ +""" StructBERT model configuration, mainly copied from :class:`~transformers.BertConfig` """ from transformers import PretrainedConfig from modelscope.utils import logger as logging @@ -26,7 +26,7 @@ class SbertConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a :class:`~modelscope.models.nlp.structbert.SbertModel`. - It is used to instantiate a SBERT model according to the specified arguments. + It is used to instantiate a StructBERT model according to the specified arguments. Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. @@ -74,15 +74,15 @@ class SbertConfig(PretrainedConfig): relevant if ``config.is_decoder=True``. classifier_dropout (:obj:`float`, `optional`): The dropout ratio for the classification head. - adv_grad_factor (:obj:`float`, `optional`): This factor will be multipled by the KL loss grad and then + adv_grad_factor (:obj:`float`, `optional`): This factor will be multiplied by the KL loss grad and then the result will be added to the original embedding. More details please check:https://arxiv.org/abs/1908.04577 - The range of this value always be 1e-3~1e-7 + The range of this value should between 1e-3~1e-7 adv_bound (:obj:`float`, `optional`): adv_bound is used to cut the top and the bottom bound of the produced embedding. - If not proveded, 2 * sigma will be used as the adv_bound factor + If not provided, 2 * sigma will be used as the adv_bound factor sigma (:obj:`float`, `optional`): The std factor used to produce a 0 mean normal distribution. - If adv_bound not proveded, 2 * sigma will be used as the adv_bound factor + If adv_bound not provided, 2 * sigma will be used as the adv_bound factor """ model_type = 'structbert' diff --git a/modelscope/models/nlp/sbert_for_faq_question_answering.py b/modelscope/models/nlp/structbert/faq_question_answering.py similarity index 74% rename from modelscope/models/nlp/sbert_for_faq_question_answering.py rename to modelscope/models/nlp/structbert/faq_question_answering.py index 23ccdcc5..c8dbf302 100644 --- a/modelscope/models/nlp/sbert_for_faq_question_answering.py +++ b/modelscope/models/nlp/structbert/faq_question_answering.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import math import os from collections import namedtuple @@ -15,103 +17,6 @@ from modelscope.models.nlp.task_models.task_model import BaseTaskModel from modelscope.utils.config import Config, ConfigFields from modelscope.utils.constant import ModelFile, Tasks -__all__ = ['SbertForFaqQuestionAnswering'] - - -class SbertForFaqQuestionAnsweringBase(BaseTaskModel): - """base class for faq models - """ - - def __init__(self, model_dir, *args, **kwargs): - super(SbertForFaqQuestionAnsweringBase, - self).__init__(model_dir, *args, **kwargs) - - backbone_cfg = SbertConfig.from_pretrained(model_dir) - self.bert = SbertModel(backbone_cfg) - - model_config = Config.from_file( - os.path.join(model_dir, - ModelFile.CONFIGURATION)).get(ConfigFields.model, {}) - - metric = model_config.get('metric', 'cosine') - pooling_method = model_config.get('pooling', 'avg') - - Arg = namedtuple('args', [ - 'metrics', 'proj_hidden_size', 'hidden_size', 'dropout', 'pooling' - ]) - args = Arg( - metrics=metric, - proj_hidden_size=self.bert.config.hidden_size, - hidden_size=self.bert.config.hidden_size, - dropout=0.0, - pooling=pooling_method) - - self.metrics_layer = MetricsLayer(args) - self.pooling = PoolingLayer(args) - - def _get_onehot_labels(self, labels, support_size, num_cls): - labels_ = labels.view(support_size, 1) - target_oh = torch.zeros(support_size, num_cls).to(labels) - target_oh.scatter_(dim=1, index=labels_, value=1) - return target_oh.view(support_size, num_cls).float() - - def forward_sentence_embedding(self, inputs: Dict[str, Tensor]): - input_ids = inputs['input_ids'] - input_mask = inputs['attention_mask'] - if not isinstance(input_ids, Tensor): - input_ids = torch.IntTensor(input_ids) - if not isinstance(input_mask, Tensor): - input_mask = torch.IntTensor(input_mask) - rst = self.bert(input_ids, input_mask) - last_hidden_states = rst.last_hidden_state - if len(input_mask.shape) == 2: - input_mask = input_mask.unsqueeze(-1) - pooled_representation = self.pooling(last_hidden_states, input_mask) - return pooled_representation - - -@MODELS.register_module( - Tasks.faq_question_answering, module_name=Models.structbert) -class SbertForFaqQuestionAnswering(SbertForFaqQuestionAnsweringBase): - _backbone_prefix = '' - - def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: - assert not self.training - query = input['query'] - support = input['support'] - if isinstance(query, list): - query = torch.stack(query) - if isinstance(support, list): - support = torch.stack(support) - n_query = query.shape[0] - n_support = support.shape[0] - query_mask = torch.ne(query, 0).view([n_query, -1]) - support_mask = torch.ne(support, 0).view([n_support, -1]) - - support_labels = input['support_labels'] - num_cls = torch.max(support_labels) + 1 - onehot_labels = self._get_onehot_labels(support_labels, n_support, - num_cls) - - input_ids = torch.cat([query, support]) - input_mask = torch.cat([query_mask, support_mask], dim=0) - pooled_representation = self.forward_sentence_embedding({ - 'input_ids': - input_ids, - 'attention_mask': - input_mask - }) - z_query = pooled_representation[:n_query] - z_support = pooled_representation[n_query:] - cls_n_support = torch.sum(onehot_labels, dim=-2) + 1e-5 - protos = torch.matmul(onehot_labels.transpose(0, 1), - z_support) / cls_n_support.unsqueeze(-1) - scores = self.metrics_layer(z_query, protos).view([n_query, num_cls]) - if self.metrics_layer.name == 'relation': - scores = torch.sigmoid(scores) - return {'scores': scores} - - activations = { 'relu': F.relu, 'tanh': torch.tanh, @@ -247,3 +152,142 @@ class PoolingLayer(nn.Module): def forward(self, x, mask): return self.pooling(x, mask) + + +@MODELS.register_module( + Tasks.faq_question_answering, module_name=Models.structbert) +class SbertForFaqQuestionAnswering(BaseTaskModel): + _backbone_prefix = '' + + @classmethod + def _instantiate(cls, **kwargs): + model = cls(kwargs.get('model_dir')) + model.load_checkpoint(kwargs.get('model_dir')) + return model + + def __init__(self, model_dir, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + + backbone_cfg = SbertConfig.from_pretrained(model_dir) + self.bert = SbertModel(backbone_cfg) + + model_config = Config.from_file( + os.path.join(model_dir, + ModelFile.CONFIGURATION)).get(ConfigFields.model, {}) + + metric = model_config.get('metric', 'cosine') + pooling_method = model_config.get('pooling', 'avg') + + Arg = namedtuple('args', [ + 'metrics', 'proj_hidden_size', 'hidden_size', 'dropout', 'pooling' + ]) + args = Arg( + metrics=metric, + proj_hidden_size=self.bert.config.hidden_size, + hidden_size=self.bert.config.hidden_size, + dropout=0.0, + pooling=pooling_method) + + self.metrics_layer = MetricsLayer(args) + self.pooling = PoolingLayer(args) + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """ + Args: + input (Dict[str, Tensor]): the preprocessed data, it contains the following keys: + query(:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + The query to be predicted. + support(:obj:`torch.LongTensor` of shape :obj:`(support_size, sequence_length)`): + The support set. + support_label(:obj:`torch.LongTensor` of shape :obj:`(support_size, )`): + The labels of support set. + + Returns: + Dict[str, Tensor]: result, it contains the following key: + scores(:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_cls)`): + Predicted scores of all classes for each query. + Examples: + >>> from modelscope.hub.snapshot_download import snapshot_download + >>> from modelscope.preprocessors import FaqQuestionAnsweringPreprocessor + >>> from modelscope.models.nlp import SbertForFaqQuestionAnswering + >>> cache_path = snapshot_download('damo/nlp_structbert_faq-question-answering_chinese-base') + >>> preprocessor = FaqQuestionAnsweringPreprocessor.from_pretrained(cache_path) + >>> model = SbertForFaqQuestionAnswering.from_pretrained(cache_path) + >>> param = { + >>> 'query_set': ['如何使用优惠券', '在哪里领券', '在哪里领券'], + >>> 'support_set': [{ + >>> 'text': '卖品代金券怎么用', + >>> 'label': '6527856' + >>> }, { + >>> 'text': '怎么使用优惠券', + >>> 'label': '6527856' + >>> }, { + >>> 'text': '这个可以一起领吗', + >>> 'label': '1000012000' + >>> }, { + >>> 'text': '付款时送的优惠券哪里领', + >>> 'label': '1000012000' + >>> }, { + >>> 'text': '购物等级怎么长', + >>> 'label': '13421097' + >>> }, { + >>> 'text': '购物等级二心', + >>> 'label': '13421097' + >>> }] + >>> } + >>> result = model(preprocessor(param)) + """ + assert not self.training + query = input['query'] + support = input['support'] + if isinstance(query, list): + query = torch.stack(query) + if isinstance(support, list): + support = torch.stack(support) + n_query = query.shape[0] + n_support = support.shape[0] + query_mask = torch.ne(query, 0).view([n_query, -1]) + support_mask = torch.ne(support, 0).view([n_support, -1]) + + support_labels = input['support_labels'] + num_cls = torch.max(support_labels) + 1 + onehot_labels = self._get_onehot_labels(support_labels, n_support, + num_cls) + + input_ids = torch.cat([query, support]) + input_mask = torch.cat([query_mask, support_mask], dim=0) + pooled_representation = self.forward_sentence_embedding({ + 'input_ids': + input_ids, + 'attention_mask': + input_mask + }) + z_query = pooled_representation[:n_query] + z_support = pooled_representation[n_query:] + cls_n_support = torch.sum(onehot_labels, dim=-2) + 1e-5 + protos = torch.matmul(onehot_labels.transpose(0, 1), + z_support) / cls_n_support.unsqueeze(-1) + scores = self.metrics_layer(z_query, protos).view([n_query, num_cls]) + if self.metrics_layer.name == 'relation': + scores = torch.sigmoid(scores) + return {'scores': scores} + + def _get_onehot_labels(self, labels, support_size, num_cls): + labels_ = labels.view(support_size, 1) + target_oh = torch.zeros(support_size, num_cls).to(labels) + target_oh.scatter_(dim=1, index=labels_, value=1) + return target_oh.view(support_size, num_cls).float() + + def forward_sentence_embedding(self, inputs: Dict[str, Tensor]): + input_ids = inputs['input_ids'] + input_mask = inputs['attention_mask'] + if not isinstance(input_ids, Tensor): + input_ids = torch.IntTensor(input_ids) + if not isinstance(input_mask, Tensor): + input_mask = torch.IntTensor(input_mask) + rst = self.bert(input_ids, input_mask) + last_hidden_states = rst.last_hidden_state + if len(input_mask.shape) == 2: + input_mask = input_mask.unsqueeze(-1) + pooled_representation = self.pooling(last_hidden_states, input_mask) + return pooled_representation diff --git a/modelscope/models/nlp/structbert/fill_mask.py b/modelscope/models/nlp/structbert/fill_mask.py new file mode 100644 index 00000000..e611aa88 --- /dev/null +++ b/modelscope/models/nlp/structbert/fill_mask.py @@ -0,0 +1,284 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionFillMaskModelOutput +from modelscope.utils import logger as logging +from modelscope.utils.constant import Tasks +from .backbone import SbertModel, SbertPreTrainedModel +from .configuration import SbertConfig + +logger = logging.get_logger(__name__) + + +class SbertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class SbertLMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.transform = SbertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class SbertOnlyMLMHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = SbertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class SbertPreTrainingHeads(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = SbertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +@MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert) +class SbertForMaskedLM(SbertPreTrainedModel): + r"""StructBERT Model with a `language modeling` head on top. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Preprocessor: + This is the fill_mask model of StructBERT, the preprocessor of this model + is `modelscope.preprocessors.NLPPreprocessor`. + + Parameters: + config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with + all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. + """ + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config: SbertConfig, **kwargs): + super().__init__(config) + + if config.is_decoder: + logger.warning( + 'If you want to use `SbertForMaskedLM` make sure `config.is_decoder=False` for ' + 'bi-directional self-attention.') + + self.bert = SbertModel(config) + self.cls = SbertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + + Returns: + Returns `modelscope.outputs.AttentionFillMaskModelOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor, NLPPreprocessor + >>> model = Model.from_pretrained('damo/nlp_structbert_fill-mask_chinese-large') + >>> preprocessor = NLPPreprocessor('damo/nlp_structbert_fill-mask_chinese-large') + >>> # Call the model, return some tensors + >>> print(model(**preprocessor('你师父差得动你,你师父可[MASK]不动我。'))) + >>> # Call the pipeline + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline('fill-mask', model=model, preprocessor=preprocessor) + >>> print(pipeline_ins('你师父差得动你,你师父可[MASK]不动我。')) + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + + if not return_dict: + output = (prediction_scores, ) + outputs[2:-1] + return ((masked_lm_loss, ) + + output) if masked_lm_loss is not None else output + + return AttentionFillMaskModelOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + input_ids=input_ids, + ) + + def prepare_inputs_for_generation(self, + input_ids, + attention_mask=None, + **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + assert self.config.pad_token_id is not None, 'The PAD token should be defined for generation' + attention_mask_zero = attention_mask.new_zeros( + (attention_mask.shape[0], 1)) + attention_mask = torch.cat([attention_mask, attention_mask_zero], + dim=-1) + dummy_token = torch.full((effective_batch_size, 1), + self.config.pad_token_id, + dtype=torch.long, + device=input_ids.device) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {'input_ids': input_ids, 'attention_mask': attention_mask} diff --git a/modelscope/models/nlp/structbert/modeling_sbert.py b/modelscope/models/nlp/structbert/modeling_sbert.py deleted file mode 100755 index e789037a..00000000 --- a/modelscope/models/nlp/structbert/modeling_sbert.py +++ /dev/null @@ -1,1963 +0,0 @@ -# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch SBERT model. mainly copied from :module:`~transformers.modeling_bert`""" - -import math -import warnings -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import numpy as np -import torch -import torch.utils.checkpoint -from packaging import version -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers.activations import ACT2FN -from transformers.file_utils import (ModelOutput, add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - replace_return_docstrings) -from transformers.modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, - BaseModelOutputWithPoolingAndCrossAttentions, - CausalLMOutputWithCrossAttentions, MaskedLMOutput, - MultipleChoiceModelOutput, NextSentencePredictorOutput, - QuestionAnsweringModelOutput, SequenceClassifierOutput, - TokenClassifierOutput) -from transformers.modeling_utils import (PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer) - -from modelscope.metainfo import Models -from modelscope.models.builder import BACKBONES -from modelscope.utils.constant import Fields -from modelscope.utils.logger import get_logger -from .adv_utils import compute_adv_loss, compute_adv_loss_pair -from .configuration_sbert import SbertConfig - -logger = get_logger(__name__) - -_CHECKPOINT_FOR_DOC = 'nlp_structbert_backbone_base_std' -_CONFIG_FOR_DOC = 'SbertConfig' -_TOKENIZER_FOR_DOC = 'SbertTokenizer' - - -class SbertEmbeddings(nn.Module): - """Construct the embeddings from word, position and token_type embeddings.""" - - def __init__(self, config): - super().__init__() - self.word_embeddings = nn.Embedding( - config.vocab_size, - config.hidden_size, - padding_idx=config.pad_token_id) - self.position_embeddings = nn.Embedding(config.max_position_embeddings, - config.hidden_size) - self.token_type_embeddings = nn.Embedding(config.type_vocab_size, - config.hidden_size) - - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file - self.LayerNorm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.position_embedding_type = getattr(config, - 'position_embedding_type', - 'absolute') - self.register_buffer( - 'position_ids', - torch.arange(config.max_position_embeddings).expand((1, -1))) - if version.parse(torch.__version__) > version.parse('1.6.0'): - self.register_buffer( - 'token_type_ids', - torch.zeros( - self.position_ids.size(), - dtype=torch.long, - device=self.position_ids.device), - persistent=False, - ) - - def forward(self, - input_ids=None, - token_type_ids=None, - position_ids=None, - inputs_embeds=None, - past_key_values_length=0, - return_inputs_embeds=False): - if input_ids is not None: - input_shape = input_ids.size() - else: - input_shape = inputs_embeds.size()[:-1] - - seq_length = input_shape[1] - - if position_ids is None: - position_ids = self.position_ids[:, - past_key_values_length:seq_length - + past_key_values_length] - - # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs - # when its auto-generated, registered buffer helps users - # when tracing the model without passing token_type_ids, solves - # issue #5664 - if token_type_ids is None: - if hasattr(self, 'token_type_ids'): - buffered_token_type_ids = self.token_type_ids[:, :seq_length] - buffered_token_type_ids_expanded = buffered_token_type_ids.expand( - input_shape[0], seq_length) - token_type_ids = buffered_token_type_ids_expanded - else: - token_type_ids = torch.zeros( - input_shape, - dtype=torch.long, - device=self.position_ids.device) - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == 'absolute': - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings) - if not return_inputs_embeds: - return embeddings - else: - return embeddings, inputs_embeds - - -class SbertSelfAttention(nn.Module): - - def __init__(self, config): - super().__init__() - if config.hidden_size % config.num_attention_heads != 0 and not hasattr( - config, 'embedding_size'): - raise ValueError( - f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention ' - f'heads ({config.num_attention_heads})') - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size - / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = nn.Linear(config.hidden_size, self.all_head_size) - self.key = nn.Linear(config.hidden_size, self.all_head_size) - self.value = nn.Linear(config.hidden_size, self.all_head_size) - - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = getattr(config, - 'position_embedding_type', - 'absolute') - if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding( - 2 * config.max_position_embeddings - 1, - self.attention_head_size) - - self.is_decoder = config.is_decoder - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, - self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - ): - mixed_query_layer = self.query(hidden_states) - - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores( - self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores( - self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, - key_layer.transpose(-1, -2)) - - if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange( - seq_length, dtype=torch.long, - device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange( - seq_length, dtype=torch.long, - device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - positional_embedding = self.distance_embedding( - distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to( - dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == 'relative_key': - relative_position_scores = torch.einsum( - 'bhld,lrd->bhlr', query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == 'relative_key_query': - relative_position_scores_query = torch.einsum( - 'bhld,lrd->bhlr', query_layer, positional_embedding) - relative_position_scores_key = torch.einsum( - 'bhrd,lrd->bhlr', key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - - attention_scores = attention_scores / math.sqrt( - self.attention_head_size) - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in SbertModel forward() function) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - attention_probs = nn.Softmax(dim=-1)(attention_scores) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + ( - self.all_head_size, ) - context_layer = context_layer.view(*new_context_layer_shape) - - outputs = (context_layer, - attention_probs) if output_attentions else (context_layer, ) - - if self.is_decoder: - outputs = outputs + (past_key_value, ) - return outputs - - -class SbertSelfOutput(nn.Module): - - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class SbertAttention(nn.Module): - - def __init__(self, config): - super().__init__() - self.self = SbertSelfAttention(config) - self.output = SbertSelfOutput(config) - self.pruned_heads = set() - - def prune_heads(self, heads): - if len(heads) == 0: - return - heads, index = find_pruneable_heads_and_indices( - heads, self.self.num_attention_heads, - self.self.attention_head_size, self.pruned_heads) - - # Prune linear layers - self.self.query = prune_linear_layer(self.self.query, index) - self.self.key = prune_linear_layer(self.self.key, index) - self.self.value = prune_linear_layer(self.self.value, index) - self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) - - # Update hyper params and store pruned heads - self.self.num_attention_heads = self.self.num_attention_heads - len( - heads) - self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads - self.pruned_heads = self.pruned_heads.union(heads) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - ): - self_outputs = self.self( - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output, - ) + self_outputs[1:] # add attentions if we output them - return outputs - - -class SbertIntermediate(nn.Module): - - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.intermediate_size) - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = ACT2FN[config.hidden_act] - else: - self.intermediate_act_fn = config.hidden_act - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - return hidden_states - - -class SbertOutput(nn.Module): - - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class SbertLayer(nn.Module): - - def __init__(self, config): - super().__init__() - self.chunk_size_feed_forward = config.chunk_size_feed_forward - self.seq_len_dim = 1 - self.attention = SbertAttention(config) - self.is_decoder = config.is_decoder - self.add_cross_attention = config.add_cross_attention - if self.add_cross_attention: - if not self.is_decoder: - raise ValueError( - f'{self} should be used as a decoder model if cross attention is added' - ) - self.crossattention = SbertAttention(config) - self.intermediate = SbertIntermediate(config) - self.output = SbertOutput(config) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - ): - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[: - 2] if past_key_value is not None else None - self_attention_outputs = self.attention( - hidden_states, - attention_mask, - head_mask, - output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, - ) - attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[ - 1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None - if self.is_decoder and encoder_hidden_states is not None: - if not hasattr(self, 'crossattention'): - raise ValueError( - f'If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention ' - f'layers by setting `config.add_cross_attention=True`') - - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[ - -2:] if past_key_value is not None else None - cross_attention_outputs = self.crossattention( - attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[ - 1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - - layer_output = apply_chunking_to_forward(self.feed_forward_chunk, - self.chunk_size_feed_forward, - self.seq_len_dim, - attention_output) - outputs = (layer_output, ) + outputs - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value, ) - - return outputs - - def feed_forward_chunk(self, attention_output): - intermediate_output = self.intermediate(attention_output) - layer_output = self.output(intermediate_output, attention_output) - return layer_output - - -class SbertEncoder(nn.Module): - - def __init__(self, config): - super().__init__() - self.config = config - self.layer = nn.ModuleList( - [SbertLayer(config) for _ in range(config.num_hidden_layers)]) - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ): - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = ( - ) if output_attentions and self.config.add_cross_attention else None - - next_decoder_cache = () if use_cache else None - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states, ) - - layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[ - i] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - if use_cache: - logger.warning( - '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' - ) - use_cache = False - - def create_custom_forward(module): - - def custom_forward(*inputs): - return module(*inputs, past_key_value, - output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - - hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1], ) - if output_attentions: - all_self_attentions = all_self_attentions + ( - layer_outputs[1], ) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + ( - layer_outputs[2], ) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states, ) - - if not return_dict: - return tuple(v for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] if v is not None) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - - -class SbertPooler(nn.Module): - - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.activation = nn.Tanh() - - def forward(self, hidden_states): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(first_token_tensor) - pooled_output = self.activation(pooled_output) - return pooled_output - - -class SbertPredictionHeadTransform(nn.Module): - - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - if isinstance(config.hidden_act, str): - self.transform_act_fn = ACT2FN[config.hidden_act] - else: - self.transform_act_fn = config.hidden_act - self.LayerNorm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps) - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.LayerNorm(hidden_states) - return hidden_states - - -class SbertLMPredictionHead(nn.Module): - - def __init__(self, config): - super().__init__() - self.transform = SbertPredictionHeadTransform(config) - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size) - - def forward(self, hidden_states): - hidden_states = self.transform(hidden_states) - hidden_states = self.decoder(hidden_states) - return hidden_states - - -class SbertOnlyMLMHead(nn.Module): - - def __init__(self, config): - super().__init__() - self.predictions = SbertLMPredictionHead(config) - - def forward(self, sequence_output): - prediction_scores = self.predictions(sequence_output) - return prediction_scores - - -class SbertOnlyNSPHead(nn.Module): - - def __init__(self, config): - super().__init__() - self.seq_relationship = nn.Linear(config.hidden_size, 2) - - def forward(self, pooled_output): - seq_relationship_score = self.seq_relationship(pooled_output) - return seq_relationship_score - - -class SbertPreTrainingHeads(nn.Module): - - def __init__(self, config): - super().__init__() - self.predictions = SbertLMPredictionHead(config) - self.seq_relationship = nn.Linear(config.hidden_size, 2) - - def forward(self, sequence_output, pooled_output): - prediction_scores = self.predictions(sequence_output) - seq_relationship_score = self.seq_relationship(pooled_output) - return prediction_scores, seq_relationship_score - - -class SbertPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = SbertConfig - base_model_prefix = 'bert' - supports_gradient_checkpointing = True - _keys_to_ignore_on_load_missing = [r'position_ids'] - - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_( - mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_( - mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, SbertEncoder): - module.gradient_checkpointing = value - - -@dataclass -class SbertForPreTrainingOutput(ModelOutput): - """ - Output type of :class:`~transformers.BertForPreTraining`. - - Args: - loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`): - Total loss as the sum of the masked language modeling loss and the next sequence prediction - (classification) loss. - prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`): - Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation - before SoftMax). - hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` - is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) - of shape :obj:`(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` - is passed or when ``config.output_attentions=True``): - Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - prediction_logits: torch.FloatTensor = None - seq_relationship_logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - - -SBERT_START_DOCSTRING = r""" - - This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic - methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, - pruning heads etc.) - - This model is also a PyTorch `torch.nn.Module `__ - subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to - general usage and behavior. - - Parameters: - config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with - all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model - weights. -""" - -SBERT_INPUTS_DOCSTRING = r""" - Args: - input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See - :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for - details. - - `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): - Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - `What are attention masks? <../glossary.html#attention-mask>`__ - token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, - 1]``: - - - 0 corresponds to a `sentence A` token, - - 1 corresponds to a `sentence B` token. - - `What are token type IDs? <../glossary.html#token-type-ids>`_ - position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, - config.max_position_embeddings - 1]``. - - `What are position IDs? <../glossary.html#position-ids>`_ - head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): - Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert :obj:`input_ids` indices into associated - vectors than the model's internal embedding lookup matrix. - output_attentions (:obj:`bool`, `optional`): - Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned - tensors for more detail. - output_hidden_states (:obj:`bool`, `optional`): - Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for - more detail. - return_dict (:obj:`bool`, `optional`): - Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. -""" - - -@dataclass -class BaseModelOutputWithPoolingAndCrossAttentionsWithEmbedding( - BaseModelOutputWithPoolingAndCrossAttentions): - embedding_output: torch.FloatTensor = None - logits: Optional[Union[tuple, torch.FloatTensor]] = None - kwargs: dict = None - - -@add_start_docstrings( - 'The Sbert Model transformer outputting raw hidden-states without any specific head on top.', - SBERT_START_DOCSTRING, -) -class SbertModel(SbertPreTrainedModel): - """ - - The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of - cross-attention is added between the self-attention layers, following the architecture described in `Attention is - all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, - Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. - - To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration - set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder` - argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an - input to the forward pass. - """ - - def __init__(self, config: SbertConfig, add_pooling_layer=True): - super().__init__(config) - self.config = config - - self.embeddings = SbertEmbeddings(config) - self.encoder = SbertEncoder(config) - - self.pooler = SbertPooler(config) if add_pooling_layer else None - - self.init_weights() - - def get_input_embeddings(self): - return self.embeddings.word_embeddings - - def set_input_embeddings(self, value): - self.embeddings.word_embeddings = value - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.encoder.layer[layer].attention.prune_heads(heads) - - @add_start_docstrings_to_model_forward( - SBERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=BaseModelOutputWithPoolingAndCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def forward(self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - **kwargs): - r""" - encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, - `optional`): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple - having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` - (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` - instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. - use_cache (:obj:`bool`, `optional`): - If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up - decoding (see :obj:`past_key_values`). - """ - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else - self.config.output_hidden_states) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if self.config.is_decoder: - use_cache = use_cache if use_cache is not None else self.config.use_cache - else: - use_cache = False - - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - 'You cannot specify both input_ids and inputs_embeds at the same time' - ) - elif input_ids is not None: - input_shape = input_ids.size() - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError( - 'You have to specify either input_ids or inputs_embeds') - - batch_size, seq_length = input_shape - device = input_ids.device if input_ids is not None else inputs_embeds.device - - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[ - 2] if past_key_values is not None else 0 - - if attention_mask is None: - attention_mask = torch.ones( - ((batch_size, seq_length + past_key_values_length)), - device=device) - - if token_type_ids is None: - if hasattr(self.embeddings, 'token_type_ids'): - buffered_token_type_ids = self.embeddings.token_type_ids[:, : - seq_length] - buffered_token_type_ids_expanded = buffered_token_type_ids.expand( - batch_size, seq_length) - token_type_ids = buffered_token_type_ids_expanded - else: - token_type_ids = torch.zeros( - input_shape, dtype=torch.long, device=device) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( - attention_mask, input_shape, device) - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size( - ) - encoder_hidden_shape = (encoder_batch_size, - encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones( - encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask( - encoder_attention_mask) - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, - self.config.num_hidden_layers) - - embedding_output, orignal_embeds = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - return_inputs_embeds=True, - ) - encoder_outputs = self.encoder( - embedding_output, - attention_mask=extended_attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - sequence_output = encoder_outputs[0] - pooled_output = self.pooler( - sequence_output) if self.pooler is not None else None - - if not return_dict: - return (sequence_output, - pooled_output) + encoder_outputs[1:] + (orignal_embeds, ) - - return BaseModelOutputWithPoolingAndCrossAttentionsWithEmbedding( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - past_key_values=encoder_outputs.past_key_values, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, - embedding_output=orignal_embeds) - - -@add_start_docstrings( - """ - Sbert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next - sentence prediction (classification)` head. - """, - SBERT_START_DOCSTRING, -) -class SbertForPreTraining(SbertPreTrainedModel): - - def __init__(self, config: SbertConfig): - super().__init__(config) - - self.bert = SbertModel(config) - self.cls = SbertPreTrainingHeads(config) - - self.init_weights() - - def get_output_embeddings(self): - return self.cls.predictions.decoder - - def set_output_embeddings(self, new_embeddings): - self.cls.predictions.decoder = new_embeddings - - @add_start_docstrings_to_model_forward( - SBERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - @replace_return_docstrings( - output_type=SbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - next_sentence_label=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`): - Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., - config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored - (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` - next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`): - Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair - (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``: - - - 0 indicates sequence B is a continuation of sequence A, - - 1 indicates sequence B is a random sequence. - kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): - Used to hide legacy arguments that have been deprecated. - - Returns: - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output, pooled_output = outputs[:2] - prediction_scores, seq_relationship_score = self.cls( - sequence_output, pooled_output) - - total_loss = None - if labels is not None and next_sentence_label is not None: - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct( - prediction_scores.view(-1, self.config.vocab_size), - labels.view(-1)) - next_sentence_loss = loss_fct( - seq_relationship_score.view(-1, 2), - next_sentence_label.view(-1)) - total_loss = masked_lm_loss + next_sentence_loss - - if not return_dict: - output = (prediction_scores, - seq_relationship_score) + outputs[2:-1] - return ((total_loss, ) - + output) if total_loss is not None else output - - return SbertForPreTrainingOutput( - loss=total_loss, - prediction_logits=prediction_scores, - seq_relationship_logits=seq_relationship_score, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """Sbert Model with a `language modeling` head on top for CLM fine-tuning. """, - SBERT_START_DOCSTRING) -class SbertLMHeadModel(SbertPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r'pooler'] - _keys_to_ignore_on_load_missing = [ - r'position_ids', r'predictions.decoder.bias' - ] - - def __init__(self, config: SbertConfig): - super().__init__(config) - - if not config.is_decoder: - logger.warning( - 'If you want to use `SbertLMHeadModel` as a standalone, add `is_decoder=True.`' - ) - - self.bert = SbertModel(config, add_pooling_layer=False) - self.cls = SbertOnlyMLMHead(config) - - self.init_weights() - - def get_output_embeddings(self): - return self.cls.predictions.decoder - - def set_output_embeddings(self, new_embeddings): - self.cls.predictions.decoder = new_embeddings - - @add_start_docstrings_to_model_forward( - SBERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - @replace_return_docstrings( - output_type=CausalLMOutputWithCrossAttentions, - config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, - `optional`): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in - ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are - ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` - past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` - with each tuple having 4 tensors of - shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` - (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` - instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. - use_cache (:obj:`bool`, `optional`): - If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up - decoding (see :obj:`past_key_values`). - - Returns: - - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if labels is not None: - use_cache = False - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - prediction_scores = self.cls(sequence_output) - - lm_loss = None - if labels is not None: - # we are doing next-token prediction; shift prediction scores and input ids by one - shifted_prediction_scores = prediction_scores[:, : - -1, :].contiguous() - labels = labels[:, 1:].contiguous() - loss_fct = CrossEntropyLoss() - lm_loss = loss_fct( - shifted_prediction_scores.view(-1, self.config.vocab_size), - labels.view(-1)) - if not return_dict: - output = (prediction_scores, ) + outputs[2:-1] - return ((lm_loss, ) + output) if lm_loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=lm_loss, - logits=prediction_scores, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - def prepare_inputs_for_generation(self, - input_ids, - past=None, - attention_mask=None, - **model_kwargs): - input_shape = input_ids.shape - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = input_ids.new_ones(input_shape) - - # cut decoder_input_ids if past is used - if past is not None: - input_ids = input_ids[:, -1:] - - return { - 'input_ids': input_ids, - 'attention_mask': attention_mask, - 'past_key_values': past - } - - def _reorder_cache(self, past, beam_idx): - reordered_past = () - for layer_past in past: - reordered_past += (tuple( - past_state.index_select(0, beam_idx) - for past_state in layer_past), ) - return reordered_past - - -@add_start_docstrings( - """Sbert Model with a `language modeling` head on top. """, - SBERT_START_DOCSTRING) -class SbertForMaskedLM(SbertPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r'pooler'] - _keys_to_ignore_on_load_missing = [ - r'position_ids', r'predictions.decoder.bias' - ] - - def __init__(self, config: SbertConfig): - super().__init__(config) - - if config.is_decoder: - logger.warning( - 'If you want to use `SbertForMaskedLM` make sure `config.is_decoder=False` for ' - 'bi-directional self-attention.') - - self.bert = SbertModel(config) - self.cls = SbertOnlyMLMHead(config) - - self.init_weights() - - def get_output_embeddings(self): - return self.cls.predictions.decoder - - def set_output_embeddings(self, new_embeddings): - self.cls.predictions.decoder = new_embeddings - - @add_start_docstrings_to_model_forward( - SBERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=MaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., - config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored - (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` - """ - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - sequence_output = outputs[0] - prediction_scores = self.cls(sequence_output) - - masked_lm_loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() # -100 index = padding token - masked_lm_loss = loss_fct( - prediction_scores.view(-1, self.config.vocab_size), - labels.view(-1)) - - if not return_dict: - output = (prediction_scores, ) + outputs[2:-1] - return ((masked_lm_loss, ) - + output) if masked_lm_loss is not None else output - - return MaskedLMOutput( - loss=masked_lm_loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation(self, - input_ids, - attention_mask=None, - **model_kwargs): - input_shape = input_ids.shape - effective_batch_size = input_shape[0] - - # add a dummy token - assert self.config.pad_token_id is not None, 'The PAD token should be defined for generation' - attention_mask_zero = attention_mask.new_zeros( - (attention_mask.shape[0], 1)) - attention_mask = torch.cat([attention_mask, attention_mask_zero], - dim=-1) - dummy_token = torch.full((effective_batch_size, 1), - self.config.pad_token_id, - dtype=torch.long, - device=input_ids.device) - input_ids = torch.cat([input_ids, dummy_token], dim=1) - - return {'input_ids': input_ids, 'attention_mask': attention_mask} - - -@add_start_docstrings( - """Sbert Model with a `next sentence prediction (classification)` head on top. """, - SBERT_START_DOCSTRING, -) -class SbertForNextSentencePrediction(SbertPreTrainedModel): - - def __init__(self, config: SbertConfig): - super().__init__(config) - - self.bert = SbertModel(config) - self.cls = SbertOnlyNSPHead(config) - - self.init_weights() - - @add_start_docstrings_to_model_forward( - SBERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - @replace_return_docstrings( - output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - **kwargs, - ): - r""" - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): - Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair - (see ``input_ids`` docstring). Indices should be in ``[0, 1]``: - - - 0 indicates sequence B is a continuation of sequence A, - - 1 indicates sequence B is a random sequence. - - Returns: - - """ - - if 'next_sentence_label' in kwargs: - warnings.warn( - 'The `next_sentence_label` argument is deprecated and will be removed ' - 'in a future version, use `labels` instead.', - FutureWarning, - ) - labels = kwargs.pop('next_sentence_label') - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = outputs[1] - - seq_relationship_scores = self.cls(pooled_output) - - next_sentence_loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - next_sentence_loss = loss_fct( - seq_relationship_scores.view(-1, 2), labels.view(-1)) - - if not return_dict: - output = (seq_relationship_scores, ) + outputs[2:-1] - return ((next_sentence_loss, ) - + output) if next_sentence_loss is not None else output - - return NextSentencePredictorOutput( - loss=next_sentence_loss, - logits=seq_relationship_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Sbert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled - output) e.g. for GLUE tasks. - """, - SBERT_START_DOCSTRING, -) -class SbertForSequenceClassification(SbertPreTrainedModel): - - def __init__(self, config: SbertConfig): - super().__init__(config) - self.num_labels = config.num_labels - self.config = config - if self.config.adv_grad_factor is None: - logger.warning( - 'Adv parameters not set, skipping compute_adv_loss.') - self.bert = SbertModel(config) - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None - else config.hidden_dropout_prob) - self.dropout = nn.Dropout(classifier_dropout) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) - self.init_weights() - - def _forward_call(self, **kwargs): - outputs = self.bert(**kwargs) - pooled_output = outputs[1] - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - outputs['logits'] = logits - outputs.kwargs = kwargs - return outputs - - @add_start_docstrings_to_model_forward( - SBERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=SequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward(self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - **kwargs): - r""" - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): - Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., - config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), - If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if not return_dict: - logger.error('Return tuple in sbert is not supported now.') - outputs = self._forward_call( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict) - return self.compute_loss(outputs, labels, **outputs.kwargs) - - def compute_loss(self, outputs, labels, **kwargs): - logits = outputs.logits - embedding_output = outputs.embedding_output - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = 'regression' - elif self.num_labels > 1 and (labels.dtype == torch.long - or labels.dtype == torch.int): - self.config.problem_type = 'single_label_classification' - else: - self.config.problem_type = 'multi_label_classification' - - if self.config.problem_type == 'regression': - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(logits, labels) - elif self.config.problem_type == 'single_label_classification': - loss_fct = CrossEntropyLoss() - loss = loss_fct( - logits.view(-1, self.num_labels), labels.view(-1)) - if self.config.adv_grad_factor is not None and self.training: - loss = compute_adv_loss( - embedding=embedding_output, - model=self._forward_call, - ori_logits=logits, - ori_loss=loss, - adv_bound=self.config.adv_bound, - adv_grad_factor=self.config.adv_grad_factor, - sigma=self.config.sigma, - **kwargs) - elif self.config.problem_type == 'multi_label_classification': - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits, labels) - - return SequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Sbert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. - """, - SBERT_START_DOCSTRING, -) -class SbertForMultipleChoice(SbertPreTrainedModel): - - def __init__(self, config: SbertConfig): - super().__init__(config) - self.config = config - if self.config.adv_grad_factor is None: - logger.warning( - 'Adv parameters not set, skipping compute_adv_loss.') - self.bert = SbertModel(config) - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None - else config.hidden_dropout_prob) - self.dropout = nn.Dropout(classifier_dropout) - self.classifier = nn.Linear(config.hidden_size, 1) - - self.init_weights() - - def _forward_call(self, num_choices, **kwargs): - outputs = self.bert(**kwargs) - pooled_output = outputs[1] - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - outputs['logits'] = logits.view(-1, num_choices) - kwargs['num_choices'] = num_choices - outputs.kwargs = kwargs - return outputs - - @add_start_docstrings_to_model_forward( - SBERT_INPUTS_DOCSTRING.format( - 'batch_size, num_choices, sequence_length')) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=MultipleChoiceModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): - Labels for computing the multiple choice classification loss. Indices should be in ``[0, ..., - num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See - :obj:`input_ids` above) - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if not return_dict: - logger.error('Return tuple in sbert is not supported now.') - - num_choices = input_ids.shape[ - 1] if input_ids is not None else inputs_embeds.shape[1] - - input_ids = input_ids.view( - -1, input_ids.size(-1)) if input_ids is not None else None - attention_mask = attention_mask.view( - -1, - attention_mask.size(-1)) if attention_mask is not None else None - token_type_ids = token_type_ids.view( - -1, - token_type_ids.size(-1)) if token_type_ids is not None else None - position_ids = position_ids.view( - -1, position_ids.size(-1)) if position_ids is not None else None - inputs_embeds = ( - inputs_embeds.view(-1, inputs_embeds.size(-2), - inputs_embeds.size(-1)) - if inputs_embeds is not None else None) - - outputs = self._forward_call( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - num_choices=num_choices) - - reshaped_logits = outputs.logits - kwargs = outputs.kwargs - embedding_output = outputs.embedding_output - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(reshaped_logits, labels) - if self.config.adv_grad_factor is not None and self.training: - loss = compute_adv_loss( - embedding=embedding_output, - model=self._forward_call, - ori_logits=reshaped_logits, - ori_loss=loss, - adv_bound=self.config.adv_bound, - adv_grad_factor=self.config.adv_grad_factor, - sigma=self.config.sigma, - **kwargs) - - return MultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Sbert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - SBERT_START_DOCSTRING, -) -class SbertForTokenClassification(SbertPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r'pooler'] - - def __init__(self, config: SbertConfig): - super().__init__(config) - self.num_labels = config.num_labels - self.config = config - if self.config.adv_grad_factor is None: - logger.warning( - 'Adv parameters not set, skipping compute_adv_loss.') - self.bert = SbertModel(config, add_pooling_layer=False) - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None - else config.hidden_dropout_prob) - self.dropout = nn.Dropout(classifier_dropout) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) - - self.init_weights() - - def _forward_call(self, **kwargs): - outputs = self.bert(**kwargs) - sequence_output = outputs[0] - sequence_output = self.dropout(sequence_output) - logits = self.classifier(sequence_output) - outputs['logits'] = logits - outputs.kwargs = kwargs - return outputs - - @add_start_docstrings_to_model_forward( - SBERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - - 1]``. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if not return_dict: - logger.error('Return tuple in sbert is not supported now.') - - outputs = self._forward_call( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict) - - logits = outputs.logits - embedding_output = outputs.embedding_output - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - # Only keep active parts of the loss - if attention_mask is not None: - active_loss = attention_mask.view(-1) == 1 - active_logits = logits.view(-1, self.num_labels) - active_labels = torch.where( - active_loss, labels.view(-1), - torch.tensor(loss_fct.ignore_index).type_as(labels)) - loss = loss_fct(active_logits, active_labels) - else: - loss = loss_fct( - logits.view(-1, self.num_labels), labels.view(-1)) - if self.config.adv_grad_factor is not None and self.training: - loss = compute_adv_loss( - embedding=embedding_output, - model=self._forward_call, - ori_logits=logits, - ori_loss=loss, - adv_bound=self.config.adv_bound, - adv_grad_factor=self.config.adv_grad_factor, - sigma=self.config.sigma, - with_attention_mask=attention_mask is not None, - **outputs.kwargs) - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - Sbert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - SBERT_START_DOCSTRING, -) -class SbertForQuestionAnswering(SbertPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r'pooler'] - - def __init__(self, config: SbertConfig): - super().__init__(config) - self.num_labels = config.num_labels - self.config = config - if self.config.adv_grad_factor is None: - logger.warning( - 'Adv parameters not set, skipping compute_adv_loss.') - self.bert = SbertModel(config, add_pooling_layer=False) - self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) - - self.init_weights() - - def _forward_call(self, **kwargs): - outputs = self.bert(**kwargs) - sequence_output = outputs[0] - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1).contiguous() - end_logits = end_logits.squeeze(-1).contiguous() - outputs['logits'] = (start_logits, end_logits) - outputs.kwargs = kwargs - return outputs - - @add_start_docstrings_to_model_forward( - SBERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=QuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - start_positions=None, - end_positions=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the - sequence are not taken into account for computing the loss. - end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the - sequence are not taken into account for computing the loss. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if not return_dict: - logger.error('Return tuple in sbert is not supported now.') - - outputs = self._forward_call( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict) - return self.compute_loss(outputs, start_positions, end_positions, - **outputs.kwargs) - - def compute_loss(self, - outputs, - start_positions=None, - end_positions=None, - **kwargs): - start_logits, end_logits = outputs.logits - embedding_output = outputs.embedding_output - total_loss = None - if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions = start_positions.clamp(0, ignored_index) - end_positions = end_positions.clamp(0, ignored_index) - - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - if self.config.adv_grad_factor is not None and self.training: - total_loss = compute_adv_loss_pair( - embedding=embedding_output, - model=self._forward_call, - start_logits=start_logits, - end_logits=end_logits, - ori_loss=total_loss, - adv_bound=self.config.adv_bound, - adv_grad_factor=self.config.adv_grad_factor, - sigma=self.config.sigma, - **kwargs) - - return QuestionAnsweringModelOutput( - loss=total_loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) diff --git a/modelscope/models/nlp/structbert/text_classification.py b/modelscope/models/nlp/structbert/text_classification.py new file mode 100644 index 00000000..044cf8d0 --- /dev/null +++ b/modelscope/models/nlp/structbert/text_classification.py @@ -0,0 +1,235 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionTextClassificationModelOutput +from modelscope.utils import logger as logging +from modelscope.utils.constant import Tasks +from .adv_utils import compute_adv_loss +from .backbone import SbertModel, SbertPreTrainedModel +from .configuration import SbertConfig + +logger = logging.get_logger(__name__) + + +@MODELS.register_module( + Tasks.text_classification, module_name=Models.structbert) +@MODELS.register_module(Tasks.nli, module_name=Models.structbert) +@MODELS.register_module( + Tasks.sentiment_classification, module_name=Models.structbert) +@MODELS.register_module( + Tasks.sentence_similarity, module_name=Models.structbert) +@MODELS.register_module( + Tasks.zero_shot_classification, module_name=Models.structbert) +class SbertForSequenceClassification(SbertPreTrainedModel): + r"""StructBERT Model transformer with a sequence classification/regression head on top + (a linear layer on top of the pooled output) e.g. for GLUE tasks. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Preprocessor: + This is the text classification model of StructBERT, the preprocessor of this model + is `modelscope.preprocessors.SequenceClassificationPreprocessor`. + + Trainer: + This model is a normal PyTorch model, and can be trained by variable trainers, like EpochBasedTrainer, + NlpEpochBasedTrainer, or trainers from other frameworks. + The preferred trainer in ModelScope is NlpEpochBasedTrainer. + + Parameters: + config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with + all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. + """ + + def __init__(self, config: SbertConfig, **kwargs): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + if self.config.adv_grad_factor is None: + logger.warning( + 'Adv parameters not set, skipping compute_adv_loss.') + + SbertForSequenceClassification.base_model_prefix = getattr( + config, 'base_model_prefix', + SbertForSequenceClassification.base_model_prefix) + setattr(self, self.base_model_prefix, SbertModel(config)) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None + else config.hidden_dropout_prob) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.init_weights() + + def _forward_call(self, **kwargs): + outputs = self.base_model(**kwargs) + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + outputs['logits'] = logits + outputs.kwargs = kwargs + return outputs + + def forward(self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + Returns `modelscope.outputs.AttentionTextClassificationModelOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_structbert_sentence-similarity_chinese-base') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_structbert_sentence-similarity_chinese-base') + >>> # Call the model, return some tensors + >>> print(model(**preprocessor(('这是个测试', '这也是个测试')))) + >>> # Call the pipeline + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline('text-classification', model=model, preprocessor=preprocessor) + >>> print(pipeline_ins(('这是个测试', '这也是个测试'))) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if not return_dict: + logger.error('Return tuple in sbert is not supported now.') + outputs = self._forward_call( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + return self.compute_loss(outputs, labels, **outputs.kwargs) + + def compute_loss(self, outputs, labels, **kwargs): + logits = outputs.logits + embedding_output = outputs.embedding_output + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = 'regression' + elif self.num_labels > 1 and (labels.dtype == torch.long + or labels.dtype == torch.int): + self.config.problem_type = 'single_label_classification' + else: + self.config.problem_type = 'multi_label_classification' + + if self.config.problem_type == 'regression': + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == 'single_label_classification': + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(-1, self.num_labels), labels.view(-1)) + if self.config.adv_grad_factor is not None and self.training: + loss = compute_adv_loss( + embedding=embedding_output, + model=self._forward_call, + ori_logits=logits, + ori_loss=loss, + adv_bound=self.config.adv_bound, + adv_grad_factor=self.config.adv_grad_factor, + sigma=self.config.sigma, + **kwargs) + elif self.config.problem_type == 'multi_label_classification': + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + return AttentionTextClassificationModelOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/modelscope/models/nlp/structbert/token_classification.py b/modelscope/models/nlp/structbert/token_classification.py new file mode 100644 index 00000000..a040ff3e --- /dev/null +++ b/modelscope/models/nlp/structbert/token_classification.py @@ -0,0 +1,229 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.outputs import TokenClassifierOutput +from modelscope.utils import logger as logging +from modelscope.utils.constant import Tasks +from .adv_utils import compute_adv_loss +from .backbone import SbertModel, SbertPreTrainedModel +from .configuration import SbertConfig + +logger = logging.get_logger(__name__) + + +@MODELS.register_module( + Tasks.token_classification, module_name=Models.structbert) +@MODELS.register_module(Tasks.word_segmentation, module_name=Models.structbert) +@MODELS.register_module(Tasks.part_of_speech, module_name=Models.structbert) +class SbertForTokenClassification(SbertPreTrainedModel): + r"""StructBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) + e.g. for Named-Entity-Recognition (NER) tasks. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Preprocessor: + This is the token-classification model of StructBERT, the preprocessor of this model + is `modelscope.preprocessors.TokenClassificationPreprocessor`. + + Trainer: + This model is a normal PyTorch model, and can be trained by variable trainers, like EpochBasedTrainer, + NlpEpochBasedTrainer, or trainers from other frameworks. + The preferred trainer in modelscope is NlpEpochBasedTrainer. + + Parameters: + config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with + all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. + """ + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + + def __init__(self, config: SbertConfig, **kwargs): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + if self.config.adv_grad_factor is None: + logger.warning( + 'Adv parameters not set, skipping compute_adv_loss.') + setattr(self, self.base_model_prefix, + SbertModel(config, add_pooling_layer=False)) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None + else config.hidden_dropout_prob) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + def _forward_call(self, **kwargs): + outputs = self.bert(**kwargs) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + outputs['logits'] = logits + outputs.kwargs = kwargs + return outputs + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + offset_mapping=None, + label_mask=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - + 1]``. + offset_mapping (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, + sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the sentence. + Selected in the range ``[0, sequence_length - 1]``. + label_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, + sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask + values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + Returns `modelscope.outputs.TokenClassifierOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_structbert_word-segmentation_chinese-base') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_structbert_word-segmentation_chinese-base') + >>> print(model(**preprocessor(('This is a test', 'This is also a test')))) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if not return_dict: + logger.error('Return tuple in sbert is not supported now.') + + outputs = self._forward_call( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + + logits = outputs.logits + embedding_output = outputs.embedding_output + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), + torch.tensor(loss_fct.ignore_index).type_as(labels)) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct( + logits.view(-1, self.num_labels), labels.view(-1)) + if self.config.adv_grad_factor is not None and self.training: + loss = compute_adv_loss( + embedding=embedding_output, + model=self._forward_call, + ori_logits=logits, + ori_loss=loss, + adv_bound=self.config.adv_bound, + adv_grad_factor=self.config.adv_grad_factor, + sigma=self.config.sigma, + with_attention_mask=attention_mask is not None, + **outputs.kwargs) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + offset_mapping=offset_mapping, + ) diff --git a/modelscope/models/nlp/structbert/tokenization_sbert.py b/modelscope/models/nlp/structbert/tokenization.py similarity index 100% rename from modelscope/models/nlp/structbert/tokenization_sbert.py rename to modelscope/models/nlp/structbert/tokenization.py diff --git a/modelscope/models/nlp/structbert/tokenization_sbert_fast.py b/modelscope/models/nlp/structbert/tokenization_fast.py similarity index 99% rename from modelscope/models/nlp/structbert/tokenization_sbert_fast.py rename to modelscope/models/nlp/structbert/tokenization_fast.py index a0a81121..6f7b7ba7 100644 --- a/modelscope/models/nlp/structbert/tokenization_sbert_fast.py +++ b/modelscope/models/nlp/structbert/tokenization_fast.py @@ -24,7 +24,7 @@ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast from modelscope.utils.constant import ModelFile from modelscope.utils.logger import get_logger -from .tokenization_sbert import SbertTokenizer +from .tokenization import SbertTokenizer logger = get_logger(__name__) diff --git a/modelscope/models/nlp/task_models/__init__.py b/modelscope/models/nlp/task_models/__init__.py index 90f22aa1..b8722a36 100644 --- a/modelscope/models/nlp/task_models/__init__.py +++ b/modelscope/models/nlp/task_models/__init__.py @@ -7,18 +7,34 @@ if TYPE_CHECKING: from .information_extraction import InformationExtractionModel from .feature_extraction import FeatureExtractionModel from .fill_mask import FillMaskModel + from .nncrf_for_named_entity_recognition import ( + LSTMCRFForNamedEntityRecognition, + TransformerCRFForNamedEntityRecognition, + ) + from .nncrf_for_word_segmentation import ( + LSTMCRFForWordSegmentation, + TransformerCRFForWordSegmentation, + ) from .sequence_classification import SequenceClassificationModel from .task_model import SingleBackboneTaskModelBase from .token_classification import TokenClassificationModel + from .text_generation import TaskModelForTextGeneration else: _import_structure = { 'information_extraction': ['InformationExtractionModel'], 'feature_extraction': ['FeatureExtractionModel'], 'fill_mask': ['FillMaskModel'], + 'nncrf_for_named_entity_recognition': [ + 'TransformerCRFForNamedEntityRecognition', + 'LSTMCRFForNamedEntityRecognition' + ], + 'nncrf_for_word_segmentation': + ['TransformerCRFForWordSegmentation', 'LSTMCRFForWordSegmentation'], 'sequence_classification': ['SequenceClassificationModel'], 'task_model': ['SingleBackboneTaskModelBase'], 'token_classification': ['TokenClassificationModel'], + 'text_generation': ['TaskModelForTextGeneration'], } import sys diff --git a/modelscope/models/nlp/task_models/feature_extraction.py b/modelscope/models/nlp/task_models/feature_extraction.py index 069c37aa..9360ec08 100644 --- a/modelscope/models/nlp/task_models/feature_extraction.py +++ b/modelscope/models/nlp/task_models/feature_extraction.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict import numpy as np @@ -31,13 +32,8 @@ class FeatureExtractionModel(SingleBackboneTaskModelBase): self.build_backbone(self.backbone_cfg) def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]: - # backbone do not need labels, only head need for loss compute - labels = input.pop(OutputKeys.LABELS, None) - + input.pop(OutputKeys.LABELS, None) outputs = super().forward(input) - sequence_output, pooled_output = self.extract_backbone_outputs(outputs) - if labels is not None: - input[OutputKeys.LABELS] = labels - + sequence_output = outputs.last_hidden_state return {OutputKeys.TEXT_EMBEDDING: sequence_output} diff --git a/modelscope/models/nlp/task_models/fill_mask.py b/modelscope/models/nlp/task_models/fill_mask.py index f7ef1cc2..0f7d3345 100644 --- a/modelscope/models/nlp/task_models/fill_mask.py +++ b/modelscope/models/nlp/task_models/fill_mask.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict import numpy as np @@ -36,7 +37,7 @@ class FillMaskModel(SingleBackboneTaskModelBase): labels = input.pop(OutputKeys.LABELS, None) outputs = super().forward(input) - sequence_output, pooled_output = self.extract_backbone_outputs(outputs) + sequence_output = outputs.last_hidden_state outputs = self.head.forward(sequence_output) if labels is not None: diff --git a/modelscope/models/nlp/task_models/information_extraction.py b/modelscope/models/nlp/task_models/information_extraction.py index 0a7d5a47..ce0e21a3 100644 --- a/modelscope/models/nlp/task_models/information_extraction.py +++ b/modelscope/models/nlp/task_models/information_extraction.py @@ -16,6 +16,8 @@ __all__ = ['InformationExtractionModel'] @MODELS.register_module( Tasks.information_extraction, module_name=TaskModels.information_extraction) +@MODELS.register_module( + Tasks.relation_extraction, module_name=TaskModels.information_extraction) class InformationExtractionModel(SingleBackboneTaskModelBase): def __init__(self, model_dir: str, *args, **kwargs): @@ -31,7 +33,7 @@ class InformationExtractionModel(SingleBackboneTaskModelBase): def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]: outputs = super().forward(input) - sequence_output, pooled_output = self.extract_backbone_outputs(outputs) + sequence_output = outputs.last_hidden_state outputs = self.head.forward(sequence_output, input['text'], input['offsets']) return {OutputKeys.SPO_LIST: outputs} diff --git a/modelscope/models/nlp/task_models/nncrf_for_named_entity_recognition.py b/modelscope/models/nlp/task_models/nncrf_for_named_entity_recognition.py new file mode 100644 index 00000000..017e35e5 --- /dev/null +++ b/modelscope/models/nlp/task_models/nncrf_for_named_entity_recognition.py @@ -0,0 +1,727 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. All rights reserved. +# The CRF implementation borrows mostly from AllenNLP CRF module (https://github.com/allenai/allennlp) +# and pytorch-crf (https://github.com/kmkurn/pytorch-crf) with some modifications. + +import os +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn +from transformers import AutoConfig, AutoModel + +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import TokenClassifierWithPredictionsOutput +from modelscope.utils.constant import ModelFile, Tasks + +__all__ = [ + 'TransformerCRFForNamedEntityRecognition', + 'LSTMCRFForNamedEntityRecognition' +] + + +class SequenceLabelingForNamedEntityRecognition(TorchModel): + + def __init__(self, model_dir, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + self.model = self.init_model(model_dir, *args, **kwargs) + + model_ckpt = os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) + self.model.load_state_dict( + torch.load(model_ckpt, map_location=torch.device('cpu'))) + + def init_model(self, model_dir, *args, **kwargs): + raise NotImplementedError + + def train(self): + return self.model.train() + + def eval(self): + return self.model.eval() + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + offset_mapping=None, + label_mask=None, + ) -> Dict[str, Any]: + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - + 1]``. + offset_mapping (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, + sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the sentence. + Selected in the range ``[0, sequence_length - 1]``. + label_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, + sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask + values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + Returns `modelscope.outputs.TokenClassifierOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_structbert_word-segmentation_chinese-base') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_structbert_word-segmentation_chinese-base') + >>> print(model(**preprocessor(('This is a test', 'This is also a test')))) + """ + input_tensor = { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'label_mask': label_mask, + } + output = { + 'offset_mapping': offset_mapping, + **input_tensor, + **self.model(input_tensor) + } + return output + + def postprocess(self, input: Dict[str, Any], **kwargs): + predicts = self.model.decode(input) + offset_len = len(input['offset_mapping']) + predictions = torch.narrow( + predicts, 1, 0, + offset_len) # index_select only move loc, not resize + return TokenClassifierWithPredictionsOutput( + loss=None, + logits=None, + hidden_states=None, + attentions=None, + offset_mapping=input['offset_mapping'], + predictions=predictions, + ) + + +@MODELS.register_module( + Tasks.named_entity_recognition, module_name=Models.tcrf) +class TransformerCRFForNamedEntityRecognition( + SequenceLabelingForNamedEntityRecognition): + """This model wraps the TransformerCRF model to register into model sets. + """ + + def init_model(self, model_dir, *args, **kwargs): + self.config = AutoConfig.from_pretrained(model_dir) + num_labels = self.config.num_labels + + model = TransformerCRF(model_dir, num_labels) + return model + + +@MODELS.register_module( + Tasks.named_entity_recognition, module_name=Models.lcrf) +class LSTMCRFForNamedEntityRecognition( + SequenceLabelingForNamedEntityRecognition): + """This model wraps the LSTMCRF model to register into model sets. + """ + + def init_model(self, model_dir, *args, **kwargs): + self.config = AutoConfig.from_pretrained(model_dir) + vocab_size = self.config.vocab_size + embed_width = self.config.embed_width + num_labels = self.config.num_labels + lstm_hidden_size = self.config.lstm_hidden_size + + model = LSTMCRF(vocab_size, embed_width, num_labels, lstm_hidden_size) + return model + + +class TransformerCRF(nn.Module): + """A transformer based model to NER tasks. + + This model will use transformers' backbones as its backbone. + """ + + def __init__(self, model_dir, num_labels, **kwargs): + super(TransformerCRF, self).__init__() + + self.encoder = AutoModel.from_pretrained(model_dir) + self.linear = nn.Linear(self.encoder.config.hidden_size, num_labels) + self.crf = CRF(num_labels, batch_first=True) + + def forward(self, inputs): + embed = self.encoder( + inputs['input_ids'], attention_mask=inputs['attention_mask'])[0] + logits = self.linear(embed) + + if 'label_mask' in inputs: + mask = inputs['label_mask'] + masked_lengths = mask.sum(-1).long() + masked_logits = torch.zeros_like(logits) + for i in range(len(mask)): + masked_logits[ + i, :masked_lengths[i], :] = logits[i].masked_select( + mask[i].unsqueeze(-1)).view(masked_lengths[i], -1) + logits = masked_logits + + outputs = {'logits': logits} + return outputs + + def decode(self, inputs): + seq_lens = inputs['label_mask'].sum(-1).long() + mask = torch.arange( + inputs['label_mask'].shape[1], + device=seq_lens.device)[None, :] < seq_lens[:, None] + predicts = self.crf.decode(inputs['logits'], mask=mask).squeeze(0) + return predicts + + +class LSTMCRF(nn.Module): + """ + A standard bilstm-crf model for fast prediction. + """ + + def __init__(self, + vocab_size, + embed_width, + num_labels, + lstm_hidden_size=100, + **kwargs): + super(LSTMCRF, self).__init__() + self.embedding = Embedding(vocab_size, embed_width) + self.lstm = nn.LSTM( + embed_width, + lstm_hidden_size, + num_layers=1, + bidirectional=True, + batch_first=True) + self.ffn = nn.Linear(lstm_hidden_size * 2, num_labels) + self.crf = CRF(num_labels, batch_first=True) + + def forward(self, inputs): + embedding = self.embedding(inputs['input_ids']) + lstm_output, _ = self.lstm(embedding) + logits = self.ffn(lstm_output) + + if 'label_mask' in inputs: + mask = inputs['label_mask'] + masked_lengths = mask.sum(-1).long() + masked_logits = torch.zeros_like(logits) + for i in range(len(mask)): + masked_logits[ + i, :masked_lengths[i], :] = logits[i].masked_select( + mask[i].unsqueeze(-1)).view(masked_lengths[i], -1) + logits = masked_logits + + outputs = {'logits': logits} + return outputs + + def decode(self, inputs): + seq_lens = inputs['label_mask'].sum(-1).long() + mask = torch.arange( + inputs['label_mask'].shape[1], + device=seq_lens.device)[None, :] < seq_lens[:, None] + predicts = self.crf.decode(inputs['logits'], mask=mask).squeeze(0) + return predicts + + +class CRF(nn.Module): + """Conditional random field. + This module implements a conditional random field [LMP01]_. The forward computation + of this class computes the log likelihood of the given sequence of tags and + emission score tensor. This class also has `~CRF.decode` method which finds + the best tag sequence given an emission score tensor using `Viterbi algorithm`_. + Args: + num_tags: Number of tags. + batch_first: Whether the first dimension corresponds to the size of a minibatch. + Attributes: + start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size + ``(num_tags,)``. + end_transitions (`~torch.nn.Parameter`): End transition score tensor of size + ``(num_tags,)``. + transitions (`~torch.nn.Parameter`): Transition score tensor of size + ``(num_tags, num_tags)``. + .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001). + "Conditional random fields: Probabilistic models for segmenting and + labeling sequence data". *Proc. 18th International Conf. on Machine + Learning*. Morgan Kaufmann. pp. 282–289. + .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm + + """ + + def __init__(self, num_tags: int, batch_first: bool = False) -> None: + if num_tags <= 0: + raise ValueError(f'invalid number of tags: {num_tags}') + super().__init__() + self.num_tags = num_tags + self.batch_first = batch_first + self.start_transitions = nn.Parameter(torch.empty(num_tags)) + self.end_transitions = nn.Parameter(torch.empty(num_tags)) + self.transitions = nn.Parameter(torch.empty(num_tags, num_tags)) + + self.reset_parameters() + + def reset_parameters(self) -> None: + """Initialize the transition parameters. + The parameters will be initialized randomly from a uniform distribution + between -0.1 and 0.1. + """ + nn.init.uniform_(self.start_transitions, -0.1, 0.1) + nn.init.uniform_(self.end_transitions, -0.1, 0.1) + nn.init.uniform_(self.transitions, -0.1, 0.1) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(num_tags={self.num_tags})' + + def forward(self, + emissions: torch.Tensor, + tags: torch.LongTensor, + mask: Optional[torch.ByteTensor] = None, + reduction: str = 'mean') -> torch.Tensor: + """Compute the conditional log likelihood of a sequence of tags given emission scores. + Args: + emissions (`~torch.Tensor`): Emission score tensor of size + ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, + ``(batch_size, seq_length, num_tags)`` otherwise. + tags (`~torch.LongTensor`): Sequence of tags tensor of size + ``(seq_length, batch_size)`` if ``batch_first`` is ``False``, + ``(batch_size, seq_length)`` otherwise. + mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` + if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. + reduction: Specifies the reduction to apply to the output: + ``none|sum|mean|token_mean``. ``none``: no reduction will be applied. + ``sum``: the output will be summed over batches. ``mean``: the output will be + averaged over batches. ``token_mean``: the output will be averaged over tokens. + Returns: + `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if + reduction is ``none``, ``()`` otherwise. + """ + if reduction not in ('none', 'sum', 'mean', 'token_mean'): + raise ValueError(f'invalid reduction: {reduction}') + if mask is None: + mask = torch.ones_like(tags, dtype=torch.uint8, device=tags.device) + if mask.dtype != torch.uint8: + mask = mask.byte() + self._validate(emissions, tags=tags, mask=mask) + + if self.batch_first: + emissions = emissions.transpose(0, 1) + tags = tags.transpose(0, 1) + mask = mask.transpose(0, 1) + + # shape: (batch_size,) + numerator = self._compute_score(emissions, tags, mask) + # shape: (batch_size,) + denominator = self._compute_normalizer(emissions, mask) + # shape: (batch_size,) + llh = numerator - denominator + + if reduction == 'none': + return llh + if reduction == 'sum': + return llh.sum() + if reduction == 'mean': + return llh.mean() + return llh.sum() / mask.float().sum() + + def decode(self, + emissions: torch.Tensor, + mask: Optional[torch.ByteTensor] = None, + nbest: Optional[int] = None, + pad_tag: Optional[int] = None) -> List[List[List[int]]]: + """Find the most likely tag sequence using Viterbi algorithm. + Args: + emissions (`~torch.Tensor`): Emission score tensor of size + ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, + ``(batch_size, seq_length, num_tags)`` otherwise. + mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` + if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. + nbest (`int`): Number of most probable paths for each sequence + pad_tag (`int`): Tag at padded positions. Often input varies in length and + the length will be padded to the maximum length in the batch. Tags at + the padded positions will be assigned with a padding tag, i.e. `pad_tag` + Returns: + A PyTorch tensor of the best tag sequence for each batch of shape + (nbest, batch_size, seq_length) + """ + if nbest is None: + nbest = 1 + if mask is None: + mask = torch.ones( + emissions.shape[:2], + dtype=torch.uint8, + device=emissions.device) + if mask.dtype != torch.uint8: + mask = mask.byte() + self._validate(emissions, mask=mask) + + if self.batch_first: + emissions = emissions.transpose(0, 1) + mask = mask.transpose(0, 1) + + if nbest == 1: + return self._viterbi_decode(emissions, mask, pad_tag).unsqueeze(0) + return self._viterbi_decode_nbest(emissions, mask, nbest, pad_tag) + + def _validate(self, + emissions: torch.Tensor, + tags: Optional[torch.LongTensor] = None, + mask: Optional[torch.ByteTensor] = None) -> None: + if emissions.dim() != 3: + raise ValueError( + f'emissions must have dimension of 3, got {emissions.dim()}') + if emissions.size(2) != self.num_tags: + raise ValueError( + f'expected last dimension of emissions is {self.num_tags}, ' + f'got {emissions.size(2)}') + + if tags is not None: + if emissions.shape[:2] != tags.shape: + raise ValueError( + 'the first two dimensions of emissions and tags must match, ' + f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}' + ) + + if mask is not None: + if emissions.shape[:2] != mask.shape: + raise ValueError( + 'the first two dimensions of emissions and mask must match, ' + f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}' + ) + no_empty_seq = not self.batch_first and mask[0].all() + no_empty_seq_bf = self.batch_first and mask[:, 0].all() + if not no_empty_seq and not no_empty_seq_bf: + raise ValueError('mask of the first timestep must all be on') + + def _compute_score(self, emissions: torch.Tensor, tags: torch.LongTensor, + mask: torch.ByteTensor) -> torch.Tensor: + # emissions: (seq_length, batch_size, num_tags) + # tags: (seq_length, batch_size) + # mask: (seq_length, batch_size) + seq_length, batch_size = tags.shape + mask = mask.float() + + # Start transition score and first emission + # shape: (batch_size,) + score = self.start_transitions[tags[0]] + score += emissions[0, torch.arange(batch_size), tags[0]] + + for i in range(1, seq_length): + # Transition score to next tag, only added if next timestep is valid (mask == 1) + # shape: (batch_size,) + score += self.transitions[tags[i - 1], tags[i]] * mask[i] + + # Emission score for next tag, only added if next timestep is valid (mask == 1) + # shape: (batch_size,) + score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i] + + # End transition score + # shape: (batch_size,) + seq_ends = mask.long().sum(dim=0) - 1 + # shape: (batch_size,) + last_tags = tags[seq_ends, torch.arange(batch_size)] + # shape: (batch_size,) + score += self.end_transitions[last_tags] + + return score + + def _compute_normalizer(self, emissions: torch.Tensor, + mask: torch.ByteTensor) -> torch.Tensor: + # emissions: (seq_length, batch_size, num_tags) + # mask: (seq_length, batch_size) + seq_length = emissions.size(0) + + # Start transition score and first emission; score has size of + # (batch_size, num_tags) where for each batch, the j-th column stores + # the score that the first timestep has tag j + # shape: (batch_size, num_tags) + score = self.start_transitions + emissions[0] + + for i in range(1, seq_length): + # Broadcast score for every possible next tag + # shape: (batch_size, num_tags, 1) + broadcast_score = score.unsqueeze(2) + + # Broadcast emission score for every possible current tag + # shape: (batch_size, 1, num_tags) + broadcast_emissions = emissions[i].unsqueeze(1) + + # Compute the score tensor of size (batch_size, num_tags, num_tags) where + # for each sample, entry at row i and column j stores the sum of scores of all + # possible tag sequences so far that end with transitioning from tag i to tag j + # and emitting + # shape: (batch_size, num_tags, num_tags) + next_score = broadcast_score + self.transitions + broadcast_emissions + + # Sum over all possible current tags, but we're in score space, so a sum + # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of + # all possible tag sequences so far, that end in tag i + # shape: (batch_size, num_tags) + next_score = torch.logsumexp(next_score, dim=1) + + # Set score to the next score if this timestep is valid (mask == 1) + # shape: (batch_size, num_tags) + score = torch.where(mask[i].unsqueeze(1), next_score, score) + + # End transition score + # shape: (batch_size, num_tags) + score += self.end_transitions + + # Sum (log-sum-exp) over all possible tags + # shape: (batch_size,) + return torch.logsumexp(score, dim=1) + + def _viterbi_decode(self, + emissions: torch.FloatTensor, + mask: torch.ByteTensor, + pad_tag: Optional[int] = None) -> List[List[int]]: + # emissions: (seq_length, batch_size, num_tags) + # mask: (seq_length, batch_size) + # return: (batch_size, seq_length) + if pad_tag is None: + pad_tag = 0 + + device = emissions.device + seq_length, batch_size = mask.shape + + # Start transition and first emission + # shape: (batch_size, num_tags) + score = self.start_transitions + emissions[0] + history_idx = torch.zeros((seq_length, batch_size, self.num_tags), + dtype=torch.long, + device=device) + oor_idx = torch.zeros((batch_size, self.num_tags), + dtype=torch.long, + device=device) + oor_tag = torch.full((seq_length, batch_size), + pad_tag, + dtype=torch.long, + device=device) + + # - score is a tensor of size (batch_size, num_tags) where for every batch, + # value at column j stores the score of the best tag sequence so far that ends + # with tag j + # - history_idx saves where the best tags candidate transitioned from; this is used + # when we trace back the best tag sequence + # - oor_idx saves the best tags candidate transitioned from at the positions + # where mask is 0, i.e. out of range (oor) + + # Viterbi algorithm recursive case: we compute the score of the best tag sequence + # for every possible next tag + for i in range(1, seq_length): + # Broadcast viterbi score for every possible next tag + # shape: (batch_size, num_tags, 1) + broadcast_score = score.unsqueeze(2) + + # Broadcast emission score for every possible current tag + # shape: (batch_size, 1, num_tags) + broadcast_emission = emissions[i].unsqueeze(1) + + # Compute the score tensor of size (batch_size, num_tags, num_tags) where + # for each sample, entry at row i and column j stores the score of the best + # tag sequence so far that ends with transitioning from tag i to tag j and emitting + # shape: (batch_size, num_tags, num_tags) + next_score = broadcast_score + self.transitions + broadcast_emission + + # Find the maximum score over all possible current tag + # shape: (batch_size, num_tags) + next_score, indices = next_score.max(dim=1) + + # Set score to the next score if this timestep is valid (mask == 1) + # and save the index that produces the next score + # shape: (batch_size, num_tags) + score = torch.where(mask[i].unsqueeze(-1), next_score, score) + indices = torch.where(mask[i].unsqueeze(-1), indices, oor_idx) + history_idx[i - 1] = indices + + # End transition score + # shape: (batch_size, num_tags) + end_score = score + self.end_transitions + _, end_tag = end_score.max(dim=1) + + # shape: (batch_size,) + seq_ends = mask.long().sum(dim=0) - 1 + + # insert the best tag at each sequence end (last position with mask == 1) + history_idx = history_idx.transpose(1, 0).contiguous() + history_idx.scatter_( + 1, + seq_ends.view(-1, 1, 1).expand(-1, 1, self.num_tags), + end_tag.view(-1, 1, 1).expand(-1, 1, self.num_tags)) + history_idx = history_idx.transpose(1, 0).contiguous() + + # The most probable path for each sequence + best_tags_arr = torch.zeros((seq_length, batch_size), + dtype=torch.long, + device=device) + best_tags = torch.zeros(batch_size, 1, dtype=torch.long, device=device) + for idx in range(seq_length - 1, -1, -1): + best_tags = torch.gather(history_idx[idx], 1, best_tags) + best_tags_arr[idx] = best_tags.data.view(batch_size) + + return torch.where(mask, best_tags_arr, oor_tag).transpose(0, 1) + + def _viterbi_decode_nbest( + self, + emissions: torch.FloatTensor, + mask: torch.ByteTensor, + nbest: int, + pad_tag: Optional[int] = None) -> List[List[List[int]]]: + # emissions: (seq_length, batch_size, num_tags) + # mask: (seq_length, batch_size) + # return: (nbest, batch_size, seq_length) + if pad_tag is None: + pad_tag = 0 + + device = emissions.device + seq_length, batch_size = mask.shape + + # Start transition and first emission + # shape: (batch_size, num_tags) + score = self.start_transitions + emissions[0] + history_idx = torch.zeros( + (seq_length, batch_size, self.num_tags, nbest), + dtype=torch.long, + device=device) + oor_idx = torch.zeros((batch_size, self.num_tags, nbest), + dtype=torch.long, + device=device) + oor_tag = torch.full((seq_length, batch_size, nbest), + pad_tag, + dtype=torch.long, + device=device) + + # + score is a tensor of size (batch_size, num_tags) where for every batch, + # value at column j stores the score of the best tag sequence so far that ends + # with tag j + # + history_idx saves where the best tags candidate transitioned from; this is used + # when we trace back the best tag sequence + # - oor_idx saves the best tags candidate transitioned from at the positions + # where mask is 0, i.e. out of range (oor) + + # Viterbi algorithm recursive case: we compute the score of the best tag sequence + # for every possible next tag + for i in range(1, seq_length): + if i == 1: + broadcast_score = score.unsqueeze(-1) + broadcast_emission = emissions[i].unsqueeze(1) + # shape: (batch_size, num_tags, num_tags) + next_score = broadcast_score + self.transitions + broadcast_emission + else: + broadcast_score = score.unsqueeze(-1) + broadcast_emission = emissions[i].unsqueeze(1).unsqueeze(2) + # shape: (batch_size, num_tags, nbest, num_tags) + next_score = broadcast_score + self.transitions.unsqueeze( + 1) + broadcast_emission + + # Find the top `nbest` maximum score over all possible current tag + # shape: (batch_size, nbest, num_tags) + next_score, indices = next_score.view(batch_size, -1, + self.num_tags).topk( + nbest, dim=1) + + if i == 1: + score = score.unsqueeze(-1).expand(-1, -1, nbest) + indices = indices * nbest + + # convert to shape: (batch_size, num_tags, nbest) + next_score = next_score.transpose(2, 1) + indices = indices.transpose(2, 1) + + # Set score to the next score if this timestep is valid (mask == 1) + # and save the index that produces the next score + # shape: (batch_size, num_tags, nbest) + score = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), + next_score, score) + indices = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), indices, + oor_idx) + history_idx[i - 1] = indices + + # End transition score shape: (batch_size, num_tags, nbest) + end_score = score + self.end_transitions.unsqueeze(-1) + _, end_tag = end_score.view(batch_size, -1).topk(nbest, dim=1) + + # shape: (batch_size,) + seq_ends = mask.long().sum(dim=0) - 1 + + # insert the best tag at each sequence end (last position with mask == 1) + history_idx = history_idx.transpose(1, 0).contiguous() + history_idx.scatter_( + 1, + seq_ends.view(-1, 1, 1, 1).expand(-1, 1, self.num_tags, nbest), + end_tag.view(-1, 1, 1, nbest).expand(-1, 1, self.num_tags, nbest)) + history_idx = history_idx.transpose(1, 0).contiguous() + + # The most probable path for each sequence + best_tags_arr = torch.zeros((seq_length, batch_size, nbest), + dtype=torch.long, + device=device) + best_tags = torch.arange(nbest, dtype=torch.long, device=device) \ + .view(1, -1).expand(batch_size, -1) + for idx in range(seq_length - 1, -1, -1): + best_tags = torch.gather(history_idx[idx].view(batch_size, -1), 1, + best_tags) + best_tags_arr[idx] = best_tags.data.view(batch_size, -1) // nbest + + return torch.where(mask.unsqueeze(-1), best_tags_arr, + oor_tag).permute(2, 1, 0) + + +class Embedding(nn.Module): + + def __init__(self, vocab_size, embed_width): + super(Embedding, self).__init__() + + self.embedding = nn.Embedding(vocab_size, embed_width) + + def forward(self, input_ids): + return self.embedding(input_ids) diff --git a/modelscope/models/nlp/nncrf_for_named_entity_recognition.py b/modelscope/models/nlp/task_models/nncrf_for_word_segmentation.py similarity index 96% rename from modelscope/models/nlp/nncrf_for_named_entity_recognition.py rename to modelscope/models/nlp/task_models/nncrf_for_word_segmentation.py index 8b0c59b2..2a3f6cf4 100644 --- a/modelscope/models/nlp/nncrf_for_named_entity_recognition.py +++ b/modelscope/models/nlp/task_models/nncrf_for_word_segmentation.py @@ -12,15 +12,13 @@ from transformers import AutoConfig, AutoModel from modelscope.metainfo import Models from modelscope.models import TorchModel from modelscope.models.builder import MODELS +from modelscope.outputs import TokenClassifierWithPredictionsOutput from modelscope.utils.constant import ModelFile, Tasks -__all__ = [ - 'TransformerCRFForNamedEntityRecognition', - 'LSTMCRFForNamedEntityRecognition' -] +__all__ = ['TransformerCRFForWordSegmentation', 'LSTMCRFForWordSegmentation'] -class SequenceLabelingForNamedEntityRecognition(TorchModel): +class SequenceLabelingForWordSegmentation(TorchModel): def __init__(self, model_dir, *args, **kwargs): super().__init__(model_dir, *args, **kwargs) @@ -46,27 +44,30 @@ class SequenceLabelingForNamedEntityRecognition(TorchModel): 'label_mask': input['label_mask'], } output = { - 'text': input['text'], 'offset_mapping': input['offset_mapping'], **input_tensor, **self.model(input_tensor) } return output - def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: + def postprocess(self, input: Dict[str, Any], **kwargs): predicts = self.model.decode(input) - output = { - 'text': input['text'], - 'offset_mapping': input['offset_mapping'], - 'predicts': predicts['predicts'].squeeze(0).cpu().numpy(), - } - return output - - -@MODELS.register_module( - Tasks.named_entity_recognition, module_name=Models.tcrf) -class TransformerCRFForNamedEntityRecognition( - SequenceLabelingForNamedEntityRecognition): + offset_len = len(input['offset_mapping']) + predictions = torch.narrow( + predicts, 1, 0, + offset_len) # index_select only move loc, not resize + return TokenClassifierWithPredictionsOutput( + loss=None, + logits=None, + hidden_states=None, + attentions=None, + offset_mapping=input['offset_mapping'], + predictions=predictions, + ) + + +@MODELS.register_module(Tasks.word_segmentation, module_name=Models.tcrf_wseg) +class TransformerCRFForWordSegmentation(SequenceLabelingForWordSegmentation): """This model wraps the TransformerCRF model to register into model sets. """ @@ -78,10 +79,8 @@ class TransformerCRFForNamedEntityRecognition( return model -@MODELS.register_module( - Tasks.named_entity_recognition, module_name=Models.lcrf) -class LSTMCRFForNamedEntityRecognition( - SequenceLabelingForNamedEntityRecognition): +@MODELS.register_module(Tasks.word_segmentation, module_name=Models.lcrf_wseg) +class LSTMCRFForWordSegmentation(SequenceLabelingForWordSegmentation): """This model wraps the LSTMCRF model to register into model sets. """ @@ -133,8 +132,8 @@ class TransformerCRF(nn.Module): inputs['label_mask'].shape[1], device=seq_lens.device)[None, :] < seq_lens[:, None] predicts = self.crf.decode(inputs['logits'], mask=mask).squeeze(0) - outputs = {'predicts': predicts} - return outputs + + return predicts class LSTMCRF(nn.Module): diff --git a/modelscope/models/nlp/task_models/sequence_classification.py b/modelscope/models/nlp/task_models/sequence_classification.py index 1f5e46c3..6c0c09a2 100644 --- a/modelscope/models/nlp/task_models/sequence_classification.py +++ b/modelscope/models/nlp/task_models/sequence_classification.py @@ -1,8 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import os from typing import Any, Dict -import json import numpy as np from modelscope.metainfo import TaskModels @@ -16,11 +14,6 @@ from modelscope.utils.hub import parse_label_mapping __all__ = ['SequenceClassificationModel'] -@MODELS.register_module( - Tasks.sentence_similarity, module_name=TaskModels.text_classification) -@MODELS.register_module(Tasks.nli, module_name=TaskModels.text_classification) -@MODELS.register_module( - Tasks.sentiment_classification, module_name=TaskModels.text_classification) @MODELS.register_module( Tasks.text_classification, module_name=TaskModels.text_classification) class SequenceClassificationModel(SingleBackboneTaskModelBase): @@ -54,25 +47,10 @@ class SequenceClassificationModel(SingleBackboneTaskModelBase): labels = input.pop(OutputKeys.LABELS, None) outputs = super().forward(input) - sequence_output, pooled_output = self.extract_backbone_outputs(outputs) + pooled_output = outputs.pooler_output outputs = self.head.forward(pooled_output) if labels is not None: input[OutputKeys.LABELS] = labels loss = self.compute_loss(outputs, labels) outputs.update(loss) return outputs - - def extract_logits(self, outputs): - return outputs[OutputKeys.LOGITS].cpu().detach() - - def postprocess(self, input, **kwargs): - logits = self.extract_logits(input) - probs = logits.softmax(-1).numpy() - pred = logits.argmax(-1).numpy() - logits = logits.numpy() - res = { - OutputKeys.PREDICTIONS: pred, - OutputKeys.PROBABILITIES: probs, - OutputKeys.LOGITS: logits - } - return res diff --git a/modelscope/models/nlp/task_models/task_model.py b/modelscope/models/nlp/task_models/task_model.py index 0b43044f..8c83517a 100644 --- a/modelscope/models/nlp/task_models/task_model.py +++ b/modelscope/models/nlp/task_models/task_model.py @@ -404,7 +404,7 @@ class SingleBackboneTaskModelBase(BaseTaskModel): def build_backbone(self, cfg): if 'prefix' in cfg: self._backbone_prefix = cfg['prefix'] - backbone = build_backbone(cfg, field=Fields.nlp) + backbone = build_backbone(cfg) setattr(self, cfg['prefix'], backbone) def build_head(self, cfg): @@ -414,7 +414,7 @@ class SingleBackboneTaskModelBase(BaseTaskModel): ) if 'prefix' in cfg: self._head_prefix = cfg['prefix'] - head = build_head(cfg, group_key=self.group_key) + head = build_head(cfg, task_name=self.group_key) setattr(self, self._head_prefix, head) return head diff --git a/modelscope/models/nlp/task_models/text_generation.py b/modelscope/models/nlp/task_models/text_generation.py new file mode 100644 index 00000000..f17b0f6b --- /dev/null +++ b/modelscope/models/nlp/task_models/text_generation.py @@ -0,0 +1,79 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import addict +import numpy as np +from transformers.modeling_utils import PreTrainedModel + +from modelscope.metainfo import TaskModels +from modelscope.models.builder import MODELS +from modelscope.models.nlp.task_models.task_model import \ + SingleBackboneTaskModelBase +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks + +__all__ = ['TaskModelForTextGeneration'] + + +@MODELS.register_module( + Tasks.text_generation, module_name=TaskModels.text_generation) +class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the text generation model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + if 'base_model_prefix' in kwargs: + self._base_model_prefix = kwargs['base_model_prefix'] + + self.build_backbone(self.backbone_cfg) + self.build_head(self.head_cfg) + if self.config.get('shared_embedding', False): + input_embeddings = self.backbone.get_input_embeddings() + output_embeddings = self.head.get_output_embeddings() + output_embeddings.weight = input_embeddings.weight + + def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]: + # backbone do not need labels, only head need for loss compute + labels = input.pop(OutputKeys.LABELS, None) + + backbone_outputs = super().forward(input) + hidden_states = backbone_outputs[0] + + outputs = self.head.forward(hidden_states) + if labels is not None: + input[OutputKeys.LABELS] = labels + loss = self.compute_loss(outputs, labels) + outputs.update(loss) + return addict.Dict(outputs) + + def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + # only last token for inputs_ids if past is defined in kwargs + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get('attention_mask', None) + position_ids = kwargs.get('position_ids', None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + 'input_ids': input_ids, + 'past_key_values': past, + 'use_cache': kwargs.get('use_cache'), + 'position_ids': position_ids, + 'attention_mask': attention_mask, + } + + def generate(self, inputs, *args, **kwargs): + input_ids = inputs['input_ids'] if isinstance(inputs, Dict) else inputs + return super().generate(input_ids, *args, **kwargs) diff --git a/modelscope/models/nlp/task_models/token_classification.py b/modelscope/models/nlp/task_models/token_classification.py index f3930182..2739bf11 100644 --- a/modelscope/models/nlp/task_models/token_classification.py +++ b/modelscope/models/nlp/task_models/token_classification.py @@ -8,7 +8,7 @@ from modelscope.metainfo import TaskModels from modelscope.models.builder import MODELS from modelscope.models.nlp.task_models.task_model import \ SingleBackboneTaskModelBase -from modelscope.outputs import OutputKeys +from modelscope.outputs import OutputKeys, TokenClassifierOutput from modelscope.utils.constant import Tasks from modelscope.utils.hub import parse_label_mapping from modelscope.utils.tensor_utils import (torch_nested_detach, @@ -19,6 +19,8 @@ __all__ = ['TokenClassificationModel'] @MODELS.register_module( Tasks.token_classification, module_name=TaskModels.token_classification) +@MODELS.register_module( + Tasks.part_of_speech, module_name=TaskModels.token_classification) class TokenClassificationModel(SingleBackboneTaskModelBase): def __init__(self, model_dir: str, *args, **kwargs): @@ -51,27 +53,20 @@ class TokenClassificationModel(SingleBackboneTaskModelBase): labels = input.pop(OutputKeys.LABELS) outputs = super().forward(input) - sequence_output, pooled_output = self.extract_backbone_outputs(outputs) - outputs = self.head.forward(sequence_output) + sequence_output = outputs[0] + logits = self.head.forward(sequence_output) + loss = None if labels in input: loss = self.compute_loss(outputs, labels) - outputs.update(loss) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + offset_mapping=input['offset_mapping'], + ) return outputs def extract_logits(self, outputs): return outputs[OutputKeys.LOGITS].cpu().detach() - - def extract_backbone_outputs(self, outputs): - sequence_output = None - pooled_output = None - if hasattr(self.backbone, 'extract_sequence_outputs'): - sequence_output = self.backbone.extract_sequence_outputs(outputs) - return sequence_output, pooled_output - - def postprocess(self, input, **kwargs): - logits = self.extract_logits(input) - pred = torch.argmax(logits[0], dim=-1) - pred = torch_nested_numpify(torch_nested_detach(pred)) - logits = torch_nested_numpify(torch_nested_detach(logits)) - res = {OutputKeys.PREDICTIONS: pred, OutputKeys.LOGITS: logits} - return res diff --git a/modelscope/models/nlp/token_classification.py b/modelscope/models/nlp/token_classification.py deleted file mode 100644 index c63e8037..00000000 --- a/modelscope/models/nlp/token_classification.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -from abc import abstractmethod -from typing import Dict - -import numpy as np -import torch -from torch import nn - -from modelscope.metainfo import Models -from modelscope.models.base import TorchModel -from modelscope.models.builder import MODELS -from modelscope.models.nlp.bert import BertPreTrainedModel -from modelscope.models.nlp.structbert import SbertPreTrainedModel -from modelscope.outputs import OutputKeys -from modelscope.utils.constant import Tasks -from modelscope.utils.hub import parse_label_mapping -from modelscope.utils.tensor_utils import (torch_nested_detach, - torch_nested_numpify) - -__all__ = ['SbertForTokenClassification'] - - -class TokenClassification(TorchModel): - """A token classification base class for all the fitted token classification models. - """ - - base_model_prefix: str = 'bert' - - def __init__(self, config, model_dir): - super().__init__(model_dir) - self.num_labels = config.num_labels - self.config = config - setattr(self, self.base_model_prefix, self.build_base_model()) - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None - else config.hidden_dropout_prob) - self.dropout = nn.Dropout(classifier_dropout) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) - - @abstractmethod - def build_base_model(self): - """Build the backbone model. - - Returns: the backbone instance. - """ - pass - - @property - def base_model(self): - return getattr(self, self.base_model_prefix) - - def compute_loss(self, logits, labels, **kwargs): - """Compute loss. - - For example, if backbone is pretrained model, there will be a 'attention_mask' parameter to skip - useless tokens. - - Args: - logits: The logits from the classifier - labels: The labels - **kwargs: Other input params. - - Returns: The loss. - - """ - pass - - def forward(self, **kwargs): - labels = None - if OutputKeys.LABEL in kwargs: - labels = kwargs.pop(OutputKeys.LABEL) - elif OutputKeys.LABELS in kwargs: - labels = kwargs.pop(OutputKeys.LABELS) - - outputs = self.base_model(**kwargs) - # base model should return the sequence_output as its first output - sequence_output = outputs[0] - sequence_output = self.dropout(sequence_output) - logits = self.classifier(sequence_output) - if labels is not None: - loss = self.compute_loss(logits, labels, **kwargs) - return {OutputKeys.LOGITS: logits, OutputKeys.LOSS: loss} - return {OutputKeys.LOGITS: logits} - - def postprocess(self, input: Dict[str, np.ndarray], - **kwargs) -> Dict[str, np.ndarray]: - logits = input[OutputKeys.LOGITS] - pred = torch.argmax(logits[0], dim=-1) - pred = torch_nested_numpify(torch_nested_detach(pred)) - logits = torch_nested_numpify(torch_nested_detach(logits)) - rst = {OutputKeys.PREDICTIONS: pred, OutputKeys.LOGITS: logits} - return rst - - -@MODELS.register_module(Tasks.word_segmentation, module_name=Models.structbert) -@MODELS.register_module(Tasks.part_of_speech, module_name=Models.structbert) -@MODELS.register_module( - Tasks.token_classification, module_name=Models.structbert) -class SbertForTokenClassification(TokenClassification, SbertPreTrainedModel): - """Sbert token classification model. - - Inherited from TokenClassification. - """ - - supports_gradient_checkpointing = True - _keys_to_ignore_on_load_unexpected = [r'pooler'] - - def __init__(self, config, model_dir): - if hasattr(config, 'base_model_prefix'): - SbertForTokenClassification.base_model_prefix = config.base_model_prefix - super().__init__(config, model_dir) - - def build_base_model(self): - from .structbert import SbertModel - return SbertModel(self.config, add_pooling_layer=False) - - def forward(self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - labels=None, - **kwargs): - return super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - labels=labels) - - def compute_loss(self, logits, labels, attention_mask=None, **kwargs): - """Compute the loss with an attention mask. - - @param logits: The logits output from the classifier. - @param labels: The labels. - @param attention_mask: The attention_mask. - @param kwargs: Unused input args. - @return: The loss - """ - loss_fct = nn.CrossEntropyLoss() - # Only keep active parts of the loss - if attention_mask is not None: - active_loss = attention_mask.view(-1) == 1 - active_logits = logits.view(-1, self.num_labels) - active_labels = torch.where( - active_loss, labels.view(-1), - torch.tensor(loss_fct.ignore_index).type_as(labels)) - return loss_fct(active_logits, active_labels) - else: - return loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - @classmethod - def _instantiate(cls, **kwargs): - """Instantiate the model. - - @param kwargs: Input args. - model_dir: The model dir used to load the checkpoint and the label information. - num_labels: An optional arg to tell the model how many classes to initialize. - Method will call utils.parse_label_mapping if num_labels not supplied. - If num_labels is not found, the model will use the default setting (2 classes). - @return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained - """ - model_dir = kwargs.get('model_dir') - num_labels = kwargs.get('num_labels') - if num_labels is None: - label2id = parse_label_mapping(model_dir) - if label2id is not None and len(label2id) > 0: - num_labels = len(label2id) - - model_args = {} if num_labels is None else {'num_labels': num_labels} - return super(SbertPreTrainedModel, - SbertForTokenClassification).from_pretrained( - pretrained_model_name_or_path=kwargs.get('model_dir'), - model_dir=kwargs.get('model_dir'), - **model_args) - - -@MODELS.register_module(Tasks.word_segmentation, module_name=Models.bert) -@MODELS.register_module(Tasks.token_classification, module_name=Models.bert) -class BertForSequenceClassification(TokenClassification, BertPreTrainedModel): - """Bert token classification model. - - Inherited from TokenClassificationBase. - """ - base_model_prefix: str = 'bert' - supports_gradient_checkpointing = True - _keys_to_ignore_on_load_missing = [r'position_ids'] - - def __init__(self, config, model_dir): - if hasattr(config, 'base_model_prefix'): - BertForSequenceClassification.base_model_prefix = config.base_model_prefix - super().__init__(config, model_dir) - - def build_base_model(self): - from .bert import BertModel - return BertModel(self.config, add_pooling_layer=True) - - def forward(self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - **kwargs): - return super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - labels=labels, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - **kwargs) diff --git a/modelscope/models/nlp/veco/__init__.py b/modelscope/models/nlp/veco/__init__.py index 0fe786fd..0774e9b4 100644 --- a/modelscope/models/nlp/veco/__init__.py +++ b/modelscope/models/nlp/veco/__init__.py @@ -18,18 +18,22 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .configuration_veco import VecoConfig - from .modeling_veco import (VecoForMaskedLM, VecoForSequenceClassification, - VecoModel) - from .tokenization_veco import VecoTokenizer - from .tokenization_veco_fast import VecoTokenizerFast + from .configuration import VecoConfig + from .backbone import VecoModel + from .text_classification import VecoForSequenceClassification + from .token_classification import VecoForTokenClassification + from .fill_mask import VecoForMaskedLM + from .tokenization import VecoTokenizer + from .tokenization_fast import VecoTokenizerFast else: _import_structure = { - 'configuration_veco': ['VecoConfig'], - 'modeling_veco': - ['VecoForMaskedLM', 'VecoForSequenceClassification', 'VecoModel'], - 'tokenization_veco': ['VecoTokenizer'], - 'tokenization_veco_fast': ['VecoTokenizerFast'], + 'configuration': ['VecoConfig'], + 'backbone': ['VecoModel'], + 'text_classification': ['VecoForSequenceClassification'], + 'fill_mask': ['VecoForMaskedLM'], + 'token_classification': ['VecoForTokenClassification'], + 'tokenization': ['VecoTokenizer'], + 'tokenization_fast': ['VecoTokenizerFast'], } import sys diff --git a/modelscope/models/nlp/veco/backbone.py b/modelscope/models/nlp/veco/backbone.py new file mode 100644 index 00000000..98d8c30a --- /dev/null +++ b/modelscope/models/nlp/veco/backbone.py @@ -0,0 +1,96 @@ +# Copyright 2019 Facebook AI Research and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Veco model. mainly copied from :module:`~transformers.modeling_xlm_roberta`""" + +from transformers import RobertaModel + +from modelscope.metainfo import Models +from modelscope.models import Model, TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionBackboneModelOutput +from modelscope.utils import logger as logging +from modelscope.utils.constant import Tasks +from .configuration import VecoConfig + +logger = logging.get_logger(__name__) + +VECO_PRETRAINED_MODEL_ARCHIVE_LIST = [] + + +@MODELS.register_module(Tasks.backbone, module_name=Models.veco) +class VecoModel(TorchModel, RobertaModel): + """The bare Veco Model transformer outputting raw hidden-states without any specific head on top. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config ([`VecoConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model + weights. + + This class overrides [`RobertaModel`]. Please check the superclass for the appropriate + documentation alongside usage examples. + """ + + config_class = VecoConfig + + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + + def forward(self, *args, **kwargs): + """ + Returns: + Returns `modelscope.outputs.AttentionBackboneModelOutputWithEmbedding` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_veco_fill-mask-large', task='backbone') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_veco_fill-mask-large') + >>> print(model(**preprocessor('这是个测试'))) + + """ + kwargs['return_dict'] = True + outputs = super(Model, self).forward(*args, **kwargs) + return AttentionBackboneModelOutput( + last_hidden_state=outputs.last_hidden_state, + pooler_output=outputs.pooler_output, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + @classmethod + def _instantiate(cls, **kwargs): + model_dir = kwargs.pop('model_dir', None) + if model_dir is None: + ponet_config = VecoConfig(**kwargs) + model = cls(ponet_config) + else: + model = super( + Model, + cls).from_pretrained(pretrained_model_name_or_path=model_dir) + return model diff --git a/modelscope/models/nlp/veco/configuration_veco.py b/modelscope/models/nlp/veco/configuration.py similarity index 100% rename from modelscope/models/nlp/veco/configuration_veco.py rename to modelscope/models/nlp/veco/configuration.py diff --git a/modelscope/models/nlp/veco/fill_mask.py b/modelscope/models/nlp/veco/fill_mask.py new file mode 100644 index 00000000..de2cdb4a --- /dev/null +++ b/modelscope/models/nlp/veco/fill_mask.py @@ -0,0 +1,99 @@ +# Copyright 2019 Facebook AI Research and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import RobertaForMaskedLM + +from modelscope.metainfo import Models +from modelscope.models import Model, TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionFillMaskModelOutput +from modelscope.utils.constant import Tasks +from .configuration import VecoConfig + + +@MODELS.register_module(Tasks.fill_mask, module_name=Models.veco) +class VecoForMaskedLM(TorchModel, RobertaForMaskedLM): + """Veco Model transformer with a masked language model head on top (a linear layer on top of the + pooled output). + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Preprocessor: + This is the fill_mask model of StructBERT, the preprocessor of this model + is `modelscope.preprocessors.NLPPreprocessor`. + + Parameters: + config ([`VecoConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model + weights. + + This class overrides [`RobertaForMaskedLM`]. Please check the superclass for the + appropriate documentation alongside usage examples. + """ + + config_class = VecoConfig + + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + + def forward(self, *args, **kwargs): + """ + Returns: + Returns `modelscope.outputs.AttentionFillMaskModelOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_veco_fill-mask-large') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_veco_fill-mask-large') + >>> # Call the model, return some tensors + >>> print(model(**preprocessor('你师父差得动你,你师父可不动我。'))) + >>> # Call the pipeline + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline('fill-mask', model=model, preprocessor=preprocessor) + >>> print(pipeline_ins('你师父差得动你,你师父可不动我。')) + """ + + kwargs['return_dict'] = True + outputs = super(Model, self).forward(*args, **kwargs) + return AttentionFillMaskModelOutput( + loss=outputs.loss, + logits=outputs.logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + input_ids=kwargs['input_ids'], + ) + + @classmethod + def _instantiate(cls, **kwargs): + model_dir = kwargs.pop('model_dir', None) + if model_dir is None: + ponet_config = VecoConfig(**kwargs) + model = cls(ponet_config) + else: + model = super( + Model, + cls).from_pretrained(pretrained_model_name_or_path=model_dir) + return model diff --git a/modelscope/models/nlp/veco/modeling_veco.py b/modelscope/models/nlp/veco/modeling_veco.py deleted file mode 100644 index b519c236..00000000 --- a/modelscope/models/nlp/veco/modeling_veco.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2019 Facebook AI Research and the HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. -# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Veco model. mainly copied from :module:`~transformers.modeling_xlm_roberta`""" - -from transformers import (RobertaForMaskedLM, RobertaForMultipleChoice, - RobertaForQuestionAnswering, - RobertaForSequenceClassification, - RobertaForTokenClassification, RobertaModel) -from transformers.file_utils import add_start_docstrings - -from modelscope.metainfo import Models -from modelscope.models.builder import BACKBONES -from modelscope.utils import logger as logging -from modelscope.utils.constant import Fields -from .configuration_veco import VecoConfig - -logger = logging.get_logger(__name__) - -VECO_PRETRAINED_MODEL_ARCHIVE_LIST = [] - -VECO_START_DOCSTRING = r""" - - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic - methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, - pruning heads etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) - subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to - general usage and behavior. - - Parameters: - config ([`VecoConfig`]): Model configuration class with all the parameters of the - model. Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model - weights. -""" - - -@add_start_docstrings( - 'The bare Veco Model transformer outputting raw hidden-states without any specific head on top.', - VECO_START_DOCSTRING, -) -class VecoModel(RobertaModel): - """ - This class overrides [`RobertaModel`]. Please check the superclass for the appropriate - documentation alongside usage examples. - """ - - config_class = VecoConfig - - -@add_start_docstrings( - """ - Veco Model transformer with a sequence classification/regression head on top (a linear layer on top of the - pooled output) e.g. for GLUE tasks. - """, - VECO_START_DOCSTRING, -) -class VecoForSequenceClassification(RobertaForSequenceClassification): - """ - This class overrides [`RobertaForSequenceClassification`]. Please check the superclass for the - appropriate documentation alongside usage examples. - """ - - config_class = VecoConfig - - -@add_start_docstrings( - """ - Veco Model transformer with a masked language model head on top (a linear layer on top of the - pooled output). - """, - VECO_START_DOCSTRING, -) -class VecoForMaskedLM(RobertaForMaskedLM): - """ - This class overrides [`RobertaForMaskedLM`]. Please check the superclass for the - appropriate documentation alongside usage examples. - """ - - config_class = VecoConfig - - -@add_start_docstrings( - """ - Veco Model with a multiple choice classification head on top (a linear layer on top of the pooled output and - a softmax) e.g. for RocStories/SWAG tasks. - """, - VECO_START_DOCSTRING, -) -class VecoForMultipleChoice(RobertaForMultipleChoice): - """ - This class overrides [`RobertaForMultipleChoice`]. Please check the superclass for the - appropriate documentation alongside usage examples. - """ - - config_class = VecoConfig - - -@add_start_docstrings( - """ - Veco Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. - for Named-Entity-Recognition (NER) tasks. - """, - VECO_START_DOCSTRING, -) -class VecoForTokenClassification(RobertaForTokenClassification): - """ - This class overrides [`RobertaForTokenClassification`]. Please check the superclass for the - appropriate documentation alongside usage examples. - """ - - config_class = VecoConfig - - -@add_start_docstrings( - """ - Veco Model with a span classification head on top for extractive question-answering tasks like SQuAD (a - linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - VECO_START_DOCSTRING, -) -class VecoForQuestionAnswering(RobertaForQuestionAnswering): - """ - This class overrides [`RobertaForQuestionAnswering`]. Please check the superclass for the - appropriate documentation alongside usage examples. - """ - - config_class = VecoConfig diff --git a/modelscope/models/nlp/veco/text_classification.py b/modelscope/models/nlp/veco/text_classification.py new file mode 100644 index 00000000..e4e74d8f --- /dev/null +++ b/modelscope/models/nlp/veco/text_classification.py @@ -0,0 +1,150 @@ +# Copyright 2019 Facebook AI Research and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import RobertaForSequenceClassification + +from modelscope.metainfo import Models +from modelscope.models import Model, TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionTextClassificationModelOutput +from modelscope.utils.constant import Tasks +from modelscope.utils.hub import parse_label_mapping +from .configuration import VecoConfig + + +@MODELS.register_module(Tasks.nli, module_name=Models.veco) +@MODELS.register_module( + Tasks.sentiment_classification, module_name=Models.veco) +@MODELS.register_module(Tasks.sentence_similarity, module_name=Models.veco) +@MODELS.register_module(Tasks.text_classification, module_name=Models.veco) +class VecoForSequenceClassification(TorchModel, + RobertaForSequenceClassification): + """Veco Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Preprocessor: + This is the text classification model of Veco, the preprocessor of this model + is `modelscope.preprocessors.SequenceClassificationPreprocessor`. + + Trainer: + This model should be trained by dataset which has mixed languages, + and evaluated by datasets of languages one by one. + For example, if the training dataset is xnli (which has sub datasets of multiple languages), then you + should mix the sub-datasets with the languages you want to train to one training dataset, and evaluate + the model one sub-dataset by one sub-dataset of different languages. + This procedure can be done by custom code. If you are using trainer of ModelScope, + the `VecoTrainer` is suggested to use to train this model. This trainer overrides the basic evaluation + loop, and will call the evaluation dataset one by one. Besides, this trainer will use the `VecoTaskDataset` + to mix the input datasets to one, you can check the API Doc for the details. + + To check the complete example please + view the unittest `test_veco_xnli` in `tests.trainers.test_finetune_sequence_classification.py` + + Parameters: + config ([`VecoConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model + weights. + + This class overrides [`RobertaForSequenceClassification`]. Please check the superclass for the + appropriate documentation alongside usage examples. + """ + + config_class = VecoConfig + + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + + def forward(self, *args, **kwargs): + """ + Returns: + Returns `modelscope.outputs.AttentionTextClassificationModelOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_veco_fill-mask-large', + >>> task='text-classification', num_labels=2) + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_veco_fill-mask-large', + >>> label2id={'0': 0, '1': 1}) + >>> # Call the model, return some tensors + >>> print(model(**preprocessor('这是个测试'))) + >>> # Call the pipeline, the result may be incorrect + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline('text-classification', pipeline_name='text-classification', + >>> model=model, preprocessor=preprocessor) + >>> print(pipeline_ins('这是个测试')) + """ + + kwargs['return_dict'] = True + outputs = super(Model, self).forward(*args, **kwargs) + return AttentionTextClassificationModelOutput( + loss=outputs.loss, + logits=outputs.logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + Args: + kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the label information. + num_labels: An optional arg to tell the model how many classes to initialize. + Method will call utils.parse_label_mapping if num_labels is not input. + label2id: An optional label2id mapping, which will cover the label2id in configuration (if exists). + + Returns: + The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained + """ + + model_dir = kwargs.pop('model_dir', None) + if model_dir is None: + config = VecoConfig(**kwargs) + model = cls(config) + else: + model_kwargs = {} + label2id = kwargs.get('label2id', parse_label_mapping(model_dir)) + id2label = kwargs.get( + 'id2label', None if label2id is None else + {id: label + for label, id in label2id.items()}) + if id2label is not None and label2id is None: + label2id = {label: id for id, label in id2label.items()} + + num_labels = kwargs.get( + 'num_labels', None if label2id is None else len(label2id)) + if num_labels is not None: + model_kwargs['num_labels'] = num_labels + if label2id is not None: + model_kwargs['label2id'] = label2id + if id2label is not None: + model_kwargs['id2label'] = id2label + model = super(Model, cls).from_pretrained( + pretrained_model_name_or_path=model_dir, **model_kwargs) + return model diff --git a/modelscope/models/nlp/veco/token_classification.py b/modelscope/models/nlp/veco/token_classification.py new file mode 100644 index 00000000..f6252209 --- /dev/null +++ b/modelscope/models/nlp/veco/token_classification.py @@ -0,0 +1,107 @@ +# Copyright 2019 Facebook AI Research and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import RobertaForTokenClassification + +from modelscope.metainfo import Models +from modelscope.models import Model, TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionTokenClassificationModelOutput +from modelscope.utils.constant import Tasks +from modelscope.utils.hub import parse_label_mapping +from .configuration import VecoConfig + + +@MODELS.register_module(Tasks.token_classification, module_name=Models.veco) +class VecoForTokenClassification(TorchModel, RobertaForTokenClassification): + """Veco Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. + for Named-Entity-Recognition (NER) tasks. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config ([`VecoConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model + weights. + + This class overrides [`RobertaForTokenClassification`]. Please check the superclass for the + appropriate documentation alongside usage examples. + """ + + config_class = VecoConfig + + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + + def forward(self, *args, **kwargs): + kwargs['return_dict'] = True + outputs = super(Model, self).forward(*args, **kwargs) + return AttentionTokenClassificationModelOutput( + loss=outputs.loss, + logits=outputs.logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + Args: + kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the label information. + num_labels: An optional arg to tell the model how many classes to initialize. + Method will call utils.parse_label_mapping if num_labels is not input. + label2id: An optional label2id mapping, which will cover the label2id in configuration (if exists). + + Returns: + The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained + """ + + model_dir = kwargs.pop('model_dir', None) + if model_dir is None: + config = VecoConfig(**kwargs) + model = cls(config) + else: + model_kwargs = {} + label2id = kwargs.get('label2id', parse_label_mapping(model_dir)) + id2label = kwargs.get( + 'id2label', None if label2id is None else + {id: label + for label, id in label2id.items()}) + if id2label is not None and label2id is None: + label2id = {label: id for id, label in id2label.items()} + + num_labels = kwargs.get( + 'num_labels', None if label2id is None else len(label2id)) + if num_labels is not None: + model_kwargs['num_labels'] = num_labels + if label2id is not None: + model_kwargs['label2id'] = label2id + if id2label is not None: + model_kwargs['id2label'] = id2label + model = super(Model, cls).from_pretrained( + pretrained_model_name_or_path=model_dir, **model_kwargs) + return model diff --git a/modelscope/models/nlp/veco/tokenization_veco.py b/modelscope/models/nlp/veco/tokenization.py similarity index 100% rename from modelscope/models/nlp/veco/tokenization_veco.py rename to modelscope/models/nlp/veco/tokenization.py diff --git a/modelscope/models/nlp/veco/tokenization_veco_fast.py b/modelscope/models/nlp/veco/tokenization_fast.py similarity index 99% rename from modelscope/models/nlp/veco/tokenization_veco_fast.py rename to modelscope/models/nlp/veco/tokenization_fast.py index 3edae0e7..b41a5c3b 100644 --- a/modelscope/models/nlp/veco/tokenization_veco_fast.py +++ b/modelscope/models/nlp/veco/tokenization_fast.py @@ -27,7 +27,7 @@ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast from modelscope.utils import logger as logging if is_sentencepiece_available(): - from .tokenization_veco import VecoTokenizer + from .tokenization import VecoTokenizer else: VecoTokenizer = None diff --git a/modelscope/models/science/__init__.py b/modelscope/models/science/__init__.py new file mode 100644 index 00000000..50ab55d7 --- /dev/null +++ b/modelscope/models/science/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .unifold import UnifoldForProteinStructrue + +else: + _import_structure = {'unifold': ['UnifoldForProteinStructrue']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/science/unifold/__init__.py b/modelscope/models/science/unifold/__init__.py new file mode 100644 index 00000000..75435fed --- /dev/null +++ b/modelscope/models/science/unifold/__init__.py @@ -0,0 +1 @@ +from .model import UnifoldForProteinStructrue diff --git a/modelscope/models/science/unifold/config.py b/modelscope/models/science/unifold/config.py new file mode 100644 index 00000000..e760fbf9 --- /dev/null +++ b/modelscope/models/science/unifold/config.py @@ -0,0 +1,636 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import copy +from typing import Any + +import ml_collections as mlc + +N_RES = 'number of residues' +N_MSA = 'number of MSA sequences' +N_EXTRA_MSA = 'number of extra MSA sequences' +N_TPL = 'number of templates' + +d_pair = mlc.FieldReference(128, field_type=int) +d_msa = mlc.FieldReference(256, field_type=int) +d_template = mlc.FieldReference(64, field_type=int) +d_extra_msa = mlc.FieldReference(64, field_type=int) +d_single = mlc.FieldReference(384, field_type=int) +max_recycling_iters = mlc.FieldReference(3, field_type=int) +chunk_size = mlc.FieldReference(4, field_type=int) +aux_distogram_bins = mlc.FieldReference(64, field_type=int) +eps = mlc.FieldReference(1e-8, field_type=float) +inf = mlc.FieldReference(3e4, field_type=float) +use_templates = mlc.FieldReference(True, field_type=bool) +is_multimer = mlc.FieldReference(False, field_type=bool) + + +def base_config(): + return mlc.ConfigDict({ + 'data': { + 'common': { + 'features': { + 'aatype': [N_RES], + 'all_atom_mask': [N_RES, None], + 'all_atom_positions': [N_RES, None, None], + 'alt_chi_angles': [N_RES, None], + 'atom14_alt_gt_exists': [N_RES, None], + 'atom14_alt_gt_positions': [N_RES, None, None], + 'atom14_atom_exists': [N_RES, None], + 'atom14_atom_is_ambiguous': [N_RES, None], + 'atom14_gt_exists': [N_RES, None], + 'atom14_gt_positions': [N_RES, None, None], + 'atom37_atom_exists': [N_RES, None], + 'frame_mask': [N_RES], + 'true_frame_tensor': [N_RES, None, None], + 'bert_mask': [N_MSA, N_RES], + 'chi_angles_sin_cos': [N_RES, None, None], + 'chi_mask': [N_RES, None], + 'extra_msa_deletion_value': [N_EXTRA_MSA, N_RES], + 'extra_msa_has_deletion': [N_EXTRA_MSA, N_RES], + 'extra_msa': [N_EXTRA_MSA, N_RES], + 'extra_msa_mask': [N_EXTRA_MSA, N_RES], + 'extra_msa_row_mask': [N_EXTRA_MSA], + 'is_distillation': [], + 'msa_feat': [N_MSA, N_RES, None], + 'msa_mask': [N_MSA, N_RES], + 'msa_chains': [N_MSA, None], + 'msa_row_mask': [N_MSA], + 'num_recycling_iters': [], + 'pseudo_beta': [N_RES, None], + 'pseudo_beta_mask': [N_RES], + 'residue_index': [N_RES], + 'residx_atom14_to_atom37': [N_RES, None], + 'residx_atom37_to_atom14': [N_RES, None], + 'resolution': [], + 'rigidgroups_alt_gt_frames': [N_RES, None, None, None], + 'rigidgroups_group_exists': [N_RES, None], + 'rigidgroups_group_is_ambiguous': [N_RES, None], + 'rigidgroups_gt_exists': [N_RES, None], + 'rigidgroups_gt_frames': [N_RES, None, None, None], + 'seq_length': [], + 'seq_mask': [N_RES], + 'target_feat': [N_RES, None], + 'template_aatype': [N_TPL, N_RES], + 'template_all_atom_mask': [N_TPL, N_RES, None], + 'template_all_atom_positions': [N_TPL, N_RES, None, None], + 'template_alt_torsion_angles_sin_cos': [ + N_TPL, + N_RES, + None, + None, + ], + 'template_frame_mask': [N_TPL, N_RES], + 'template_frame_tensor': [N_TPL, N_RES, None, None], + 'template_mask': [N_TPL], + 'template_pseudo_beta': [N_TPL, N_RES, None], + 'template_pseudo_beta_mask': [N_TPL, N_RES], + 'template_sum_probs': [N_TPL, None], + 'template_torsion_angles_mask': [N_TPL, N_RES, None], + 'template_torsion_angles_sin_cos': + [N_TPL, N_RES, None, None], + 'true_msa': [N_MSA, N_RES], + 'use_clamped_fape': [], + 'assembly_num_chains': [1], + 'asym_id': [N_RES], + 'sym_id': [N_RES], + 'entity_id': [N_RES], + 'num_sym': [N_RES], + 'asym_len': [None], + 'cluster_bias_mask': [N_MSA], + }, + 'masked_msa': { + 'profile_prob': 0.1, + 'same_prob': 0.1, + 'uniform_prob': 0.1, + }, + 'block_delete_msa': { + 'msa_fraction_per_block': 0.3, + 'randomize_num_blocks': False, + 'num_blocks': 5, + 'min_num_msa': 16, + }, + 'random_delete_msa': { + 'max_msa_entry': 1 << 25, # := 33554432 + }, + 'v2_feature': + False, + 'gumbel_sample': + False, + 'max_extra_msa': + 1024, + 'msa_cluster_features': + True, + 'reduce_msa_clusters_by_max_templates': + True, + 'resample_msa_in_recycling': + True, + 'template_features': [ + 'template_all_atom_positions', + 'template_sum_probs', + 'template_aatype', + 'template_all_atom_mask', + ], + 'unsupervised_features': [ + 'aatype', + 'residue_index', + 'msa', + 'msa_chains', + 'num_alignments', + 'seq_length', + 'between_segment_residues', + 'deletion_matrix', + 'num_recycling_iters', + 'crop_and_fix_size_seed', + ], + 'recycling_features': [ + 'msa_chains', + 'msa_mask', + 'msa_row_mask', + 'bert_mask', + 'true_msa', + 'msa_feat', + 'extra_msa_deletion_value', + 'extra_msa_has_deletion', + 'extra_msa', + 'extra_msa_mask', + 'extra_msa_row_mask', + 'is_distillation', + ], + 'multimer_features': [ + 'assembly_num_chains', + 'asym_id', + 'sym_id', + 'num_sym', + 'entity_id', + 'asym_len', + 'cluster_bias_mask', + ], + 'use_templates': + use_templates, + 'is_multimer': + is_multimer, + 'use_template_torsion_angles': + use_templates, + 'max_recycling_iters': + max_recycling_iters, + }, + 'supervised': { + 'use_clamped_fape_prob': + 1.0, + 'supervised_features': [ + 'all_atom_mask', + 'all_atom_positions', + 'resolution', + 'use_clamped_fape', + 'is_distillation', + ], + }, + 'predict': { + 'fixed_size': True, + 'subsample_templates': False, + 'block_delete_msa': False, + 'random_delete_msa': True, + 'masked_msa_replace_fraction': 0.15, + 'max_msa_clusters': 128, + 'max_templates': 4, + 'num_ensembles': 2, + 'crop': False, + 'crop_size': None, + 'supervised': False, + 'biased_msa_by_chain': False, + 'share_mask': False, + }, + 'eval': { + 'fixed_size': True, + 'subsample_templates': False, + 'block_delete_msa': False, + 'random_delete_msa': True, + 'masked_msa_replace_fraction': 0.15, + 'max_msa_clusters': 128, + 'max_templates': 4, + 'num_ensembles': 1, + 'crop': False, + 'crop_size': None, + 'spatial_crop_prob': 0.5, + 'ca_ca_threshold': 10.0, + 'supervised': True, + 'biased_msa_by_chain': False, + 'share_mask': False, + }, + 'train': { + 'fixed_size': True, + 'subsample_templates': True, + 'block_delete_msa': True, + 'random_delete_msa': True, + 'masked_msa_replace_fraction': 0.15, + 'max_msa_clusters': 128, + 'max_templates': 4, + 'num_ensembles': 1, + 'crop': True, + 'crop_size': 256, + 'spatial_crop_prob': 0.5, + 'ca_ca_threshold': 10.0, + 'supervised': True, + 'use_clamped_fape_prob': 1.0, + 'max_distillation_msa_clusters': 1000, + 'biased_msa_by_chain': True, + 'share_mask': True, + }, + }, + 'globals': { + 'chunk_size': chunk_size, + 'block_size': None, + 'd_pair': d_pair, + 'd_msa': d_msa, + 'd_template': d_template, + 'd_extra_msa': d_extra_msa, + 'd_single': d_single, + 'eps': eps, + 'inf': inf, + 'max_recycling_iters': max_recycling_iters, + 'alphafold_original_mode': False, + }, + 'model': { + 'is_multimer': is_multimer, + 'input_embedder': { + 'tf_dim': 22, + 'msa_dim': 49, + 'd_pair': d_pair, + 'd_msa': d_msa, + 'relpos_k': 32, + 'max_relative_chain': 2, + }, + 'recycling_embedder': { + 'd_pair': d_pair, + 'd_msa': d_msa, + 'min_bin': 3.25, + 'max_bin': 20.75, + 'num_bins': 15, + 'inf': 1e8, + }, + 'template': { + 'distogram': { + 'min_bin': 3.25, + 'max_bin': 50.75, + 'num_bins': 39, + }, + 'template_angle_embedder': { + 'd_in': 57, + 'd_out': d_msa, + }, + 'template_pair_embedder': { + 'd_in': 88, + 'v2_d_in': [39, 1, 22, 22, 1, 1, 1, 1], + 'd_pair': d_pair, + 'd_out': d_template, + 'v2_feature': False, + }, + 'template_pair_stack': { + 'd_template': d_template, + 'd_hid_tri_att': 16, + 'd_hid_tri_mul': 64, + 'num_blocks': 2, + 'num_heads': 4, + 'pair_transition_n': 2, + 'dropout_rate': 0.25, + 'inf': 1e9, + 'tri_attn_first': True, + }, + 'template_pointwise_attention': { + 'enabled': True, + 'd_template': d_template, + 'd_pair': d_pair, + 'd_hid': 16, + 'num_heads': 4, + 'inf': 1e5, + }, + 'inf': 1e5, + 'eps': 1e-6, + 'enabled': use_templates, + 'embed_angles': use_templates, + }, + 'extra_msa': { + 'extra_msa_embedder': { + 'd_in': 25, + 'd_out': d_extra_msa, + }, + 'extra_msa_stack': { + 'd_msa': d_extra_msa, + 'd_pair': d_pair, + 'd_hid_msa_att': 8, + 'd_hid_opm': 32, + 'd_hid_mul': 128, + 'd_hid_pair_att': 32, + 'num_heads_msa': 8, + 'num_heads_pair': 4, + 'num_blocks': 4, + 'transition_n': 4, + 'msa_dropout': 0.15, + 'pair_dropout': 0.25, + 'inf': 1e9, + 'eps': 1e-10, + 'outer_product_mean_first': False, + }, + 'enabled': True, + }, + 'evoformer_stack': { + 'd_msa': d_msa, + 'd_pair': d_pair, + 'd_hid_msa_att': 32, + 'd_hid_opm': 32, + 'd_hid_mul': 128, + 'd_hid_pair_att': 32, + 'd_single': d_single, + 'num_heads_msa': 8, + 'num_heads_pair': 4, + 'num_blocks': 48, + 'transition_n': 4, + 'msa_dropout': 0.15, + 'pair_dropout': 0.25, + 'inf': 1e9, + 'eps': 1e-10, + 'outer_product_mean_first': False, + }, + 'structure_module': { + 'd_single': d_single, + 'd_pair': d_pair, + 'd_ipa': 16, + 'd_angle': 128, + 'num_heads_ipa': 12, + 'num_qk_points': 4, + 'num_v_points': 8, + 'dropout_rate': 0.1, + 'num_blocks': 8, + 'no_transition_layers': 1, + 'num_resnet_blocks': 2, + 'num_angles': 7, + 'trans_scale_factor': 10, + 'epsilon': 1e-12, + 'inf': 1e5, + 'separate_kv': False, + 'ipa_bias': True, + }, + 'heads': { + 'plddt': { + 'num_bins': 50, + 'd_in': d_single, + 'd_hid': 128, + }, + 'distogram': { + 'd_pair': d_pair, + 'num_bins': aux_distogram_bins, + 'disable_enhance_head': False, + }, + 'pae': { + 'd_pair': d_pair, + 'num_bins': aux_distogram_bins, + 'enabled': False, + 'iptm_weight': 0.8, + 'disable_enhance_head': False, + }, + 'masked_msa': { + 'd_msa': d_msa, + 'd_out': 23, + 'disable_enhance_head': False, + }, + 'experimentally_resolved': { + 'd_single': d_single, + 'd_out': 37, + 'enabled': False, + 'disable_enhance_head': False, + }, + }, + }, + 'loss': { + 'distogram': { + 'min_bin': 2.3125, + 'max_bin': 21.6875, + 'num_bins': 64, + 'eps': 1e-6, + 'weight': 0.3, + }, + 'experimentally_resolved': { + 'eps': 1e-8, + 'min_resolution': 0.1, + 'max_resolution': 3.0, + 'weight': 0.0, + }, + 'fape': { + 'backbone': { + 'clamp_distance': 10.0, + 'clamp_distance_between_chains': 30.0, + 'loss_unit_distance': 10.0, + 'loss_unit_distance_between_chains': 20.0, + 'weight': 0.5, + 'eps': 1e-4, + }, + 'sidechain': { + 'clamp_distance': 10.0, + 'length_scale': 10.0, + 'weight': 0.5, + 'eps': 1e-4, + }, + 'weight': 1.0, + }, + 'plddt': { + 'min_resolution': 0.1, + 'max_resolution': 3.0, + 'cutoff': 15.0, + 'num_bins': 50, + 'eps': 1e-10, + 'weight': 0.01, + }, + 'masked_msa': { + 'eps': 1e-8, + 'weight': 2.0, + }, + 'supervised_chi': { + 'chi_weight': 0.5, + 'angle_norm_weight': 0.01, + 'eps': 1e-6, + 'weight': 1.0, + }, + 'violation': { + 'violation_tolerance_factor': 12.0, + 'clash_overlap_tolerance': 1.5, + 'bond_angle_loss_weight': 0.3, + 'eps': 1e-6, + 'weight': 0.0, + }, + 'pae': { + 'max_bin': 31, + 'num_bins': 64, + 'min_resolution': 0.1, + 'max_resolution': 3.0, + 'eps': 1e-8, + 'weight': 0.0, + }, + 'repr_norm': { + 'weight': 0.01, + 'tolerance': 1.0, + }, + 'chain_centre_mass': { + 'weight': 0.0, + 'eps': 1e-8, + }, + }, + }) + + +def recursive_set(c: mlc.ConfigDict, key: str, value: Any, ignore: str = None): + with c.unlocked(): + for k, v in c.items(): + if ignore is not None and k == ignore: + continue + if isinstance(v, mlc.ConfigDict): + recursive_set(v, key, value) + elif k == key: + c[k] = value + + +def model_config(name, train=False): + c = copy.deepcopy(base_config()) + + def model_2_v2(c): + recursive_set(c, 'v2_feature', True) + recursive_set(c, 'gumbel_sample', True) + c.model.heads.masked_msa.d_out = 22 + c.model.structure_module.separate_kv = True + c.model.structure_module.ipa_bias = False + c.model.template.template_angle_embedder.d_in = 34 + return c + + def multimer(c): + recursive_set(c, 'is_multimer', True) + recursive_set(c, 'max_extra_msa', 1152) + recursive_set(c, 'max_msa_clusters', 128) + recursive_set(c, 'v2_feature', True) + recursive_set(c, 'gumbel_sample', True) + c.model.template.template_angle_embedder.d_in = 34 + c.model.template.template_pair_stack.tri_attn_first = False + c.model.template.template_pointwise_attention.enabled = False + c.model.heads.pae.enabled = True + # we forget to enable it in our training, so disable it here + c.model.heads.pae.disable_enhance_head = True + c.model.heads.masked_msa.d_out = 22 + c.model.structure_module.separate_kv = True + c.model.structure_module.ipa_bias = False + c.model.structure_module.trans_scale_factor = 20 + c.loss.pae.weight = 0.1 + c.model.input_embedder.tf_dim = 21 + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + c.loss.chain_centre_mass.weight = 1.0 + return c + + if name == 'model_1': + pass + elif name == 'model_1_ft': + recursive_set(c, 'max_extra_msa', 5120) + recursive_set(c, 'max_msa_clusters', 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + elif name == 'model_1_af2': + recursive_set(c, 'max_extra_msa', 5120) + recursive_set(c, 'max_msa_clusters', 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + c.loss.repr_norm.weight = 0 + c.model.heads.experimentally_resolved.enabled = True + c.loss.experimentally_resolved.weight = 0.01 + c.globals.alphafold_original_mode = True + elif name == 'model_2': + pass + elif name == 'model_init': + pass + elif name == 'model_init_af2': + c.globals.alphafold_original_mode = True + pass + elif name == 'model_2_ft': + recursive_set(c, 'max_extra_msa', 1024) + recursive_set(c, 'max_msa_clusters', 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + elif name == 'model_2_af2': + recursive_set(c, 'max_extra_msa', 1024) + recursive_set(c, 'max_msa_clusters', 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + c.loss.repr_norm.weight = 0 + c.model.heads.experimentally_resolved.enabled = True + c.loss.experimentally_resolved.weight = 0.01 + c.globals.alphafold_original_mode = True + elif name == 'model_2_v2': + c = model_2_v2(c) + elif name == 'model_2_v2_ft': + c = model_2_v2(c) + recursive_set(c, 'max_extra_msa', 1024) + recursive_set(c, 'max_msa_clusters', 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + elif name == 'model_3_af2' or name == 'model_4_af2': + recursive_set(c, 'max_extra_msa', 5120) + recursive_set(c, 'max_msa_clusters', 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + c.loss.repr_norm.weight = 0 + c.model.heads.experimentally_resolved.enabled = True + c.loss.experimentally_resolved.weight = 0.01 + c.globals.alphafold_original_mode = True + c.model.template.enabled = False + c.model.template.embed_angles = False + recursive_set(c, 'use_templates', False) + recursive_set(c, 'use_template_torsion_angles', False) + elif name == 'model_5_af2': + recursive_set(c, 'max_extra_msa', 1024) + recursive_set(c, 'max_msa_clusters', 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + c.loss.repr_norm.weight = 0 + c.model.heads.experimentally_resolved.enabled = True + c.loss.experimentally_resolved.weight = 0.01 + c.globals.alphafold_original_mode = True + c.model.template.enabled = False + c.model.template.embed_angles = False + recursive_set(c, 'use_templates', False) + recursive_set(c, 'use_template_torsion_angles', False) + elif name == 'multimer': + c = multimer(c) + elif name == 'multimer_ft': + c = multimer(c) + recursive_set(c, 'max_extra_msa', 1152) + recursive_set(c, 'max_msa_clusters', 256) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.5 + elif name == 'multimer_af2': + recursive_set(c, 'max_extra_msa', 1152) + recursive_set(c, 'max_msa_clusters', 256) + recursive_set(c, 'is_multimer', True) + recursive_set(c, 'v2_feature', True) + recursive_set(c, 'gumbel_sample', True) + c.model.template.template_angle_embedder.d_in = 34 + c.model.template.template_pair_stack.tri_attn_first = False + c.model.template.template_pointwise_attention.enabled = False + c.model.heads.pae.enabled = True + c.model.heads.experimentally_resolved.enabled = True + c.model.heads.masked_msa.d_out = 22 + c.model.structure_module.separate_kv = True + c.model.structure_module.ipa_bias = False + c.model.structure_module.trans_scale_factor = 20 + c.loss.pae.weight = 0.1 + c.loss.violation.weight = 0.5 + c.loss.experimentally_resolved.weight = 0.01 + c.model.input_embedder.tf_dim = 21 + c.globals.alphafold_original_mode = True + c.data.train.crop_size = 384 + c.loss.repr_norm.weight = 0 + c.loss.chain_centre_mass.weight = 1.0 + recursive_set(c, 'outer_product_mean_first', True) + else: + raise ValueError(f'invalid --model-name: {name}.') + if train: + c.globals.chunk_size = None + recursive_set(c, 'inf', 3e4) + recursive_set(c, 'eps', 1e-5, 'loss') + return c diff --git a/modelscope/models/science/unifold/data/__init__.py b/modelscope/models/science/unifold/data/__init__.py new file mode 100644 index 00000000..9821d212 --- /dev/null +++ b/modelscope/models/science/unifold/data/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data pipeline for model features.""" diff --git a/modelscope/models/science/unifold/data/data_ops.py b/modelscope/models/science/unifold/data/data_ops.py new file mode 100644 index 00000000..637aa0cd --- /dev/null +++ b/modelscope/models/science/unifold/data/data_ops.py @@ -0,0 +1,1397 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import itertools +from functools import reduce, wraps +from operator import add +from typing import List, MutableMapping, Optional + +import numpy as np +import torch +from unicore.data import data_utils +from unicore.utils import batched_gather, one_hot, tensor_tree_map, tree_map + +from modelscope.models.science.unifold.config import (N_EXTRA_MSA, N_MSA, + N_RES, N_TPL) +from modelscope.models.science.unifold.data import residue_constants as rc +from modelscope.models.science.unifold.modules.frame import Frame, Rotation + +NumpyDict = MutableMapping[str, np.ndarray] +TorchDict = MutableMapping[str, np.ndarray] + +protein: TorchDict + +MSA_FEATURE_NAMES = [ + 'msa', + 'deletion_matrix', + 'msa_mask', + 'msa_row_mask', + 'bert_mask', + 'true_msa', + 'msa_chains', +] + + +def cast_to_64bit_ints(protein): + # We keep all ints as int64 + for k, v in protein.items(): + if k.endswith('_mask'): + protein[k] = v.type(torch.float32) + elif v.dtype in (torch.int32, torch.uint8, torch.int8): + protein[k] = v.type(torch.int64) + + return protein + + +def make_seq_mask(protein): + protein['seq_mask'] = torch.ones( + protein['aatype'].shape, dtype=torch.float32) + return protein + + +def make_template_mask(protein): + protein['template_mask'] = torch.ones( + protein['template_aatype'].shape[0], dtype=torch.float32) + return protein + + +def curry1(f): + """Supply all arguments but the first.""" + + @wraps(f) + def fc(*args, **kwargs): + return lambda x: f(x, *args, **kwargs) + + return fc + + +def correct_msa_restypes(protein): + """Correct MSA restype to have the same order as rc.""" + protein['msa'] = protein['msa'].long() + new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + new_order = ( + torch.tensor(new_order_list, dtype=torch.int8).unsqueeze(-1).expand( + -1, protein['msa'].shape[1])) + protein['msa'] = torch.gather(new_order, 0, protein['msa']).long() + + return protein + + +def squeeze_features(protein): + """Remove singleton and repeated dimensions in protein features.""" + if len(protein['aatype'].shape) == 2: + protein['aatype'] = torch.argmax(protein['aatype'], dim=-1) + if 'resolution' in protein and len(protein['resolution'].shape) == 1: + # use tensor for resolution + protein['resolution'] = protein['resolution'][0] + for k in [ + 'domain_name', + 'msa', + 'num_alignments', + 'seq_length', + 'sequence', + 'superfamily', + 'deletion_matrix', + 'between_segment_residues', + 'residue_index', + 'template_all_atom_mask', + ]: + if k in protein and len(protein[k].shape): + final_dim = protein[k].shape[-1] + if isinstance(final_dim, int) and final_dim == 1: + if torch.is_tensor(protein[k]): + protein[k] = torch.squeeze(protein[k], dim=-1) + else: + protein[k] = np.squeeze(protein[k], axis=-1) + + for k in ['seq_length', 'num_alignments']: + if k in protein and len(protein[k].shape): + protein[k] = protein[k][0] + + return protein + + +@curry1 +def randomly_replace_msa_with_unknown(protein, replace_proportion): + """Replace a portion of the MSA with 'X'.""" + if replace_proportion > 0.0: + msa_mask = np.random.rand(protein['msa'].shape) < replace_proportion + x_idx = 20 + gap_idx = 21 + msa_mask = torch.logical_and(msa_mask, protein['msa'] != gap_idx) + protein['msa'] = torch.where(msa_mask, + torch.ones_like(protein['msa']) * x_idx, + protein['msa']) + aatype_mask = np.random.rand( + protein['aatype'].shape) < replace_proportion + + protein['aatype'] = torch.where( + aatype_mask, + torch.ones_like(protein['aatype']) * x_idx, + protein['aatype'], + ) + return protein + + +def gumbel_noise(shape): + """Generate Gumbel Noise of given Shape. + This generates samples from Gumbel(0, 1). + Args: + shape: Shape of noise to return. + Returns: + Gumbel noise of given shape. + """ + epsilon = 1e-6 + uniform_noise = torch.from_numpy(np.random.uniform(0, 1, shape)) + gumbel = -torch.log(-torch.log(uniform_noise + epsilon) + epsilon) + return gumbel + + +def gumbel_max_sample(logits): + """Samples from a probability distribution given by 'logits'. + This uses Gumbel-max trick to implement the sampling in an efficient manner. + Args: + logits: Logarithm of probabilities to sample from, probabilities can be + unnormalized. + Returns: + Sample from logprobs in one-hot form. + """ + z = gumbel_noise(logits.shape) + return torch.argmax(logits + z, dim=-1) + + +def gumbel_argsort_sample_idx(logits): + """Samples with replacement from a distribution given by 'logits'. + This uses Gumbel trick to implement the sampling an efficient manner. For a + distribution over k items this samples k times without replacement, so this + is effectively sampling a random permutation with probabilities over the + permutations derived from the logprobs. + Args: + logits: Logarithm of probabilities to sample from, probabilities can be + unnormalized. + Returns: + Sample from logprobs in index + """ + z = gumbel_noise(logits.shape) + return torch.argsort(logits + z, dim=-1, descending=True) + + +def uniform_permutation(num_seq): + shuffled = torch.from_numpy(np.random.permutation(num_seq - 1) + 1) + return torch.cat((torch.tensor([0]), shuffled), dim=0) + + +def gumbel_permutation(msa_mask, msa_chains=None): + has_msa = torch.sum(msa_mask.long(), dim=-1) > 0 + # default logits is zero + logits = torch.zeros_like(has_msa, dtype=torch.float32) + logits[~has_msa] = -1e6 + # one sample only + assert len(logits.shape) == 1 + # skip first row + logits = logits[1:] + has_msa = has_msa[1:] + if logits.shape[0] == 0: + return torch.tensor([0]) + if msa_chains is not None: + # skip first row + msa_chains = msa_chains[1:].reshape(-1) + msa_chains[~has_msa] = 0 + keys, counts = np.unique(msa_chains, return_counts=True) + num_has_msa = has_msa.sum() + num_pair = (msa_chains == 1).sum() + num_unpair = num_has_msa - num_pair + num_chains = (keys > 1).sum() + logits[has_msa] = 1.0 / (num_has_msa + 1e-6) + logits[~has_msa] = 0 + for k in keys: + if k > 1: + cur_mask = msa_chains == k + cur_cnt = cur_mask.sum() + if cur_cnt > 0: + logits[cur_mask] *= num_unpair / (num_chains * cur_cnt) + logits = torch.log(logits + 1e-6) + shuffled = gumbel_argsort_sample_idx(logits) + 1 + return torch.cat((torch.tensor([0]), shuffled), dim=0) + + +@curry1 +def sample_msa(protein, + max_seq, + keep_extra, + gumbel_sample=False, + biased_msa_by_chain=False): + """Sample MSA randomly, remaining sequences are stored are stored as `extra_*`.""" + num_seq = protein['msa'].shape[0] + num_sel = min(max_seq, num_seq) + if not gumbel_sample: + index_order = uniform_permutation(num_seq) + else: + msa_chains = ( + protein['msa_chains'] if + (biased_msa_by_chain and 'msa_chains' in protein) else None) + index_order = gumbel_permutation(protein['msa_mask'], msa_chains) + num_sel = min(max_seq, num_seq) + sel_seq, not_sel_seq = torch.split(index_order, + [num_sel, num_seq - num_sel]) + + for k in MSA_FEATURE_NAMES: + if k in protein: + if keep_extra: + protein['extra_' + k] = torch.index_select( + protein[k], 0, not_sel_seq) + protein[k] = torch.index_select(protein[k], 0, sel_seq) + + return protein + + +@curry1 +def sample_msa_distillation(protein, max_seq): + if 'is_distillation' in protein and protein['is_distillation'] == 1: + protein = sample_msa(max_seq, keep_extra=False)(protein) + return protein + + +@curry1 +def random_delete_msa(protein, config): + # to reduce the cost of msa features + num_seq = protein['msa'].shape[0] + seq_len = protein['msa'].shape[1] + max_seq = config.max_msa_entry // seq_len + if num_seq > max_seq: + keep_index = ( + torch.from_numpy( + np.random.choice(num_seq - 1, max_seq - 1, + replace=False)).long() + 1) + keep_index = torch.sort(keep_index)[0] + keep_index = torch.cat((torch.tensor([0]), keep_index), dim=0) + for k in MSA_FEATURE_NAMES: + if k in protein: + protein[k] = torch.index_select(protein[k], 0, keep_index) + return protein + + +@curry1 +def crop_extra_msa(protein, max_extra_msa): + num_seq = protein['extra_msa'].shape[0] + num_sel = min(max_extra_msa, num_seq) + select_indices = torch.from_numpy(np.random.permutation(num_seq)[:num_sel]) + for k in MSA_FEATURE_NAMES: + if 'extra_' + k in protein: + protein['extra_' + k] = torch.index_select(protein['extra_' + k], + 0, select_indices) + + return protein + + +def delete_extra_msa(protein): + for k in MSA_FEATURE_NAMES: + if 'extra_' + k in protein: + del protein['extra_' + k] + return protein + + +@curry1 +def block_delete_msa(protein, config): + if 'is_distillation' in protein and protein['is_distillation'] == 1: + return protein + num_seq = protein['msa'].shape[0] + if num_seq <= config.min_num_msa: + return protein + block_num_seq = torch.floor( + torch.tensor(num_seq, dtype=torch.float32) + * config.msa_fraction_per_block).to(torch.int32) + + if config.randomize_num_blocks: + nb = np.random.randint(0, config.num_blocks + 1) + else: + nb = config.num_blocks + + del_block_starts = torch.from_numpy(np.random.randint(0, num_seq, [nb])) + del_blocks = del_block_starts[:, None] + torch.arange(0, block_num_seq) + del_blocks = torch.clip(del_blocks, 0, num_seq - 1) + del_indices = torch.unique(del_blocks.view(-1)) + # add zeros to ensure cnt_zero > 1 + combined = torch.hstack((torch.arange(0, num_seq)[None], del_indices[None], + torch.zeros(2)[None])).long() + uniques, counts = combined.unique(return_counts=True) + difference = uniques[counts == 1] + # intersection = uniques[counts > 1] + keep_indices = difference.view(-1) + keep_indices = torch.hstack( + [torch.zeros(1).long()[None], keep_indices[None]]).view(-1) + assert int(keep_indices[0]) == 0 + for k in MSA_FEATURE_NAMES: + if k in protein: + protein[k] = torch.index_select(protein[k], 0, index=keep_indices) + return protein + + +@curry1 +def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0): + weights = torch.cat( + [torch.ones(21), gap_agreement_weight * torch.ones(1), + torch.zeros(1)], + 0, + ) + + msa_one_hot = one_hot(protein['msa'], 23) + sample_one_hot = protein['msa_mask'][:, :, None] * msa_one_hot + extra_msa_one_hot = one_hot(protein['extra_msa'], 23) + extra_one_hot = protein['extra_msa_mask'][:, :, None] * extra_msa_one_hot + + num_seq, num_res, _ = sample_one_hot.shape + extra_num_seq, _, _ = extra_one_hot.shape + + # Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights) + # in an optimized fashion to avoid possible memory or computation blowup. + a = extra_one_hot.view(extra_num_seq, num_res * 23) + b = (sample_one_hot * weights).view(num_seq, num_res * 23).transpose(0, 1) + agreement = a @ b + # Assign each sequence in the extra sequences to the closest MSA sample + protein['extra_cluster_assignment'] = torch.argmax(agreement, dim=1).long() + + return protein + + +def unsorted_segment_sum(data, segment_ids, num_segments): + assert len( + segment_ids.shape) == 1 and segment_ids.shape[0] == data.shape[0] + segment_ids = segment_ids.view(segment_ids.shape[0], + *((1, ) * len(data.shape[1:]))) + segment_ids = segment_ids.expand(data.shape) + shape = [num_segments] + list(data.shape[1:]) + tensor = torch.zeros(*shape).scatter_add_(0, segment_ids, data.float()) + tensor = tensor.type(data.dtype) + return tensor + + +def summarize_clusters(protein): + """Produce profile and deletion_matrix_mean within each cluster.""" + num_seq = protein['msa'].shape[0] + + def csum(x): + return unsorted_segment_sum(x, protein['extra_cluster_assignment'], + num_seq) + + mask = protein['extra_msa_mask'] + mask_counts = 1e-6 + protein['msa_mask'] + csum(mask) # Include center + + # TODO: this line is very slow + msa_sum = csum(mask[:, :, None] * one_hot(protein['extra_msa'], 23)) + msa_sum += one_hot(protein['msa'], 23) # Original sequence + protein['cluster_profile'] = msa_sum / mask_counts[:, :, None] + del msa_sum + + del_sum = csum(mask * protein['extra_deletion_matrix']) + del_sum += protein['deletion_matrix'] # Original sequence + protein['cluster_deletion_mean'] = del_sum / mask_counts + del del_sum + + return protein + + +@curry1 +def nearest_neighbor_clusters_v2(batch, gap_agreement_weight=0.0): + """Assign each extra MSA sequence to its nearest neighbor in sampled MSA.""" + + # Determine how much weight we assign to each agreement. In theory, we could + # use a full blosum matrix here, but right now let's just down-weight gap + # agreement because it could be spurious. + # Never put weight on agreeing on BERT mask. + + weights = torch.tensor( + [1.0] * 21 + [gap_agreement_weight] + [0.0], dtype=torch.float32) + + msa_mask = batch['msa_mask'] + extra_mask = batch['extra_msa_mask'] + msa_one_hot = one_hot(batch['msa'], 23) + extra_one_hot = one_hot(batch['extra_msa'], 23) + + msa_one_hot_masked = msa_mask[:, :, None] * msa_one_hot + extra_one_hot_masked = extra_mask[:, :, None] * extra_one_hot + + t1 = weights * msa_one_hot_masked + t1 = t1.view(t1.shape[0], t1.shape[1] * t1.shape[2]) + t2 = extra_one_hot_masked.view( + extra_one_hot.shape[0], + extra_one_hot.shape[1] * extra_one_hot.shape[2]) + agreement = t1 @ t2.T + + cluster_assignment = torch.nn.functional.softmax(1e3 * agreement, dim=0) + cluster_assignment *= torch.einsum('mr, nr->mn', msa_mask, extra_mask) + + cluster_count = torch.sum(cluster_assignment, dim=-1) + cluster_count += 1.0 # We always include the sequence itself. + + msa_sum = torch.einsum('nm, mrc->nrc', cluster_assignment, + extra_one_hot_masked) + msa_sum += msa_one_hot_masked + + cluster_profile = msa_sum / cluster_count[:, None, None] + + deletion_matrix = batch['deletion_matrix'] + extra_deletion_matrix = batch['extra_deletion_matrix'] + + del_sum = torch.einsum('nm, mc->nc', cluster_assignment, + extra_mask * extra_deletion_matrix) + del_sum += deletion_matrix # Original sequence. + cluster_deletion_mean = del_sum / cluster_count[:, None] + batch['cluster_profile'] = cluster_profile + batch['cluster_deletion_mean'] = cluster_deletion_mean + + return batch + + +def make_msa_mask(protein): + """Mask features are all ones, but will later be zero-padded.""" + if 'msa_mask' not in protein: + protein['msa_mask'] = torch.ones( + protein['msa'].shape, dtype=torch.float32) + protein['msa_row_mask'] = torch.ones((protein['msa'].shape[0]), + dtype=torch.float32) + return protein + + +def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask): + """Create pseudo beta features.""" + if aatype.shape[0] > 0: + is_gly = torch.eq(aatype, rc.restype_order['G']) + ca_idx = rc.atom_order['CA'] + cb_idx = rc.atom_order['CB'] + pseudo_beta = torch.where( + torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]), + all_atom_positions[..., ca_idx, :], + all_atom_positions[..., cb_idx, :], + ) + else: + pseudo_beta = all_atom_positions.new_zeros(*aatype.shape, 3) + if all_atom_mask is not None: + if aatype.shape[0] > 0: + pseudo_beta_mask = torch.where(is_gly, all_atom_mask[..., ca_idx], + all_atom_mask[..., cb_idx]) + else: + pseudo_beta_mask = torch.zeros_like(aatype).float() + return pseudo_beta, pseudo_beta_mask + else: + return pseudo_beta + + +@curry1 +def make_pseudo_beta(protein, prefix=''): + """Create pseudo-beta (alpha for glycine) position and mask.""" + assert prefix in ['', 'template_'] + ( + protein[prefix + 'pseudo_beta'], + protein[prefix + 'pseudo_beta_mask'], + ) = pseudo_beta_fn( + protein['template_aatype' if prefix else 'aatype'], + protein[prefix + 'all_atom_positions'], + protein['template_all_atom_mask' if prefix else 'all_atom_mask'], + ) + return protein + + +@curry1 +def add_constant_field(protein, key, value): + protein[key] = torch.tensor(value) + return protein + + +def shaped_categorical(probs, epsilon=1e-10): + ds = probs.shape + num_classes = ds[-1] + probs = torch.reshape(probs + epsilon, [-1, num_classes]) + gen = torch.Generator() + gen.manual_seed(np.random.randint(65535)) + counts = torch.multinomial(probs, 1, generator=gen) + return torch.reshape(counts, ds[:-1]) + + +def make_hhblits_profile(protein): + """Compute the HHblits MSA profile if not already present.""" + if 'hhblits_profile' in protein: + return protein + + # Compute the profile for every residue (over all MSA sequences). + msa_one_hot = one_hot(protein['msa'], 22) + + protein['hhblits_profile'] = torch.mean(msa_one_hot, dim=0) + return protein + + +def make_msa_profile(batch): + """Compute the MSA profile.""" + # Compute the profile for every residue (over all MSA sequences). + oh = one_hot(batch['msa'], 22) + mask = batch['msa_mask'][:, :, None] + oh *= mask + return oh.sum(dim=0) / (mask.sum(dim=0) + 1e-10) + + +def make_hhblits_profile_v2(protein): + """Compute the HHblits MSA profile if not already present.""" + if 'hhblits_profile' in protein: + return protein + protein['hhblits_profile'] = make_msa_profile(protein) + return protein + + +def share_mask_by_entity(mask_position, protein): # new in unifold + if 'num_sym' not in protein: + return mask_position + entity_id = protein['entity_id'] + sym_id = protein['sym_id'] + num_sym = protein['num_sym'] + unique_entity_ids = entity_id.unique() + first_sym_mask = sym_id == 1 + for cur_entity_id in unique_entity_ids: + cur_entity_mask = entity_id == cur_entity_id + cur_num_sym = int(num_sym[cur_entity_mask][0]) + if cur_num_sym > 1: + cur_sym_mask = first_sym_mask & cur_entity_mask + cur_sym_bert_mask = mask_position[:, cur_sym_mask] + mask_position[:, cur_entity_mask] = cur_sym_bert_mask.repeat( + 1, cur_num_sym) + return mask_position + + +@curry1 +def make_masked_msa(protein, + config, + replace_fraction, + gumbel_sample=False, + share_mask=False): + """Create data for BERT on raw MSA.""" + # Add a random amino acid uniformly. + random_aa = torch.tensor([0.05] * 20 + [0.0, 0.0], dtype=torch.float32) + + categorical_probs = ( + config.uniform_prob * random_aa + + config.profile_prob * protein['hhblits_profile'] + + config.same_prob * one_hot(protein['msa'], 22)) + + # Put all remaining probability on [MASK] which is a new column + pad_shapes = list( + reduce(add, [(0, 0) for _ in range(len(categorical_probs.shape))])) + pad_shapes[1] = 1 + mask_prob = 1.0 - config.profile_prob - config.same_prob - config.uniform_prob + assert mask_prob >= 0.0 + categorical_probs = torch.nn.functional.pad( + categorical_probs, pad_shapes, value=mask_prob) + sh = protein['msa'].shape + mask_position = torch.from_numpy(np.random.rand(*sh) < replace_fraction) + mask_position &= protein['msa_mask'].bool() + + if 'bert_mask' in protein: + mask_position &= protein['bert_mask'].bool() + + if share_mask: + mask_position = share_mask_by_entity(mask_position, protein) + if gumbel_sample: + logits = torch.log(categorical_probs + 1e-6) + bert_msa = gumbel_max_sample(logits) + else: + bert_msa = shaped_categorical(categorical_probs) + bert_msa = torch.where(mask_position, bert_msa, protein['msa']) + bert_msa *= protein['msa_mask'].long() + + # Mix real and masked MSA + protein['bert_mask'] = mask_position.to(torch.float32) + protein['true_msa'] = protein['msa'] + protein['msa'] = bert_msa + + return protein + + +@curry1 +def make_fixed_size( + protein, + shape_schema, + msa_cluster_size, + extra_msa_size, + num_res=0, + num_templates=0, +): + """Guess at the MSA and sequence dimension to make fixed size.""" + + def get_pad_size(cur_size, multiplier=4): + return max(multiplier, + ((cur_size + multiplier - 1) // multiplier) * multiplier) + + if num_res is not None: + input_num_res = ( + protein['aatype'].shape[0] + if 'aatype' in protein else protein['msa_mask'].shape[1]) + if input_num_res != num_res: + num_res = get_pad_size(input_num_res, 4) + if 'extra_msa_mask' in protein: + input_extra_msa_size = protein['extra_msa_mask'].shape[0] + if input_extra_msa_size != extra_msa_size: + extra_msa_size = get_pad_size(input_extra_msa_size, 8) + pad_size_map = { + N_RES: num_res, + N_MSA: msa_cluster_size, + N_EXTRA_MSA: extra_msa_size, + N_TPL: num_templates, + } + + for k, v in protein.items(): + # Don't transfer this to the accelerator. + if k == 'extra_cluster_assignment': + continue + shape = list(v.shape) + schema = shape_schema[k] + msg = 'Rank mismatch between shape and shape schema for' + assert len(shape) == len(schema), f'{msg} {k}: {shape} vs {schema}' + pad_size = [ + pad_size_map.get(s2, None) or s1 + for (s1, s2) in zip(shape, schema) + ] + + padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)] + padding.reverse() + padding = list(itertools.chain(*padding)) + if padding: + protein[k] = torch.nn.functional.pad(v, padding) + protein[k] = torch.reshape(protein[k], pad_size) + + return protein + + +def make_target_feat(protein): + """Create and concatenate MSA features.""" + protein['aatype'] = protein['aatype'].long() + + if 'between_segment_residues' in protein: + has_break = torch.clip( + protein['between_segment_residues'].to(torch.float32), 0, 1) + else: + has_break = torch.zeros_like(protein['aatype'], dtype=torch.float32) + if 'asym_len' in protein: + asym_len = protein['asym_len'] + entity_ends = torch.cumsum(asym_len, dim=-1)[:-1] + has_break[entity_ends] = 1.0 + has_break = has_break.float() + aatype_1hot = one_hot(protein['aatype'], 21) + target_feat = [ + torch.unsqueeze(has_break, dim=-1), + aatype_1hot, # Everyone gets the original sequence. + ] + protein['target_feat'] = torch.cat(target_feat, dim=-1) + return protein + + +def make_msa_feat(protein): + """Create and concatenate MSA features.""" + msa_1hot = one_hot(protein['msa'], 23) + has_deletion = torch.clip(protein['deletion_matrix'], 0.0, 1.0) + deletion_value = torch.atan( + protein['deletion_matrix'] / 3.0) * (2.0 / np.pi) + msa_feat = [ + msa_1hot, + torch.unsqueeze(has_deletion, dim=-1), + torch.unsqueeze(deletion_value, dim=-1), + ] + if 'cluster_profile' in protein: + deletion_mean_value = torch.atan( + protein['cluster_deletion_mean'] / 3.0) * (2.0 / np.pi) + msa_feat.extend([ + protein['cluster_profile'], + torch.unsqueeze(deletion_mean_value, dim=-1), + ]) + + if 'extra_deletion_matrix' in protein: + protein['extra_msa_has_deletion'] = torch.clip( + protein['extra_deletion_matrix'], 0.0, 1.0) + protein['extra_msa_deletion_value'] = torch.atan( + protein['extra_deletion_matrix'] / 3.0) * (2.0 / np.pi) + + protein['msa_feat'] = torch.cat(msa_feat, dim=-1) + return protein + + +def make_msa_feat_v2(batch): + """Create and concatenate MSA features.""" + msa_1hot = one_hot(batch['msa'], 23) + deletion_matrix = batch['deletion_matrix'] + has_deletion = torch.clip(deletion_matrix, 0.0, 1.0)[..., None] + deletion_value = (torch.atan(deletion_matrix / 3.0) * (2.0 / np.pi))[..., + None] + + deletion_mean_value = ( + torch.arctan(batch['cluster_deletion_mean'] / 3.0) * # noqa W504 + (2.0 / np.pi))[..., None] + + msa_feat = [ + msa_1hot, + has_deletion, + deletion_value, + batch['cluster_profile'], + deletion_mean_value, + ] + batch['msa_feat'] = torch.concat(msa_feat, dim=-1) + return batch + + +@curry1 +def make_extra_msa_feat(batch, num_extra_msa): + # 23 = 20 amino acids + 'X' for unknown + gap + bert mask + extra_msa = batch['extra_msa'][:num_extra_msa] + deletion_matrix = batch['extra_deletion_matrix'][:num_extra_msa] + has_deletion = torch.clip(deletion_matrix, 0.0, 1.0) + deletion_value = torch.atan(deletion_matrix / 3.0) * (2.0 / np.pi) + extra_msa_mask = batch['extra_msa_mask'][:num_extra_msa] + batch['extra_msa'] = extra_msa + batch['extra_msa_mask'] = extra_msa_mask + batch['extra_msa_has_deletion'] = has_deletion + batch['extra_msa_deletion_value'] = deletion_value + return batch + + +@curry1 +def select_feat(protein, feature_list): + return {k: v for k, v in protein.items() if k in feature_list} + + +def make_atom14_masks(protein): + """Construct denser atom positions (14 dimensions instead of 37).""" + + if 'atom14_atom_exists' in protein: # lazy move + return protein + + restype_atom14_to_atom37 = torch.tensor( + rc.restype_atom14_to_atom37, + dtype=torch.int64, + device=protein['aatype'].device, + ) + restype_atom37_to_atom14 = torch.tensor( + rc.restype_atom37_to_atom14, + dtype=torch.int64, + device=protein['aatype'].device, + ) + restype_atom14_mask = torch.tensor( + rc.restype_atom14_mask, + dtype=torch.float32, + device=protein['aatype'].device, + ) + restype_atom37_mask = torch.tensor( + rc.restype_atom37_mask, + dtype=torch.float32, + device=protein['aatype'].device) + + protein_aatype = protein['aatype'].long() + protein['residx_atom14_to_atom37'] = restype_atom14_to_atom37[ + protein_aatype].long() + protein['residx_atom37_to_atom14'] = restype_atom37_to_atom14[ + protein_aatype].long() + protein['atom14_atom_exists'] = restype_atom14_mask[protein_aatype] + protein['atom37_atom_exists'] = restype_atom37_mask[protein_aatype] + + return protein + + +def make_atom14_masks_np(batch): + batch = tree_map(lambda n: torch.tensor(n), batch, np.ndarray) + out = make_atom14_masks(batch) + out = tensor_tree_map(lambda t: np.array(t), out) + return out + + +def make_atom14_positions(protein): + """Constructs denser atom positions (14 dimensions instead of 37).""" + protein['aatype'] = protein['aatype'].long() + protein['all_atom_mask'] = protein['all_atom_mask'].float() + protein['all_atom_positions'] = protein['all_atom_positions'].float() + residx_atom14_mask = protein['atom14_atom_exists'] + residx_atom14_to_atom37 = protein['residx_atom14_to_atom37'] + + # Create a mask for known ground truth positions. + residx_atom14_gt_mask = residx_atom14_mask * batched_gather( + protein['all_atom_mask'], + residx_atom14_to_atom37, + dim=-1, + num_batch_dims=len(protein['all_atom_mask'].shape[:-1]), + ) + + # Gather the ground truth positions. + residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * ( + batched_gather( + protein['all_atom_positions'], + residx_atom14_to_atom37, + dim=-2, + num_batch_dims=len(protein['all_atom_positions'].shape[:-2]), + )) + + protein['atom14_atom_exists'] = residx_atom14_mask + protein['atom14_gt_exists'] = residx_atom14_gt_mask + protein['atom14_gt_positions'] = residx_atom14_gt_positions + + renaming_matrices = torch.tensor( + rc.renaming_matrices, + dtype=protein['all_atom_mask'].dtype, + device=protein['all_atom_mask'].device, + ) + + # Pick the transformation matrices for the given residue sequence + # shape (num_res, 14, 14). + renaming_transform = renaming_matrices[protein['aatype']] + + # Apply it to the ground truth positions. shape (num_res, 14, 3). + alternative_gt_positions = torch.einsum('...rac,...rab->...rbc', + residx_atom14_gt_positions, + renaming_transform) + protein['atom14_alt_gt_positions'] = alternative_gt_positions + + # Create the mask for the alternative ground truth (differs from the + # ground truth mask, if only one of the atoms in an ambiguous pair has a + # ground truth position). + alternative_gt_mask = torch.einsum('...ra,...rab->...rb', + residx_atom14_gt_mask, + renaming_transform) + protein['atom14_alt_gt_exists'] = alternative_gt_mask + + restype_atom14_is_ambiguous = torch.tensor( + rc.restype_atom14_is_ambiguous, + dtype=protein['all_atom_mask'].dtype, + device=protein['all_atom_mask'].device, + ) + # From this create an ambiguous_mask for the given sequence. + protein['atom14_atom_is_ambiguous'] = restype_atom14_is_ambiguous[ + protein['aatype']] + + return protein + + +def atom37_to_frames(protein, eps=1e-8): + # TODO: extract common part and put them into residue constants. + aatype = protein['aatype'] + all_atom_positions = protein['all_atom_positions'] + all_atom_mask = protein['all_atom_mask'] + + batch_dims = len(aatype.shape[:-1]) + + restype_rigidgroup_base_atom_names = np.full([21, 8, 3], '', dtype=object) + restype_rigidgroup_base_atom_names[:, 0, :] = ['C', 'CA', 'N'] + restype_rigidgroup_base_atom_names[:, 3, :] = ['CA', 'C', 'O'] + + for restype, restype_letter in enumerate(rc.restypes): + resname = rc.restype_1to3[restype_letter] + for chi_idx in range(4): + if rc.chi_angles_mask[restype][chi_idx]: + names = rc.chi_angles_atoms[resname][chi_idx] + restype_rigidgroup_base_atom_names[restype, + chi_idx + 4, :] = names[1:] + + restype_rigidgroup_mask = all_atom_mask.new_zeros( + (*aatype.shape[:-1], 21, 8), ) + restype_rigidgroup_mask[..., 0] = 1 + restype_rigidgroup_mask[..., 3] = 1 + restype_rigidgroup_mask[..., :20, + 4:] = all_atom_mask.new_tensor(rc.chi_angles_mask) + + lookuptable = rc.atom_order.copy() + lookuptable[''] = 0 + lookup = np.vectorize(lambda x: lookuptable[x]) + restype_rigidgroup_base_atom37_idx = lookup( + restype_rigidgroup_base_atom_names, ) + restype_rigidgroup_base_atom37_idx = aatype.new_tensor( + restype_rigidgroup_base_atom37_idx, ) + restype_rigidgroup_base_atom37_idx = restype_rigidgroup_base_atom37_idx.view( + *((1, ) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape) + + residx_rigidgroup_base_atom37_idx = batched_gather( + restype_rigidgroup_base_atom37_idx, + aatype, + dim=-3, + num_batch_dims=batch_dims, + ) + + base_atom_pos = batched_gather( + all_atom_positions, + residx_rigidgroup_base_atom37_idx, + dim=-2, + num_batch_dims=len(all_atom_positions.shape[:-2]), + ) + + gt_frames = Frame.from_3_points( + p_neg_x_axis=base_atom_pos[..., 0, :], + origin=base_atom_pos[..., 1, :], + p_xy_plane=base_atom_pos[..., 2, :], + eps=eps, + ) + + group_exists = batched_gather( + restype_rigidgroup_mask, + aatype, + dim=-2, + num_batch_dims=batch_dims, + ) + + gt_atoms_exist = batched_gather( + all_atom_mask, + residx_rigidgroup_base_atom37_idx, + dim=-1, + num_batch_dims=len(all_atom_mask.shape[:-1]), + ) + gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists + + rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device) + rots = torch.tile(rots, (*((1, ) * batch_dims), 8, 1, 1)) + rots[..., 0, 0, 0] = -1 + rots[..., 0, 2, 2] = -1 + rots = Rotation(mat=rots) + + gt_frames = gt_frames.compose(Frame(rots, None)) + + restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros( + *((1, ) * batch_dims), 21, 8) + restype_rigidgroup_rots = torch.eye( + 3, dtype=all_atom_mask.dtype, device=aatype.device) + restype_rigidgroup_rots = torch.tile( + restype_rigidgroup_rots, + (*((1, ) * batch_dims), 21, 8, 1, 1), + ) + + for resname, _ in rc.residue_atom_renaming_swaps.items(): + restype = rc.restype_order[rc.restype_3to1[resname]] + chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1) + restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1 + restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1 + restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1 + + residx_rigidgroup_is_ambiguous = batched_gather( + restype_rigidgroup_is_ambiguous, + aatype, + dim=-2, + num_batch_dims=batch_dims, + ) + + residx_rigidgroup_ambiguity_rot = batched_gather( + restype_rigidgroup_rots, + aatype, + dim=-4, + num_batch_dims=batch_dims, + ) + + residx_rigidgroup_ambiguity_rot = Rotation( + mat=residx_rigidgroup_ambiguity_rot) + alt_gt_frames = gt_frames.compose( + Frame(residx_rigidgroup_ambiguity_rot, None)) + + gt_frames_tensor = gt_frames.to_tensor_4x4() + alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4() + + protein['rigidgroups_gt_frames'] = gt_frames_tensor + protein['rigidgroups_gt_exists'] = gt_exists + protein['rigidgroups_group_exists'] = group_exists + protein['rigidgroups_group_is_ambiguous'] = residx_rigidgroup_is_ambiguous + protein['rigidgroups_alt_gt_frames'] = alt_gt_frames_tensor + + return protein + + +@curry1 +def atom37_to_torsion_angles( + protein, + prefix='', +): + aatype = protein[prefix + 'aatype'] + all_atom_positions = protein[prefix + 'all_atom_positions'] + all_atom_mask = protein[prefix + 'all_atom_mask'] + if aatype.shape[-1] == 0: + base_shape = aatype.shape + protein[prefix + + 'torsion_angles_sin_cos'] = all_atom_positions.new_zeros( + *base_shape, 7, 2) + protein[prefix + + 'alt_torsion_angles_sin_cos'] = all_atom_positions.new_zeros( + *base_shape, 7, 2) + protein[prefix + 'torsion_angles_mask'] = all_atom_positions.new_zeros( + *base_shape, 7) + return protein + + aatype = torch.clamp(aatype, max=20) + + pad = all_atom_positions.new_zeros( + [*all_atom_positions.shape[:-3], 1, 37, 3]) + prev_all_atom_positions = torch.cat( + [pad, all_atom_positions[..., :-1, :, :]], dim=-3) + + pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37]) + prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2) + + pre_omega_atom_pos = torch.cat( + [prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]], + dim=-2, + ) + phi_atom_pos = torch.cat( + [prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]], + dim=-2, + ) + psi_atom_pos = torch.cat( + [all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]], + dim=-2, + ) + + pre_omega_mask = torch.prod( + prev_all_atom_mask[..., 1:3], dim=-1) * torch.prod( + all_atom_mask[..., :2], dim=-1) + phi_mask = prev_all_atom_mask[..., 2] * torch.prod( + all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype) + psi_mask = ( + torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype) + * all_atom_mask[..., 4]) + + chi_atom_indices = torch.as_tensor( + rc.chi_atom_indices, device=aatype.device) + + atom_indices = chi_atom_indices[..., aatype, :, :] + chis_atom_pos = batched_gather(all_atom_positions, atom_indices, -2, + len(atom_indices.shape[:-2])) + + chi_angles_mask = list(rc.chi_angles_mask) + chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) + chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask) + + chis_mask = chi_angles_mask[aatype, :] + + chi_angle_atoms_mask = batched_gather( + all_atom_mask, + atom_indices, + dim=-1, + num_batch_dims=len(atom_indices.shape[:-2]), + ) + chi_angle_atoms_mask = torch.prod( + chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype) + chis_mask = chis_mask * chi_angle_atoms_mask + + torsions_atom_pos = torch.cat( + [ + pre_omega_atom_pos[..., None, :, :], + phi_atom_pos[..., None, :, :], + psi_atom_pos[..., None, :, :], + chis_atom_pos, + ], + dim=-3, + ) + + torsion_angles_mask = torch.cat( + [ + pre_omega_mask[..., None], + phi_mask[..., None], + psi_mask[..., None], + chis_mask, + ], + dim=-1, + ) + + torsion_frames = Frame.from_3_points( + torsions_atom_pos[..., 1, :], + torsions_atom_pos[..., 2, :], + torsions_atom_pos[..., 0, :], + eps=1e-8, + ) + + fourth_atom_rel_pos = torsion_frames.invert().apply( + torsions_atom_pos[..., 3, :]) + + torsion_angles_sin_cos = torch.stack( + [fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1) + + denom = torch.sqrt( + torch.sum( + torch.square(torsion_angles_sin_cos), + dim=-1, + dtype=torsion_angles_sin_cos.dtype, + keepdims=True, + ) + 1e-8) + torsion_angles_sin_cos = torsion_angles_sin_cos / denom + + torsion_angles_sin_cos = ( + torsion_angles_sin_cos + * all_atom_mask.new_tensor([1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0], )[ + ((None, ) * len(torsion_angles_sin_cos.shape[:-2])) + + (slice(None), None)]) + + chi_is_ambiguous = torsion_angles_sin_cos.new_tensor( + rc.chi_pi_periodic, )[aatype, ...] + + mirror_torsion_angles = torch.cat( + [ + all_atom_mask.new_ones(*aatype.shape, 3), + 1.0 - 2.0 * chi_is_ambiguous, + ], + dim=-1, + ) + + alt_torsion_angles_sin_cos = ( + torsion_angles_sin_cos * mirror_torsion_angles[..., None]) + + if prefix == '': + # consistent to uni-fold. use [1, 0] placeholder + placeholder_torsions = torch.stack( + [ + torch.ones(torsion_angles_sin_cos.shape[:-1]), + torch.zeros(torsion_angles_sin_cos.shape[:-1]), + ], + dim=-1, + ) + torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask[ + ..., + None] + placeholder_torsions * (1 - torsion_angles_mask[..., None]) + alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos * torsion_angles_mask[ + ..., + None] + placeholder_torsions * (1 - torsion_angles_mask[..., None]) + + protein[prefix + 'torsion_angles_sin_cos'] = torsion_angles_sin_cos + protein[prefix + 'alt_torsion_angles_sin_cos'] = alt_torsion_angles_sin_cos + protein[prefix + 'torsion_angles_mask'] = torsion_angles_mask + + return protein + + +def get_backbone_frames(protein): + protein['true_frame_tensor'] = protein['rigidgroups_gt_frames'][..., + 0, :, :] + protein['frame_mask'] = protein['rigidgroups_gt_exists'][..., 0] + + return protein + + +def get_chi_angles(protein): + dtype = protein['all_atom_mask'].dtype + protein['chi_angles_sin_cos'] = ( + protein['torsion_angles_sin_cos'][..., 3:, :]).to(dtype) + protein['chi_mask'] = protein['torsion_angles_mask'][..., 3:].to(dtype) + + return protein + + +@curry1 +def crop_templates( + protein, + max_templates, + subsample_templates=False, +): + if 'template_mask' in protein: + num_templates = protein['template_mask'].shape[-1] + else: + num_templates = 0 + + # don't sample when there are no templates + if num_templates > 0: + if subsample_templates: + # af2's sampling, min(4, uniform[0, n]) + max_templates = min(max_templates, + np.random.randint(0, num_templates + 1)) + template_idx = torch.tensor( + np.random.choice(num_templates, max_templates, replace=False), + dtype=torch.int64, + ) + else: + # use top templates + template_idx = torch.arange( + min(num_templates, max_templates), dtype=torch.int64) + for k, v in protein.items(): + if k.startswith('template'): + try: + v = v[template_idx] + except Exception as ex: + print(ex.__class__, ex) + print('num_templates', num_templates) + print(k, v.shape) + print('protein:', protein) + print( + 'protein_shape:', + { + k: v.shape + for k, v in protein.items() if 'shape' in dir(v) + }, + ) + protein[k] = v + + return protein + + +@curry1 +def crop_to_size_single(protein, crop_size, shape_schema, seed): + """crop to size.""" + num_res = ( + protein['aatype'].shape[0] + if 'aatype' in protein else protein['msa_mask'].shape[1]) + crop_idx = get_single_crop_idx(num_res, crop_size, seed) + protein = apply_crop_idx(protein, shape_schema, crop_idx) + return protein + + +@curry1 +def crop_to_size_multimer(protein, crop_size, shape_schema, seed, + spatial_crop_prob, ca_ca_threshold): + """crop to size.""" + with data_utils.numpy_seed(seed, key='multimer_crop'): + use_spatial_crop = np.random.rand() < spatial_crop_prob + is_distillation = 'is_distillation' in protein and protein[ + 'is_distillation'] == 1 + if is_distillation: + return crop_to_size_single( + crop_size=crop_size, shape_schema=shape_schema, seed=seed)( + protein) + elif use_spatial_crop: + crop_idx = get_spatial_crop_idx(protein, crop_size, seed, + ca_ca_threshold) + else: + crop_idx = get_contiguous_crop_idx(protein, crop_size, seed) + return apply_crop_idx(protein, shape_schema, crop_idx) + + +def get_single_crop_idx(num_res: NumpyDict, crop_size: int, + random_seed: Optional[int]) -> torch.Tensor: + + if num_res < crop_size: + return torch.arange(num_res) + with data_utils.numpy_seed(random_seed): + crop_start = int(np.random.randint(0, num_res - crop_size + 1)) + return torch.arange(crop_start, crop_start + crop_size) + + +def get_crop_sizes_each_chain( + asym_len: torch.Tensor, + crop_size: int, + random_seed: Optional[int] = None, + use_multinomial: bool = False, +) -> torch.Tensor: + """get crop sizes for contiguous crop""" + if not use_multinomial: + with data_utils.numpy_seed( + random_seed, key='multimer_contiguous_perm'): + shuffle_idx = np.random.permutation(len(asym_len)) + num_left = asym_len.sum() + num_budget = torch.tensor(crop_size) + crop_sizes = [0 for _ in asym_len] + for j, idx in enumerate(shuffle_idx): + this_len = asym_len[idx] + num_left -= this_len + # num res at most we can keep in this ent + max_size = min(num_budget, this_len) + # num res at least we shall keep in this ent + min_size = min(this_len, max(0, num_budget - num_left)) + with data_utils.numpy_seed( + random_seed, j, key='multimer_contiguous_crop_size'): + this_crop_size = int( + np.random.randint( + low=int(min_size), high=int(max_size) + 1)) + num_budget -= this_crop_size + crop_sizes[idx] = this_crop_size + crop_sizes = torch.tensor(crop_sizes) + else: # use multinomial + # TODO: better multimer + entity_probs = asym_len / torch.sum(asym_len) + crop_sizes = torch.from_numpy( + np.random.multinomial(crop_size, pvals=entity_probs)) + crop_sizes = torch.min(crop_sizes, asym_len) + return crop_sizes + + +def get_contiguous_crop_idx( + protein: NumpyDict, + crop_size: int, + random_seed: Optional[int] = None, + use_multinomial: bool = False, +) -> torch.Tensor: + + num_res = protein['aatype'].shape[0] + if num_res <= crop_size: + return torch.arange(num_res) + + assert 'asym_len' in protein + asym_len = protein['asym_len'] + + crop_sizes = get_crop_sizes_each_chain(asym_len, crop_size, random_seed, + use_multinomial) + crop_idxs = [] + asym_offset = torch.tensor(0, dtype=torch.int64) + with data_utils.numpy_seed( + random_seed, key='multimer_contiguous_crop_start_idx'): + for ll, csz in zip(asym_len, crop_sizes): + this_start = np.random.randint(0, int(ll - csz) + 1) + crop_idxs.append( + torch.arange(asym_offset + this_start, + asym_offset + this_start + csz)) + asym_offset += ll + + return torch.concat(crop_idxs) + + +def get_spatial_crop_idx( + protein: NumpyDict, + crop_size: int, + random_seed: int, + ca_ca_threshold: float, + inf: float = 3e4, +) -> List[int]: + + ca_idx = rc.atom_order['CA'] + ca_coords = protein['all_atom_positions'][..., ca_idx, :] + ca_mask = protein['all_atom_mask'][..., ca_idx].bool() + # if there are not enough atoms to construct interface, use contiguous crop + if (ca_mask.sum(dim=-1) <= 1).all(): + return get_contiguous_crop_idx(protein, crop_size, random_seed) + + pair_mask = ca_mask[..., None] * ca_mask[..., None, :] + ca_distances = get_pairwise_distances(ca_coords) + + interface_candidates = get_interface_candidates(ca_distances, + protein['asym_id'], + pair_mask, ca_ca_threshold) + + if torch.any(interface_candidates): + with data_utils.numpy_seed(random_seed, key='multimer_spatial_crop'): + target_res = int(np.random.choice(interface_candidates)) + else: + return get_contiguous_crop_idx(protein, crop_size, random_seed) + + to_target_distances = ca_distances[target_res] + # set inf to non-position residues + to_target_distances[~ca_mask] = inf + break_tie = ( + torch.arange( + 0, + to_target_distances.shape[-1], + device=to_target_distances.device).float() * 1e-3) + to_target_distances += break_tie + ret = torch.argsort(to_target_distances)[:crop_size] + return ret.sort().values + + +def get_pairwise_distances(coords: torch.Tensor) -> torch.Tensor: + coord_diff = coords.unsqueeze(-2) - coords.unsqueeze(-3) + return torch.sqrt(torch.sum(coord_diff**2, dim=-1)) + + +def get_interface_candidates( + ca_distances: torch.Tensor, + asym_id: torch.Tensor, + pair_mask: torch.Tensor, + ca_ca_threshold, +) -> torch.Tensor: + + in_same_asym = asym_id[..., None] == asym_id[..., None, :] + # set distance in the same entity to zero + ca_distances = ca_distances * (1.0 - in_same_asym.float()) * pair_mask + cnt_interfaces = torch.sum( + (ca_distances > 0) & (ca_distances < ca_ca_threshold), dim=-1) + interface_candidates = cnt_interfaces.nonzero(as_tuple=True)[0] + return interface_candidates + + +def apply_crop_idx(protein, shape_schema, crop_idx): + cropped_protein = {} + for k, v in protein.items(): + if k not in shape_schema: # skip items with unknown shape schema + continue + for i, dim_size in enumerate(shape_schema[k]): + if dim_size == N_RES: + v = torch.index_select(v, i, crop_idx) + cropped_protein[k] = v + return cropped_protein diff --git a/modelscope/models/science/unifold/data/msa_pairing.py b/modelscope/models/science/unifold/data/msa_pairing.py new file mode 100644 index 00000000..cc65962c --- /dev/null +++ b/modelscope/models/science/unifold/data/msa_pairing.py @@ -0,0 +1,526 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pairing logic for multimer data """ + +import collections +from typing import Dict, Iterable, List, Sequence + +import numpy as np +import pandas as pd +import scipy.linalg + +from .data_ops import NumpyDict +from .residue_constants import restypes_with_x_and_gap + +MSA_GAP_IDX = restypes_with_x_and_gap.index('-') +SEQUENCE_GAP_CUTOFF = 0.5 +SEQUENCE_SIMILARITY_CUTOFF = 0.9 + +MSA_PAD_VALUES = { + 'msa_all_seq': MSA_GAP_IDX, + 'msa_mask_all_seq': 1, + 'deletion_matrix_all_seq': 0, + 'deletion_matrix_int_all_seq': 0, + 'msa': MSA_GAP_IDX, + 'msa_mask': 1, + 'deletion_matrix': 0, + 'deletion_matrix_int': 0, +} + +MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix', 'deletion_matrix_int') +SEQ_FEATURES = ( + 'residue_index', + 'aatype', + 'all_atom_positions', + 'all_atom_mask', + 'seq_mask', + 'between_segment_residues', + 'has_alt_locations', + 'has_hetatoms', + 'asym_id', + 'entity_id', + 'sym_id', + 'entity_mask', + 'deletion_mean', + 'prediction_atom_mask', + 'literature_positions', + 'atom_indices_to_group_indices', + 'rigid_group_default_frame', + # zy + 'num_sym', +) +TEMPLATE_FEATURES = ( + 'template_aatype', + 'template_all_atom_positions', + 'template_all_atom_mask', +) +CHAIN_FEATURES = ('num_alignments', 'seq_length') + + +def create_paired_features(chains: Iterable[NumpyDict], ) -> List[NumpyDict]: + """Returns the original chains with paired NUM_SEQ features. + + Args: + chains: A list of feature dictionaries for each chain. + + Returns: + A list of feature dictionaries with sequence features including only + rows to be paired. + """ + chains = list(chains) + chain_keys = chains[0].keys() + + if len(chains) < 2: + return chains + else: + updated_chains = [] + paired_chains_to_paired_row_indices = pair_sequences(chains) + paired_rows = reorder_paired_rows(paired_chains_to_paired_row_indices) + + for chain_num, chain in enumerate(chains): + new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k} + for feature_name in chain_keys: + if feature_name.endswith('_all_seq'): + feats_padded = pad_features(chain[feature_name], + feature_name) + new_chain[feature_name] = feats_padded[ + paired_rows[:, chain_num]] + new_chain['num_alignments_all_seq'] = np.asarray( + len(paired_rows[:, chain_num])) + updated_chains.append(new_chain) + return updated_chains + + +def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray: + """Add a 'padding' row at the end of the features list. + + The padding row will be selected as a 'paired' row in the case of partial + alignment - for the chain that doesn't have paired alignment. + + Args: + feature: The feature to be padded. + feature_name: The name of the feature to be padded. + + Returns: + The feature with an additional padding row. + """ + assert feature.dtype != np.dtype(np.string_) + if feature_name in ( + 'msa_all_seq', + 'msa_mask_all_seq', + 'deletion_matrix_all_seq', + 'deletion_matrix_int_all_seq', + ): + num_res = feature.shape[1] + padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res], + feature.dtype) + elif feature_name == 'msa_species_identifiers_all_seq': + padding = [b''] + else: + return feature + feats_padded = np.concatenate([feature, padding], axis=0) + return feats_padded + + +def _make_msa_df(chain_features: NumpyDict) -> pd.DataFrame: + """Makes dataframe with msa features needed for msa pairing.""" + chain_msa = chain_features['msa_all_seq'] + query_seq = chain_msa[0] + per_seq_similarity = np.sum( + query_seq[None] == chain_msa, axis=-1) / float(len(query_seq)) + per_seq_gap = np.sum(chain_msa == 21, axis=-1) / float(len(query_seq)) + msa_df = pd.DataFrame({ + 'msa_species_identifiers': + chain_features['msa_species_identifiers_all_seq'], + 'msa_row': + np.arange(len(chain_features['msa_species_identifiers_all_seq'])), + 'msa_similarity': + per_seq_similarity, + 'gap': + per_seq_gap, + }) + return msa_df + + +def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]: + """Creates mapping from species to msa dataframe of that species.""" + species_lookup = {} + for species, species_df in msa_df.groupby('msa_species_identifiers'): + species_lookup[species] = species_df + return species_lookup + + +def _match_rows_by_sequence_similarity( + this_species_msa_dfs: List[pd.DataFrame], ) -> List[List[int]]: # noqa + """Finds MSA sequence pairings across chains based on sequence similarity. + + Each chain's MSA sequences are first sorted by their sequence similarity to + their respective target sequence. The sequences are then paired, starting + from the sequences most similar to their target sequence. + + Args: + this_species_msa_dfs: a list of dataframes containing MSA features for + sequences for a specific species. + + Returns: + A list of lists, each containing M indices corresponding to paired MSA rows, + where M is the number of chains. + """ + all_paired_msa_rows = [] + + num_seqs = [ + len(species_df) for species_df in this_species_msa_dfs + if species_df is not None + ] + take_num_seqs = np.min(num_seqs) + + # sort_by_similarity = lambda x: x.sort_values( + # 'msa_similarity', axis=0, ascending=False) + + def sort_by_similarity(x): + return x.sort_values('msa_similarity', axis=0, ascending=False) + + for species_df in this_species_msa_dfs: + if species_df is not None: + species_df_sorted = sort_by_similarity(species_df) + msa_rows = species_df_sorted.msa_row.iloc[:take_num_seqs].values + else: + msa_rows = [-1] * take_num_seqs # take the last 'padding' row + all_paired_msa_rows.append(msa_rows) + all_paired_msa_rows = list(np.array(all_paired_msa_rows).transpose()) + return all_paired_msa_rows + + +def pair_sequences(examples: List[NumpyDict]) -> Dict[int, np.ndarray]: + """Returns indices for paired MSA sequences across chains.""" + + num_examples = len(examples) + + all_chain_species_dict = [] + common_species = set() + for chain_features in examples: + msa_df = _make_msa_df(chain_features) + species_dict = _create_species_dict(msa_df) + all_chain_species_dict.append(species_dict) + common_species.update(set(species_dict)) + + common_species = sorted(common_species) + common_species.remove(b'') # Remove target sequence species. + + all_paired_msa_rows = [np.zeros(len(examples), int)] + all_paired_msa_rows_dict = {k: [] for k in range(num_examples)} + all_paired_msa_rows_dict[num_examples] = [np.zeros(len(examples), int)] + + for species in common_species: + if not species: + continue + this_species_msa_dfs = [] + species_dfs_present = 0 + for species_dict in all_chain_species_dict: + if species in species_dict: + this_species_msa_dfs.append(species_dict[species]) + species_dfs_present += 1 + else: + this_species_msa_dfs.append(None) + + # Skip species that are present in only one chain. + if species_dfs_present <= 1: + continue + + if np.any( + np.array([ + len(species_df) for species_df in this_species_msa_dfs + if isinstance(species_df, pd.DataFrame) + ]) > 600): + continue + + paired_msa_rows = _match_rows_by_sequence_similarity( + this_species_msa_dfs) + all_paired_msa_rows.extend(paired_msa_rows) + all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows) + all_paired_msa_rows_dict = { + num_examples: np.array(paired_msa_rows) + for num_examples, paired_msa_rows in all_paired_msa_rows_dict.items() + } + return all_paired_msa_rows_dict + + +def reorder_paired_rows( + all_paired_msa_rows_dict: Dict[int, np.ndarray]) -> np.ndarray: + """Creates a list of indices of paired MSA rows across chains. + + Args: + all_paired_msa_rows_dict: a mapping from the number of paired chains to the + paired indices. + + Returns: + a list of lists, each containing indices of paired MSA rows across chains. + The paired-index lists are ordered by: + 1) the number of chains in the paired alignment, i.e, all-chain pairings + will come first. + 2) e-values + """ + all_paired_msa_rows = [] + + for num_pairings in sorted(all_paired_msa_rows_dict, reverse=True): + paired_rows = all_paired_msa_rows_dict[num_pairings] + paired_rows_product = np.abs( + np.array( + [np.prod(rows.astype(np.float64)) for rows in paired_rows])) + paired_rows_sort_index = np.argsort(paired_rows_product) + all_paired_msa_rows.extend(paired_rows[paired_rows_sort_index]) + + return np.array(all_paired_msa_rows) + + +def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray: + """Like scipy.linalg.block_diag but with an optional padding value.""" + ones_arrs = [np.ones_like(x) for x in arrs] + off_diag_mask = 1 - scipy.linalg.block_diag(*ones_arrs) + diag = scipy.linalg.block_diag(*arrs) + diag += (off_diag_mask * pad_value).astype(diag.dtype) + return diag + + +def _correct_post_merged_feats(np_example: NumpyDict, + np_chains_list: Sequence[NumpyDict], + pair_msa_sequences: bool) -> NumpyDict: + """Adds features that need to be computed/recomputed post merging.""" + + np_example['seq_length'] = np.asarray( + np_example['aatype'].shape[0], dtype=np.int32) + np_example['num_alignments'] = np.asarray( + np_example['msa'].shape[0], dtype=np.int32) + + if not pair_msa_sequences: + # Generate a bias that is 1 for the first row of every block in the + # block diagonal MSA - i.e. make sure the cluster stack always includes + # the query sequences for each chain (since the first row is the query + # sequence). + cluster_bias_masks = [] + for chain in np_chains_list: + mask = np.zeros(chain['msa'].shape[0]) + mask[0] = 1 + cluster_bias_masks.append(mask) + np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks) + + # Initialize Bert mask with masked out off diagonals. + msa_masks = [ + np.ones(x['msa'].shape, dtype=np.int8) for x in np_chains_list + ] + + np_example['bert_mask'] = block_diag(*msa_masks, pad_value=0) + else: + np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0]) + np_example['cluster_bias_mask'][0] = 1 + + # Initialize Bert mask with masked out off diagonals. + msa_masks = [ + np.ones(x['msa'].shape, dtype=np.int8) for x in np_chains_list + ] + msa_masks_all_seq = [ + np.ones(x['msa_all_seq'].shape, dtype=np.int8) + for x in np_chains_list + ] + + msa_mask_block_diag = block_diag(*msa_masks, pad_value=0) + msa_mask_all_seq = np.concatenate(msa_masks_all_seq, axis=1) + np_example['bert_mask'] = np.concatenate( + [msa_mask_all_seq, msa_mask_block_diag], axis=0) + return np_example + + +def _pad_templates(chains: Sequence[NumpyDict], + max_templates: int) -> Sequence[NumpyDict]: + """For each chain pad the number of templates to a fixed size. + + Args: + chains: A list of protein chains. + max_templates: Each chain will be padded to have this many templates. + + Returns: + The list of chains, updated to have template features padded to + max_templates. + """ + for chain in chains: + for k, v in chain.items(): + if k in TEMPLATE_FEATURES: + padding = np.zeros_like(v.shape) + padding[0] = max_templates - v.shape[0] + padding = [(0, p) for p in padding] + chain[k] = np.pad(v, padding, mode='constant') + return chains + + +def _merge_features_from_multiple_chains( + chains: Sequence[NumpyDict], pair_msa_sequences: bool) -> NumpyDict: + """Merge features from multiple chains. + + Args: + chains: A list of feature dictionaries that we want to merge. + pair_msa_sequences: Whether to concatenate MSA features along the + num_res dimension (if True), or to block diagonalize them (if False). + + Returns: + A feature dictionary for the merged example. + """ + merged_example = {} + for feature_name in chains[0]: + feats = [x[feature_name] for x in chains] + feature_name_split = feature_name.split('_all_seq')[0] + if feature_name_split in MSA_FEATURES: + if pair_msa_sequences or '_all_seq' in feature_name: + merged_example[feature_name] = np.concatenate(feats, axis=1) + if feature_name_split == 'msa': + merged_example['msa_chains_all_seq'] = np.ones( + merged_example[feature_name].shape[0]).reshape(-1, 1) + else: + merged_example[feature_name] = block_diag( + *feats, pad_value=MSA_PAD_VALUES[feature_name]) + if feature_name_split == 'msa': + msa_chains = [] + for i, feat in enumerate(feats): + cur_shape = feat.shape[0] + vals = np.ones(cur_shape) * (i + 2) + msa_chains.append(vals) + merged_example['msa_chains'] = np.concatenate( + msa_chains).reshape(-1, 1) + elif feature_name_split in SEQ_FEATURES: + merged_example[feature_name] = np.concatenate(feats, axis=0) + elif feature_name_split in TEMPLATE_FEATURES: + merged_example[feature_name] = np.concatenate(feats, axis=1) + elif feature_name_split in CHAIN_FEATURES: + merged_example[feature_name] = np.sum(feats).astype(np.int32) + else: + merged_example[feature_name] = feats[0] + return merged_example + + +def _merge_homomers_dense_msa( + chains: Iterable[NumpyDict]) -> Sequence[NumpyDict]: + """Merge all identical chains, making the resulting MSA dense. + + Args: + chains: An iterable of features for each chain. + + Returns: + A list of feature dictionaries. All features with the same entity_id + will be merged - MSA features will be concatenated along the num_res + dimension - making them dense. + """ + entity_chains = collections.defaultdict(list) + for chain in chains: + entity_id = chain['entity_id'][0] + entity_chains[entity_id].append(chain) + + grouped_chains = [] + for entity_id in sorted(entity_chains): + chains = entity_chains[entity_id] + grouped_chains.append(chains) + chains = [ + _merge_features_from_multiple_chains(chains, pair_msa_sequences=True) + for chains in grouped_chains + ] + return chains + + +def _concatenate_paired_and_unpaired_features(example: NumpyDict) -> NumpyDict: + """Merges paired and block-diagonalised features.""" + features = MSA_FEATURES + ('msa_chains', ) + for feature_name in features: + if feature_name in example: + feat = example[feature_name] + feat_all_seq = example[feature_name + '_all_seq'] + try: + merged_feat = np.concatenate([feat_all_seq, feat], axis=0) + except Exception as ex: + raise Exception( + 'concat failed.', + feature_name, + feat_all_seq.shape, + feat.shape, + ex.__class__, + ex, + ) + example[feature_name] = merged_feat + example['num_alignments'] = np.array( + example['msa'].shape[0], dtype=np.int32) + return example + + +def merge_chain_features(np_chains_list: List[NumpyDict], + pair_msa_sequences: bool, + max_templates: int) -> NumpyDict: + """Merges features for multiple chains to single FeatureDict. + + Args: + np_chains_list: List of FeatureDicts for each chain. + pair_msa_sequences: Whether to merge paired MSAs. + max_templates: The maximum number of templates to include. + + Returns: + Single FeatureDict for entire complex. + """ + np_chains_list = _pad_templates( + np_chains_list, max_templates=max_templates) + np_chains_list = _merge_homomers_dense_msa(np_chains_list) + # Unpaired MSA features will be always block-diagonalised; paired MSA + # features will be concatenated. + np_example = _merge_features_from_multiple_chains( + np_chains_list, pair_msa_sequences=False) + if pair_msa_sequences: + np_example = _concatenate_paired_and_unpaired_features(np_example) + np_example = _correct_post_merged_feats( + np_example=np_example, + np_chains_list=np_chains_list, + pair_msa_sequences=pair_msa_sequences, + ) + + return np_example + + +def deduplicate_unpaired_sequences( + np_chains: List[NumpyDict]) -> List[NumpyDict]: + """Removes unpaired sequences which duplicate a paired sequence.""" + + feature_names = np_chains[0].keys() + msa_features = MSA_FEATURES + cache_msa_features = {} + for chain in np_chains: + entity_id = int(chain['entity_id'][0]) + if entity_id not in cache_msa_features: + sequence_set = set(s.tobytes() for s in chain['msa_all_seq']) + keep_rows = [] + # Go through unpaired MSA seqs and remove any rows that correspond to the + # sequences that are already present in the paired MSA. + for row_num, seq in enumerate(chain['msa']): + if seq.tobytes() not in sequence_set: + keep_rows.append(row_num) + new_msa_features = {} + for feature_name in feature_names: + if feature_name in msa_features: + if keep_rows: + new_msa_features[feature_name] = chain[feature_name][ + keep_rows] + else: + new_shape = list(chain[feature_name].shape) + new_shape[0] = 0 + new_msa_features[feature_name] = np.zeros( + new_shape, dtype=chain[feature_name].dtype) + cache_msa_features[entity_id] = new_msa_features + for feature_name in cache_msa_features[entity_id]: + chain[feature_name] = cache_msa_features[entity_id][feature_name] + chain['num_alignments'] = np.array( + chain['msa'].shape[0], dtype=np.int32) + return np_chains diff --git a/modelscope/models/science/unifold/data/process.py b/modelscope/models/science/unifold/data/process.py new file mode 100644 index 00000000..3987cb1c --- /dev/null +++ b/modelscope/models/science/unifold/data/process.py @@ -0,0 +1,264 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from typing import Optional + +import numpy as np +import torch + +from modelscope.models.science.unifold.data import data_ops + + +def nonensembled_fns(common_cfg, mode_cfg): + """Input pipeline data transformers that are not ensembled.""" + v2_feature = common_cfg.v2_feature + operators = [] + if mode_cfg.random_delete_msa: + operators.append( + data_ops.random_delete_msa(common_cfg.random_delete_msa)) + operators.extend([ + data_ops.cast_to_64bit_ints, + data_ops.correct_msa_restypes, + data_ops.squeeze_features, + data_ops.randomly_replace_msa_with_unknown(0.0), + data_ops.make_seq_mask, + data_ops.make_msa_mask, + ]) + operators.append(data_ops.make_hhblits_profile_v2 + if v2_feature else data_ops.make_hhblits_profile) + if common_cfg.use_templates: + operators.extend([ + data_ops.make_template_mask, + data_ops.make_pseudo_beta('template_'), + ]) + operators.append( + data_ops.crop_templates( + max_templates=mode_cfg.max_templates, + subsample_templates=mode_cfg.subsample_templates, + )) + + if common_cfg.use_template_torsion_angles: + operators.extend([ + data_ops.atom37_to_torsion_angles('template_'), + ]) + + operators.append(data_ops.make_atom14_masks) + operators.append(data_ops.make_target_feat) + + return operators + + +def crop_and_fix_size_fns(common_cfg, mode_cfg, crop_and_fix_size_seed): + operators = [] + if common_cfg.reduce_msa_clusters_by_max_templates: + pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates + else: + pad_msa_clusters = mode_cfg.max_msa_clusters + crop_feats = dict(common_cfg.features) + if mode_cfg.fixed_size: + if mode_cfg.crop: + if common_cfg.is_multimer: + crop_fn = data_ops.crop_to_size_multimer( + crop_size=mode_cfg.crop_size, + shape_schema=crop_feats, + seed=crop_and_fix_size_seed, + spatial_crop_prob=mode_cfg.spatial_crop_prob, + ca_ca_threshold=mode_cfg.ca_ca_threshold, + ) + else: + crop_fn = data_ops.crop_to_size_single( + crop_size=mode_cfg.crop_size, + shape_schema=crop_feats, + seed=crop_and_fix_size_seed, + ) + operators.append(crop_fn) + + operators.append(data_ops.select_feat(crop_feats)) + + operators.append( + data_ops.make_fixed_size( + crop_feats, + pad_msa_clusters, + common_cfg.max_extra_msa, + mode_cfg.crop_size, + mode_cfg.max_templates, + )) + return operators + + +def ensembled_fns(common_cfg, mode_cfg): + """Input pipeline data transformers that can be ensembled and averaged.""" + operators = [] + multimer_mode = common_cfg.is_multimer + v2_feature = common_cfg.v2_feature + # multimer don't use block delete msa + if mode_cfg.block_delete_msa and not multimer_mode: + operators.append( + data_ops.block_delete_msa(common_cfg.block_delete_msa)) + if 'max_distillation_msa_clusters' in mode_cfg: + operators.append( + data_ops.sample_msa_distillation( + mode_cfg.max_distillation_msa_clusters)) + + if common_cfg.reduce_msa_clusters_by_max_templates: + pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates + else: + pad_msa_clusters = mode_cfg.max_msa_clusters + + max_msa_clusters = pad_msa_clusters + max_extra_msa = common_cfg.max_extra_msa + + assert common_cfg.resample_msa_in_recycling + gumbel_sample = common_cfg.gumbel_sample + operators.append( + data_ops.sample_msa( + max_msa_clusters, + keep_extra=True, + gumbel_sample=gumbel_sample, + biased_msa_by_chain=mode_cfg.biased_msa_by_chain, + )) + + if 'masked_msa' in common_cfg: + # Masked MSA should come *before* MSA clustering so that + # the clustering and full MSA profile do not leak information about + # the masked locations and secret corrupted locations. + operators.append( + data_ops.make_masked_msa( + common_cfg.masked_msa, + mode_cfg.masked_msa_replace_fraction, + gumbel_sample=gumbel_sample, + share_mask=mode_cfg.share_mask, + )) + + if common_cfg.msa_cluster_features: + if v2_feature: + operators.append(data_ops.nearest_neighbor_clusters_v2()) + else: + operators.append(data_ops.nearest_neighbor_clusters()) + operators.append(data_ops.summarize_clusters) + + if v2_feature: + operators.append(data_ops.make_msa_feat_v2) + else: + operators.append(data_ops.make_msa_feat) + # Crop after creating the cluster profiles. + if max_extra_msa: + if v2_feature: + operators.append(data_ops.make_extra_msa_feat(max_extra_msa)) + else: + operators.append(data_ops.crop_extra_msa(max_extra_msa)) + else: + operators.append(data_ops.delete_extra_msa) + # operators.append(data_operators.select_feat(common_cfg.recycling_features)) + return operators + + +def process_features(tensors, common_cfg, mode_cfg): + """Based on the config, apply filters and transformations to the data.""" + is_distillation = bool(tensors.get('is_distillation', 0)) + multimer_mode = common_cfg.is_multimer + crop_and_fix_size_seed = int(tensors['crop_and_fix_size_seed']) + crop_fn = crop_and_fix_size_fns( + common_cfg, + mode_cfg, + crop_and_fix_size_seed, + ) + + def wrap_ensemble_fn(data, i): + """Function to be mapped over the ensemble dimension.""" + d = data.copy() + fns = ensembled_fns( + common_cfg, + mode_cfg, + ) + new_d = compose(fns)(d) + if not multimer_mode or is_distillation: + new_d = data_ops.select_feat(common_cfg.recycling_features)(new_d) + return compose(crop_fn)(new_d) + else: # select after crop for spatial cropping + d = compose(crop_fn)(d) + d = data_ops.select_feat(common_cfg.recycling_features)(d) + return d + + nonensembled = nonensembled_fns(common_cfg, mode_cfg) + + if mode_cfg.supervised and (not multimer_mode or is_distillation): + nonensembled.extend(label_transform_fn()) + + tensors = compose(nonensembled)(tensors) + + num_recycling = int(tensors['num_recycling_iters']) + 1 + num_ensembles = mode_cfg.num_ensembles + + ensemble_tensors = map_fn( + lambda x: wrap_ensemble_fn(tensors, x), + torch.arange(num_recycling * num_ensembles), + ) + tensors = compose(crop_fn)(tensors) + # add a dummy dim to align with recycling features + tensors = {k: torch.stack([tensors[k]], dim=0) for k in tensors} + tensors.update(ensemble_tensors) + return tensors + + +@data_ops.curry1 +def compose(x, fs): + for f in fs: + x = f(x) + return x + + +def pad_then_stack(values, ): + if len(values[0].shape) >= 1: + size = max(v.shape[0] for v in values) + new_values = [] + for v in values: + if v.shape[0] < size: + res = values[0].new_zeros(size, *v.shape[1:]) + res[:v.shape[0], ...] = v + else: + res = v + new_values.append(res) + else: + new_values = values + return torch.stack(new_values, dim=0) + + +def map_fn(fun, x): + ensembles = [fun(elem) for elem in x] + features = ensembles[0].keys() + ensembled_dict = {} + for feat in features: + ensembled_dict[feat] = pad_then_stack( + [dict_i[feat] for dict_i in ensembles]) + return ensembled_dict + + +def process_single_label(label: dict, + num_ensemble: Optional[int] = None) -> dict: + assert 'aatype' in label + assert 'all_atom_positions' in label + assert 'all_atom_mask' in label + label = compose(label_transform_fn())(label) + if num_ensemble is not None: + label = { + k: torch.stack([v for _ in range(num_ensemble)]) + for k, v in label.items() + } + return label + + +def process_labels(labels_list, num_ensemble: Optional[int] = None): + return [process_single_label(ll, num_ensemble) for ll in labels_list] + + +def label_transform_fn(): + return [ + data_ops.make_atom14_masks, + data_ops.make_atom14_positions, + data_ops.atom37_to_frames, + data_ops.atom37_to_torsion_angles(''), + data_ops.make_pseudo_beta(''), + data_ops.get_backbone_frames, + data_ops.get_chi_angles, + ] diff --git a/modelscope/models/science/unifold/data/process_multimer.py b/modelscope/models/science/unifold/data/process_multimer.py new file mode 100644 index 00000000..04572d2d --- /dev/null +++ b/modelscope/models/science/unifold/data/process_multimer.py @@ -0,0 +1,417 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature processing logic for multimer data """ + +import collections +from typing import Iterable, List, MutableMapping + +import numpy as np + +from modelscope.models.science.unifold.data import (msa_pairing, + residue_constants) +from .utils import correct_template_restypes + +FeatureDict = MutableMapping[str, np.ndarray] + +REQUIRED_FEATURES = frozenset({ + 'aatype', + 'all_atom_mask', + 'all_atom_positions', + 'all_chains_entity_ids', + 'all_crops_all_chains_mask', + 'all_crops_all_chains_positions', + 'all_crops_all_chains_residue_ids', + 'assembly_num_chains', + 'asym_id', + 'bert_mask', + 'cluster_bias_mask', + 'deletion_matrix', + 'deletion_mean', + 'entity_id', + 'entity_mask', + 'mem_peak', + 'msa', + 'msa_mask', + 'num_alignments', + 'num_templates', + 'queue_size', + 'residue_index', + 'resolution', + 'seq_length', + 'seq_mask', + 'sym_id', + 'template_aatype', + 'template_all_atom_mask', + 'template_all_atom_positions', + # zy added: + 'asym_len', + 'template_sum_probs', + 'num_sym', + 'msa_chains', +}) + +MAX_TEMPLATES = 4 +MSA_CROP_SIZE = 2048 + + +def _is_homomer_or_monomer(chains: Iterable[FeatureDict]) -> bool: + """Checks if a list of chains represents a homomer/monomer example.""" + # Note that an entity_id of 0 indicates padding. + num_unique_chains = len( + np.unique( + np.concatenate([ + np.unique(chain['entity_id'][chain['entity_id'] > 0]) + for chain in chains + ]))) + return num_unique_chains == 1 + + +def pair_and_merge( + all_chain_features: MutableMapping[str, FeatureDict]) -> FeatureDict: + """Runs processing on features to augment, pair and merge. + + Args: + all_chain_features: A MutableMap of dictionaries of features for each chain. + + Returns: + A dictionary of features. + """ + + process_unmerged_features(all_chain_features) + + np_chains_list = all_chain_features + + pair_msa_sequences = not _is_homomer_or_monomer(np_chains_list) + + if pair_msa_sequences: + np_chains_list = msa_pairing.create_paired_features( + chains=np_chains_list) + np_chains_list = msa_pairing.deduplicate_unpaired_sequences( + np_chains_list) + np_chains_list = crop_chains( + np_chains_list, + msa_crop_size=MSA_CROP_SIZE, + pair_msa_sequences=pair_msa_sequences, + max_templates=MAX_TEMPLATES, + ) + np_example = msa_pairing.merge_chain_features( + np_chains_list=np_chains_list, + pair_msa_sequences=pair_msa_sequences, + max_templates=MAX_TEMPLATES, + ) + np_example = process_final(np_example) + return np_example + + +def crop_chains( + chains_list: List[FeatureDict], + msa_crop_size: int, + pair_msa_sequences: bool, + max_templates: int, +) -> List[FeatureDict]: + """Crops the MSAs for a set of chains. + + Args: + chains_list: A list of chains to be cropped. + msa_crop_size: The total number of sequences to crop from the MSA. + pair_msa_sequences: Whether we are operating in sequence-pairing mode. + max_templates: The maximum templates to use per chain. + + Returns: + The chains cropped. + """ + + # Apply the cropping. + cropped_chains = [] + for chain in chains_list: + cropped_chain = _crop_single_chain( + chain, + msa_crop_size=msa_crop_size, + pair_msa_sequences=pair_msa_sequences, + max_templates=max_templates, + ) + cropped_chains.append(cropped_chain) + + return cropped_chains + + +def _crop_single_chain(chain: FeatureDict, msa_crop_size: int, + pair_msa_sequences: bool, + max_templates: int) -> FeatureDict: + """Crops msa sequences to `msa_crop_size`.""" + msa_size = chain['num_alignments'] + + if pair_msa_sequences: + msa_size_all_seq = chain['num_alignments_all_seq'] + msa_crop_size_all_seq = np.minimum(msa_size_all_seq, + msa_crop_size // 2) + + # We reduce the number of un-paired sequences, by the number of times a + # sequence from this chain's MSA is included in the paired MSA. This keeps + # the MSA size for each chain roughly constant. + msa_all_seq = chain['msa_all_seq'][:msa_crop_size_all_seq, :] + num_non_gapped_pairs = np.sum( + np.any(msa_all_seq != msa_pairing.MSA_GAP_IDX, axis=1)) + num_non_gapped_pairs = np.minimum(num_non_gapped_pairs, + msa_crop_size_all_seq) + + # Restrict the unpaired crop size so that paired+unpaired sequences do not + # exceed msa_seqs_per_chain for each chain. + max_msa_crop_size = np.maximum(msa_crop_size - num_non_gapped_pairs, 0) + msa_crop_size = np.minimum(msa_size, max_msa_crop_size) + else: + msa_crop_size = np.minimum(msa_size, msa_crop_size) + + include_templates = 'template_aatype' in chain and max_templates + if include_templates: + num_templates = chain['template_aatype'].shape[0] + templates_crop_size = np.minimum(num_templates, max_templates) + + for k in chain: + k_split = k.split('_all_seq')[0] + if k_split in msa_pairing.TEMPLATE_FEATURES: + chain[k] = chain[k][:templates_crop_size, :] + elif k_split in msa_pairing.MSA_FEATURES: + if '_all_seq' in k and pair_msa_sequences: + chain[k] = chain[k][:msa_crop_size_all_seq, :] + else: + chain[k] = chain[k][:msa_crop_size, :] + + chain['num_alignments'] = np.asarray(msa_crop_size, dtype=np.int32) + if include_templates: + chain['num_templates'] = np.asarray( + templates_crop_size, dtype=np.int32) + if pair_msa_sequences: + chain['num_alignments_all_seq'] = np.asarray( + msa_crop_size_all_seq, dtype=np.int32) + return chain + + +def process_final(np_example: FeatureDict) -> FeatureDict: + """Final processing steps in data pipeline, after merging and pairing.""" + np_example = _make_seq_mask(np_example) + np_example = _make_msa_mask(np_example) + np_example = _filter_features(np_example) + return np_example + + +def _make_seq_mask(np_example): + np_example['seq_mask'] = (np_example['entity_id'] > 0).astype(np.float32) + return np_example + + +def _make_msa_mask(np_example): + """Mask features are all ones, but will later be zero-padded.""" + + np_example['msa_mask'] = np.ones_like(np_example['msa'], dtype=np.int8) + + seq_mask = (np_example['entity_id'] > 0).astype(np.int8) + np_example['msa_mask'] *= seq_mask[None] + + return np_example + + +def _filter_features(np_example: FeatureDict) -> FeatureDict: + """Filters features of example to only those requested.""" + return {k: v for (k, v) in np_example.items() if k in REQUIRED_FEATURES} + + +def process_unmerged_features(all_chain_features: MutableMapping[str, + FeatureDict]): + """Postprocessing stage for per-chain features before merging.""" + num_chains = len(all_chain_features) + for chain_features in all_chain_features: + # Convert deletion matrices to float. + if 'deletion_matrix_int' in chain_features: + chain_features['deletion_matrix'] = np.asarray( + chain_features.pop('deletion_matrix_int'), dtype=np.float32) + if 'deletion_matrix_int_all_seq' in chain_features: + chain_features['deletion_matrix_all_seq'] = np.asarray( + chain_features.pop('deletion_matrix_int_all_seq'), + dtype=np.float32) + + chain_features['deletion_mean'] = np.mean( + chain_features['deletion_matrix'], axis=0) + + if 'all_atom_positions' not in chain_features: + # Add all_atom_mask and dummy all_atom_positions based on aatype. + all_atom_mask = residue_constants.STANDARD_ATOM_MASK[ + chain_features['aatype']] + chain_features['all_atom_mask'] = all_atom_mask + chain_features['all_atom_positions'] = np.zeros( + list(all_atom_mask.shape) + [3]) + + # Add assembly_num_chains. + chain_features['assembly_num_chains'] = np.asarray(num_chains) + + # Add entity_mask. + for chain_features in all_chain_features: + chain_features['entity_mask'] = ( + chain_features['entity_id'] != # noqa W504 + 0).astype(np.int32) + + +def empty_template_feats(n_res): + return { + 'template_aatype': + np.zeros((0, n_res)).astype(np.int64), + 'template_all_atom_positions': + np.zeros((0, n_res, 37, 3)).astype(np.float32), + 'template_sum_probs': + np.zeros((0, 1)).astype(np.float32), + 'template_all_atom_mask': + np.zeros((0, n_res, 37)).astype(np.float32), + } + + +def convert_monomer_features(monomer_features: FeatureDict) -> FeatureDict: + """Reshapes and modifies monomer features for multimer models.""" + if monomer_features['template_aatype'].shape[0] == 0: + monomer_features.update( + empty_template_feats(monomer_features['aatype'].shape[0])) + converted = {} + unnecessary_leading_dim_feats = { + 'sequence', + 'domain_name', + 'num_alignments', + 'seq_length', + } + for feature_name, feature in monomer_features.items(): + if feature_name in unnecessary_leading_dim_feats: + # asarray ensures it's a np.ndarray. + feature = np.asarray(feature[0], dtype=feature.dtype) + elif feature_name == 'aatype': + # The multimer model performs the one-hot operation itself. + feature = np.argmax(feature, axis=-1).astype(np.int32) + elif feature_name == 'template_aatype': + if feature.shape[0] > 0: + feature = correct_template_restypes(feature) + elif feature_name == 'template_all_atom_masks': + feature_name = 'template_all_atom_mask' + elif feature_name == 'msa': + feature = feature.astype(np.uint8) + + if feature_name.endswith('_mask'): + feature = feature.astype(np.float32) + + converted[feature_name] = feature + + if 'deletion_matrix_int' in monomer_features: + monomer_features['deletion_matrix'] = monomer_features.pop( + 'deletion_matrix_int').astype(np.float32) + + converted.pop( + 'template_sum_probs' + ) # zy: this input is checked to be dirty in shape. TODO: figure out why and make it right. + return converted + + +def int_id_to_str_id(num: int) -> str: + """Encodes a number as a string, using reverse spreadsheet style naming. + + Args: + num: A positive integer. + + Returns: + A string that encodes the positive integer using reverse spreadsheet style, + naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the + usual way to encode chain IDs in mmCIF files. + """ + if num <= 0: + raise ValueError(f'Only positive integers allowed, got {num}.') + + num = num - 1 # 1-based indexing. + output = [] + while num >= 0: + output.append(chr(num % 26 + ord('A'))) + num = num // 26 - 1 + return ''.join(output) + + +def add_assembly_features(all_chain_features, ): + """Add features to distinguish between chains. + + Args: + all_chain_features: A dictionary which maps chain_id to a dictionary of + features for each chain. + + Returns: + all_chain_features: A dictionary which maps strings of the form + `_` to the corresponding chain features. E.g. two + chains from a homodimer would have keys A_1 and A_2. Two chains from a + heterodimer would have keys A_1 and B_1. + """ + # Group the chains by sequence + seq_to_entity_id = {} + grouped_chains = collections.defaultdict(list) + for chain_features in all_chain_features: + assert 'sequence' in chain_features + seq = str(chain_features['sequence']) + if seq not in seq_to_entity_id: + seq_to_entity_id[seq] = len(seq_to_entity_id) + 1 + grouped_chains[seq_to_entity_id[seq]].append(chain_features) + + new_all_chain_features = [] + chain_id = 1 + for entity_id, group_chain_features in grouped_chains.items(): + num_sym = len(group_chain_features) # zy + for sym_id, chain_features in enumerate(group_chain_features, start=1): + seq_length = chain_features['seq_length'] + chain_features['asym_id'] = chain_id * np.ones(seq_length) + chain_features['sym_id'] = sym_id * np.ones(seq_length) + chain_features['entity_id'] = entity_id * np.ones(seq_length) + chain_features['num_sym'] = num_sym * np.ones(seq_length) + chain_id += 1 + new_all_chain_features.append(chain_features) + + return new_all_chain_features + + +def pad_msa(np_example, min_num_seq): + np_example = dict(np_example) + num_seq = np_example['msa'].shape[0] + if num_seq < min_num_seq: + for feat in ('msa', 'deletion_matrix', 'bert_mask', 'msa_mask', + 'msa_chains'): + np_example[feat] = np.pad(np_example[feat], + ((0, min_num_seq - num_seq), (0, 0))) + np_example['cluster_bias_mask'] = np.pad( + np_example['cluster_bias_mask'], ((0, min_num_seq - num_seq), )) + return np_example + + +def post_process(np_example): + np_example = pad_msa(np_example, 512) + no_dim_keys = [ + 'num_alignments', + 'assembly_num_chains', + 'num_templates', + 'seq_length', + 'resolution', + ] + for k in no_dim_keys: + if k in np_example: + np_example[k] = np_example[k].reshape(-1) + return np_example + + +def merge_msas(msa, del_mat, new_msa, new_del_mat): + cur_msa_set = set([tuple(m) for m in msa]) + new_rows = [] + for i, s in enumerate(new_msa): + if tuple(s) not in cur_msa_set: + new_rows.append(i) + ret_msa = np.concatenate([msa, new_msa[new_rows]], axis=0) + ret_del_mat = np.concatenate([del_mat, new_del_mat[new_rows]], axis=0) + return ret_msa, ret_del_mat diff --git a/modelscope/models/science/unifold/data/protein.py b/modelscope/models/science/unifold/data/protein.py new file mode 100644 index 00000000..42308d04 --- /dev/null +++ b/modelscope/models/science/unifold/data/protein.py @@ -0,0 +1,322 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Protein data type.""" +import dataclasses +import io +from typing import Any, Mapping, Optional + +import numpy as np +from Bio.PDB import PDBParser + +from modelscope.models.science.unifold.data import residue_constants + +FeatureDict = Mapping[str, np.ndarray] +ModelOutput = Mapping[str, Any] # Is a nested dict. + +# Complete sequence of chain IDs supported by the PDB format. +PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' +PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62. + + +@dataclasses.dataclass(frozen=True) +class Protein: + """Protein structure representation.""" + + # Cartesian coordinates of atoms in angstroms. The atom types correspond to + # residue_constants.atom_types, i.e. the first three are N, CA, CB. + atom_positions: np.ndarray # [num_res, num_atom_type, 3] + + # Amino-acid type for each residue represented as an integer between 0 and + # 20, where 20 is 'X'. + aatype: np.ndarray # [num_res] + + # Binary float mask to indicate presence of a particular atom. 1.0 if an atom + # is present and 0.0 if not. This should be used for loss masking. + atom_mask: np.ndarray # [num_res, num_atom_type] + + # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. + residue_index: np.ndarray # [num_res] + + # 0-indexed number corresponding to the chain in the protein that this residue + # belongs to. + chain_index: np.ndarray # [num_res] + + # B-factors, or temperature factors, of each residue (in sq. angstroms units), + # representing the displacement of the residue from its ground truth mean + # value. + b_factors: np.ndarray # [num_res, num_atom_type] + + def __post_init__(self): + if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS: + raise ValueError( + f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains ' + 'because these cannot be written to PDB format.') + + +def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: + """Takes a PDB string and constructs a Protein object. + + WARNING: All non-standard residue types will be converted into UNK. All + non-standard atoms will be ignored. + + Args: + pdb_str: The contents of the pdb file + chain_id: If chain_id is specified (e.g. A), then only that chain + is parsed. Otherwise all chains are parsed. + + Returns: + A new `Protein` parsed from the pdb contents. + """ + pdb_fh = io.StringIO(pdb_str) + parser = PDBParser(QUIET=True) + structure = parser.get_structure('none', pdb_fh) + models = list(structure.get_models()) + if len(models) != 1: + raise ValueError( + f'Only single model PDBs are supported. Found {len(models)} models.' + ) + model = models[0] + + atom_positions = [] + aatype = [] + atom_mask = [] + residue_index = [] + chain_ids = [] + b_factors = [] + + for chain in model: + if chain_id is not None and chain.id != chain_id: + continue + for res in chain: + if res.id[2] != ' ': + raise ValueError( + f'PDB contains an insertion code at chain {chain.id} and residue ' + f'index {res.id[1]}. These are not supported.') + res_shortname = residue_constants.restype_3to1.get( + res.resname, 'X') + restype_idx = residue_constants.restype_order.get( + res_shortname, residue_constants.restype_num) + pos = np.zeros((residue_constants.atom_type_num, 3)) + mask = np.zeros((residue_constants.atom_type_num, )) + res_b_factors = np.zeros((residue_constants.atom_type_num, )) + for atom in res: + if atom.name not in residue_constants.atom_types: + continue + pos[residue_constants.atom_order[atom.name]] = atom.coord + mask[residue_constants.atom_order[atom.name]] = 1.0 + res_b_factors[residue_constants.atom_order[ + atom.name]] = atom.bfactor + if np.sum(mask) < 0.5: + # If no known atom positions are reported for the residue then skip it. + continue + aatype.append(restype_idx) + atom_positions.append(pos) + atom_mask.append(mask) + residue_index.append(res.id[1]) + chain_ids.append(chain.id) + b_factors.append(res_b_factors) + + # Chain IDs are usually characters so map these to ints. + unique_chain_ids = np.unique(chain_ids) + chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)} + chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids]) + + return Protein( + atom_positions=np.array(atom_positions), + atom_mask=np.array(atom_mask), + aatype=np.array(aatype), + residue_index=np.array(residue_index), + chain_index=chain_index, + b_factors=np.array(b_factors), + ) + + +def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str: + chain_end = 'TER' + return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} ' + f'{chain_name:>1}{residue_index:>4}') + + +def to_pdb(prot: Protein) -> str: + """Converts a `Protein` instance to a PDB string. + + Args: + prot: The protein to convert to PDB. + + Returns: + PDB string. + """ + restypes = residue_constants.restypes + ['X'] + + # res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK') + def res_1to3(r): + return residue_constants.restype_1to3.get(restypes[r], 'UNK') + + atom_types = residue_constants.atom_types + + pdb_lines = [] + + atom_mask = prot.atom_mask + aatype = prot.aatype + atom_positions = prot.atom_positions + residue_index = prot.residue_index.astype(np.int32) + chain_index = prot.chain_index.astype(np.int32) + b_factors = prot.b_factors + + if np.any(aatype > residue_constants.restype_num): + raise ValueError('Invalid aatypes.') + + # Construct a mapping from chain integer indices to chain ID strings. + chain_ids = {} + for i in np.unique(chain_index): # np.unique gives sorted output. + if i >= PDB_MAX_CHAINS: + raise ValueError( + f'The PDB format supports at most {PDB_MAX_CHAINS} chains.') + chain_ids[i] = PDB_CHAIN_IDS[i] + + pdb_lines.append('MODEL 1') + atom_index = 1 + last_chain_index = chain_index[0] + # Add all atom sites. + for i in range(aatype.shape[0]): + # Close the previous chain if in a multichain PDB. + if last_chain_index != chain_index[i]: + pdb_lines.append( + _chain_end( + atom_index, + res_1to3(aatype[i - 1]), + chain_ids[chain_index[i - 1]], + residue_index[i - 1], + )) + last_chain_index = chain_index[i] + atom_index += 1 # Atom index increases at the TER symbol. + + res_name_3 = res_1to3(aatype[i]) + for atom_name, pos, mask, b_factor in zip(atom_types, + atom_positions[i], + atom_mask[i], b_factors[i]): + if mask < 0.5: + continue + + record_type = 'ATOM' + name = atom_name if len(atom_name) == 4 else f' {atom_name}' + alt_loc = '' + insertion_code = '' + occupancy = 1.00 + element = atom_name[ + 0] # Protein supports only C, N, O, S, this works. + charge = '' + # PDB is a columnar format, every space matters here! + atom_line = ( + f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}' + f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}' + f'{residue_index[i]:>4}{insertion_code:>1} ' + f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}' + f'{occupancy:>6.2f}{b_factor:>6.2f} ' + f'{element:>2}{charge:>2}') + pdb_lines.append(atom_line) + atom_index += 1 + + # Close the final chain. + pdb_lines.append( + _chain_end( + atom_index, + res_1to3(aatype[-1]), + chain_ids[chain_index[-1]], + residue_index[-1], + )) + pdb_lines.append('ENDMDL') + pdb_lines.append('END') + + # Pad all lines to 80 characters. + pdb_lines = [line.ljust(80) for line in pdb_lines] + return '\n'.join(pdb_lines) + '\n' # Add terminating newline. + + +def ideal_atom_mask(prot: Protein) -> np.ndarray: + """Computes an ideal atom mask. + + `Protein.atom_mask` typically is defined according to the atoms that are + reported in the PDB. This function computes a mask according to heavy atoms + that should be present in the given sequence of amino acids. + + Args: + prot: `Protein` whose fields are `numpy.ndarray` objects. + + Returns: + An ideal atom mask. + """ + return residue_constants.STANDARD_ATOM_MASK[prot.aatype] + + +def from_prediction(features: FeatureDict, + result: ModelOutput, + b_factors: Optional[np.ndarray] = None) -> Protein: + """Assembles a protein from a prediction. + + Args: + features: Dictionary holding model inputs. + fold_output: Dictionary holding model outputs. + b_factors: (Optional) B-factors to use for the protein. + + Returns: + A protein instance. + """ + + if 'asym_id' in features: + chain_index = features['asym_id'] - 1 + else: + chain_index = np.zeros_like((features['aatype'])) + + if b_factors is None: + b_factors = np.zeros_like(result['final_atom_mask']) + + return Protein( + aatype=features['aatype'], + atom_positions=result['final_atom_positions'], + atom_mask=result['final_atom_mask'], + residue_index=features['residue_index'] + 1, + chain_index=chain_index, + b_factors=b_factors, + ) + + +def from_feature(features: FeatureDict, + b_factors: Optional[np.ndarray] = None) -> Protein: + """Assembles a standard pdb from input atom positions & mask. + + Args: + features: Dictionary holding model inputs. + b_factors: (Optional) B-factors to use for the protein. + + Returns: + A protein instance. + """ + + if 'asym_id' in features: + chain_index = features['asym_id'] - 1 + else: + chain_index = np.zeros_like((features['aatype'])) + + if b_factors is None: + b_factors = np.zeros_like(features['all_atom_mask']) + + return Protein( + aatype=features['aatype'], + atom_positions=features['all_atom_positions'], + atom_mask=features['all_atom_mask'], + residue_index=features['residue_index'] + 1, + chain_index=chain_index, + b_factors=b_factors, + ) diff --git a/modelscope/models/science/unifold/data/residue_constants.py b/modelscope/models/science/unifold/data/residue_constants.py new file mode 100644 index 00000000..beebfe89 --- /dev/null +++ b/modelscope/models/science/unifold/data/residue_constants.py @@ -0,0 +1,1212 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Constants used in AlphaFold.""" + +import collections +import functools +import os +from typing import List, Mapping, Tuple + +import numpy as np +from unicore.utils import tree_map + +# Distance from one CA to next CA [trans configuration: omega = 180]. +ca_ca = 3.80209737096 + +# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in +# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have +# chi angles so their chi angle lists are empty. +chi_angles_atoms = { + 'ALA': [], + # Chi5 in arginine is always 0 +- 5 degrees, so ignore it. + 'ARG': [ + ['N', 'CA', 'CB', 'CG'], + ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'NE'], + ['CG', 'CD', 'NE', 'CZ'], + ], + 'ASN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']], + 'ASP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']], + 'CYS': [['N', 'CA', 'CB', 'SG']], + 'GLN': [ + ['N', 'CA', 'CB', 'CG'], + ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'OE1'], + ], + 'GLU': [ + ['N', 'CA', 'CB', 'CG'], + ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'OE1'], + ], + 'GLY': [], + 'HIS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'ND1']], + 'ILE': [['N', 'CA', 'CB', 'CG1'], ['CA', 'CB', 'CG1', 'CD1']], + 'LEU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'LYS': [ + ['N', 'CA', 'CB', 'CG'], + ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'CE'], + ['CG', 'CD', 'CE', 'NZ'], + ], + 'MET': [ + ['N', 'CA', 'CB', 'CG'], + ['CA', 'CB', 'CG', 'SD'], + ['CB', 'CG', 'SD', 'CE'], + ], + 'PHE': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'PRO': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD']], + 'SER': [['N', 'CA', 'CB', 'OG']], + 'THR': [['N', 'CA', 'CB', 'OG1']], + 'TRP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'TYR': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'VAL': [['N', 'CA', 'CB', 'CG1']], +} + +# If chi angles given in fixed-length array, this matrix determines how to mask +# them for each AA type. The order is as per restype_order (see below). +chi_angles_mask = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [1.0, 1.0, 1.0, 1.0], # ARG + [1.0, 1.0, 0.0, 0.0], # ASN + [1.0, 1.0, 0.0, 0.0], # ASP + [1.0, 0.0, 0.0, 0.0], # CYS + [1.0, 1.0, 1.0, 0.0], # GLN + [1.0, 1.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [1.0, 1.0, 0.0, 0.0], # HIS + [1.0, 1.0, 0.0, 0.0], # ILE + [1.0, 1.0, 0.0, 0.0], # LEU + [1.0, 1.0, 1.0, 1.0], # LYS + [1.0, 1.0, 1.0, 0.0], # MET + [1.0, 1.0, 0.0, 0.0], # PHE + [1.0, 1.0, 0.0, 0.0], # PRO + [1.0, 0.0, 0.0, 0.0], # SER + [1.0, 0.0, 0.0, 0.0], # THR + [1.0, 1.0, 0.0, 0.0], # TRP + [1.0, 1.0, 0.0, 0.0], # TYR + [1.0, 0.0, 0.0, 0.0], # VAL +] + +# The following chi angles are pi periodic: they can be rotated by a multiple +# of pi without affecting the structure. +chi_pi_periodic = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [0.0, 0.0, 0.0, 0.0], # ARG + [0.0, 0.0, 0.0, 0.0], # ASN + [0.0, 1.0, 0.0, 0.0], # ASP + [0.0, 0.0, 0.0, 0.0], # CYS + [0.0, 0.0, 0.0, 0.0], # GLN + [0.0, 0.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [0.0, 0.0, 0.0, 0.0], # HIS + [0.0, 0.0, 0.0, 0.0], # ILE + [0.0, 0.0, 0.0, 0.0], # LEU + [0.0, 0.0, 0.0, 0.0], # LYS + [0.0, 0.0, 0.0, 0.0], # MET + [0.0, 1.0, 0.0, 0.0], # PHE + [0.0, 0.0, 0.0, 0.0], # PRO + [0.0, 0.0, 0.0, 0.0], # SER + [0.0, 0.0, 0.0, 0.0], # THR + [0.0, 0.0, 0.0, 0.0], # TRP + [0.0, 1.0, 0.0, 0.0], # TYR + [0.0, 0.0, 0.0, 0.0], # VAL + [0.0, 0.0, 0.0, 0.0], # UNK +] + +# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi, +# psi and chi angles: +# 0: 'backbone group', +# 1: 'pre-omega-group', (empty) +# 2: 'phi-group', (currently empty, because it defines only hydrogens) +# 3: 'psi-group', +# 4,5,6,7: 'chi1,2,3,4-group' +# The atom positions are relative to the axis-end-atom of the corresponding +# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis +# is defined such that the dihedral-angle-definiting atom (the last entry in +# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate). +# format: [atomname, group_idx, rel_position] +rigid_group_atom_positions = { + 'ALA': [ + ['N', 0, (-0.525, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.529, -0.774, -1.205)], + ['O', 3, (0.627, 1.062, 0.000)], + ], + 'ARG': [ + ['N', 0, (-0.524, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.524, -0.778, -1.209)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG', 4, (0.616, 1.390, -0.000)], + ['CD', 5, (0.564, 1.414, 0.000)], + ['NE', 6, (0.539, 1.357, -0.000)], + ['NH1', 7, (0.206, 2.301, 0.000)], + ['NH2', 7, (2.078, 0.978, -0.000)], + ['CZ', 7, (0.758, 1.093, -0.000)], + ], + 'ASN': [ + ['N', 0, (-0.536, 1.357, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.531, -0.787, -1.200)], + ['O', 3, (0.625, 1.062, 0.000)], + ['CG', 4, (0.584, 1.399, 0.000)], + ['ND2', 5, (0.593, -1.188, 0.001)], + ['OD1', 5, (0.633, 1.059, 0.000)], + ], + 'ASP': [ + ['N', 0, (-0.525, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, 0.000, -0.000)], + ['CB', 0, (-0.526, -0.778, -1.208)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.593, 1.398, -0.000)], + ['OD1', 5, (0.610, 1.091, 0.000)], + ['OD2', 5, (0.592, -1.101, -0.003)], + ], + 'CYS': [ + ['N', 0, (-0.522, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, 0.000, 0.000)], + ['CB', 0, (-0.519, -0.773, -1.212)], + ['O', 3, (0.625, 1.062, -0.000)], + ['SG', 4, (0.728, 1.653, 0.000)], + ], + 'GLN': [ + ['N', 0, (-0.526, 1.361, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, 0.000)], + ['CB', 0, (-0.525, -0.779, -1.207)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.615, 1.393, 0.000)], + ['CD', 5, (0.587, 1.399, -0.000)], + ['NE2', 6, (0.593, -1.189, -0.001)], + ['OE1', 6, (0.634, 1.060, 0.000)], + ], + 'GLU': [ + ['N', 0, (-0.528, 1.361, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.526, -0.781, -1.207)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG', 4, (0.615, 1.392, 0.000)], + ['CD', 5, (0.600, 1.397, 0.000)], + ['OE1', 6, (0.607, 1.095, -0.000)], + ['OE2', 6, (0.589, -1.104, -0.001)], + ], + 'GLY': [ + ['N', 0, (-0.572, 1.337, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.517, -0.000, -0.000)], + ['O', 3, (0.626, 1.062, -0.000)], + ], + 'HIS': [ + ['N', 0, (-0.527, 1.360, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, 0.000, 0.000)], + ['CB', 0, (-0.525, -0.778, -1.208)], + ['O', 3, (0.625, 1.063, 0.000)], + ['CG', 4, (0.600, 1.370, -0.000)], + ['CD2', 5, (0.889, -1.021, 0.003)], + ['ND1', 5, (0.744, 1.160, -0.000)], + ['CE1', 5, (2.030, 0.851, 0.002)], + ['NE2', 5, (2.145, -0.466, 0.004)], + ], + 'ILE': [ + ['N', 0, (-0.493, 1.373, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, -0.000)], + ['CB', 0, (-0.536, -0.793, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG1', 4, (0.534, 1.437, -0.000)], + ['CG2', 4, (0.540, -0.785, -1.199)], + ['CD1', 5, (0.619, 1.391, 0.000)], + ], + 'LEU': [ + ['N', 0, (-0.520, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.522, -0.773, -1.214)], + ['O', 3, (0.625, 1.063, -0.000)], + ['CG', 4, (0.678, 1.371, 0.000)], + ['CD1', 5, (0.530, 1.430, -0.000)], + ['CD2', 5, (0.535, -0.774, 1.200)], + ], + 'LYS': [ + ['N', 0, (-0.526, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, 0.000)], + ['CB', 0, (-0.524, -0.778, -1.208)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.619, 1.390, 0.000)], + ['CD', 5, (0.559, 1.417, 0.000)], + ['CE', 6, (0.560, 1.416, 0.000)], + ['NZ', 7, (0.554, 1.387, 0.000)], + ], + 'MET': [ + ['N', 0, (-0.521, 1.364, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, 0.000, 0.000)], + ['CB', 0, (-0.523, -0.776, -1.210)], + ['O', 3, (0.625, 1.062, -0.000)], + ['CG', 4, (0.613, 1.391, -0.000)], + ['SD', 5, (0.703, 1.695, 0.000)], + ['CE', 6, (0.320, 1.786, -0.000)], + ], + 'PHE': [ + ['N', 0, (-0.518, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, 0.000, -0.000)], + ['CB', 0, (-0.525, -0.776, -1.212)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.607, 1.377, 0.000)], + ['CD1', 5, (0.709, 1.195, -0.000)], + ['CD2', 5, (0.706, -1.196, 0.000)], + ['CE1', 5, (2.102, 1.198, -0.000)], + ['CE2', 5, (2.098, -1.201, -0.000)], + ['CZ', 5, (2.794, -0.003, -0.001)], + ], + 'PRO': [ + ['N', 0, (-0.566, 1.351, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, 0.000)], + ['CB', 0, (-0.546, -0.611, -1.293)], + ['O', 3, (0.621, 1.066, 0.000)], + ['CG', 4, (0.382, 1.445, 0.0)], + # ['CD', 5, (0.427, 1.440, 0.0)], + ['CD', 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger + ], + 'SER': [ + ['N', 0, (-0.529, 1.360, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.518, -0.777, -1.211)], + ['O', 3, (0.626, 1.062, -0.000)], + ['OG', 4, (0.503, 1.325, 0.000)], + ], + 'THR': [ + ['N', 0, (-0.517, 1.364, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, -0.000)], + ['CB', 0, (-0.516, -0.793, -1.215)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG2', 4, (0.550, -0.718, -1.228)], + ['OG1', 4, (0.472, 1.353, 0.000)], + ], + 'TRP': [ + ['N', 0, (-0.521, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, 0.000)], + ['CB', 0, (-0.523, -0.776, -1.212)], + ['O', 3, (0.627, 1.062, 0.000)], + ['CG', 4, (0.609, 1.370, -0.000)], + ['CD1', 5, (0.824, 1.091, 0.000)], + ['CD2', 5, (0.854, -1.148, -0.005)], + ['CE2', 5, (2.186, -0.678, -0.007)], + ['CE3', 5, (0.622, -2.530, -0.007)], + ['NE1', 5, (2.140, 0.690, -0.004)], + ['CH2', 5, (3.028, -2.890, -0.013)], + ['CZ2', 5, (3.283, -1.543, -0.011)], + ['CZ3', 5, (1.715, -3.389, -0.011)], + ], + 'TYR': [ + ['N', 0, (-0.522, 1.362, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, -0.000, -0.000)], + ['CB', 0, (-0.522, -0.776, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG', 4, (0.607, 1.382, -0.000)], + ['CD1', 5, (0.716, 1.195, -0.000)], + ['CD2', 5, (0.713, -1.194, -0.001)], + ['CE1', 5, (2.107, 1.200, -0.002)], + ['CE2', 5, (2.104, -1.201, -0.003)], + ['OH', 5, (4.168, -0.002, -0.005)], + ['CZ', 5, (2.791, -0.001, -0.003)], + ], + 'VAL': [ + ['N', 0, (-0.494, 1.373, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, -0.000)], + ['CB', 0, (-0.533, -0.795, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG1', 4, (0.540, 1.429, -0.000)], + ['CG2', 4, (0.533, -0.776, 1.203)], + ], +} + +# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention. +residue_atoms = { + 'ALA': ['C', 'CA', 'CB', 'N', 'O'], + 'ARG': ['C', 'CA', 'CB', 'CG', 'CD', 'CZ', 'N', 'NE', 'O', 'NH1', 'NH2'], + 'ASP': ['C', 'CA', 'CB', 'CG', 'N', 'O', 'OD1', 'OD2'], + 'ASN': ['C', 'CA', 'CB', 'CG', 'N', 'ND2', 'O', 'OD1'], + 'CYS': ['C', 'CA', 'CB', 'N', 'O', 'SG'], + 'GLU': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O', 'OE1', 'OE2'], + 'GLN': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'NE2', 'O', 'OE1'], + 'GLY': ['C', 'CA', 'N', 'O'], + 'HIS': ['C', 'CA', 'CB', 'CG', 'CD2', 'CE1', 'N', 'ND1', 'NE2', 'O'], + 'ILE': ['C', 'CA', 'CB', 'CG1', 'CG2', 'CD1', 'N', 'O'], + 'LEU': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'N', 'O'], + 'LYS': ['C', 'CA', 'CB', 'CG', 'CD', 'CE', 'N', 'NZ', 'O'], + 'MET': ['C', 'CA', 'CB', 'CG', 'CE', 'N', 'O', 'SD'], + 'PHE': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O'], + 'PRO': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O'], + 'SER': ['C', 'CA', 'CB', 'N', 'O', 'OG'], + 'THR': ['C', 'CA', 'CB', 'CG2', 'N', 'O', 'OG1'], + 'TRP': [ + 'C', + 'CA', + 'CB', + 'CG', + 'CD1', + 'CD2', + 'CE2', + 'CE3', + 'CZ2', + 'CZ3', + 'CH2', + 'N', + 'NE1', + 'O', + ], + 'TYR': + ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O', 'OH'], + 'VAL': ['C', 'CA', 'CB', 'CG1', 'CG2', 'N', 'O'], +} + +# Naming swaps for ambiguous atom names. +# Due to symmetries in the amino acids the naming of atoms is ambiguous in +# 4 of the 20 amino acids. +# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities +# in LEU, VAL and ARG can be resolved by using the 3d constellations of +# the 'ambiguous' atoms and their neighbours) +residue_atom_renaming_swaps = { + 'ASP': { + 'OD1': 'OD2' + }, + 'GLU': { + 'OE1': 'OE2' + }, + 'PHE': { + 'CD1': 'CD2', + 'CE1': 'CE2' + }, + 'TYR': { + 'CD1': 'CD2', + 'CE1': 'CE2' + }, +} + +# Van der Waals radii [Angstroem] of the atoms (from Wikipedia) +van_der_waals_radius = { + 'C': 1.7, + 'N': 1.55, + 'O': 1.52, + 'S': 1.8, +} + +Bond = collections.namedtuple('Bond', + ['atom1_name', 'atom2_name', 'length', 'stddev']) +BondAngle = collections.namedtuple( + 'BondAngle', + ['atom1_name', 'atom2_name', 'atom3name', 'angle_rad', 'stddev']) + + +@functools.lru_cache(maxsize=None) +# def load_stereo_chemical_props() -> Tuple[Mapping[str, List[Bond]], Mapping[ #noqa +# str, List[Bond]], Mapping[str, List[BondAngle]]]: +def load_stereo_chemical_props(): + """Load stereo_chemical_props.txt into a nice structure. + + Load literature values for bond lengths and bond angles and translate + bond angles into the length of the opposite edge of the triangle + ("residue_virtual_bonds"). + + Returns: + residue_bonds: Dict that maps resname -> list of Bond tuples. + residue_virtual_bonds: Dict that maps resname -> list of Bond tuples. + residue_bond_angles: Dict that maps resname -> list of BondAngle tuples. + """ + stereo_chemical_props_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + 'stereo_chemical_props.txt') + with open(stereo_chemical_props_path, 'rt') as f: + stereo_chemical_props = f.read() + lines_iter = iter(stereo_chemical_props.splitlines()) + # Load bond lengths. + residue_bonds = {} + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == '-': + break + bond, resname, length, stddev = line.split() + atom1, atom2 = bond.split('-') + if resname not in residue_bonds: + residue_bonds[resname] = [] + residue_bonds[resname].append( + Bond(atom1, atom2, float(length), float(stddev))) + residue_bonds['UNK'] = [] + + # Load bond angles. + residue_bond_angles = {} + next(lines_iter) # Skip empty line. + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == '-': + break + bond, resname, angle_degree, stddev_degree = line.split() + atom1, atom2, atom3 = bond.split('-') + if resname not in residue_bond_angles: + residue_bond_angles[resname] = [] + residue_bond_angles[resname].append( + BondAngle( + atom1, + atom2, + atom3, + float(angle_degree) / 180.0 * np.pi, + float(stddev_degree) / 180.0 * np.pi, + )) + residue_bond_angles['UNK'] = [] + + def make_bond_key(atom1_name, atom2_name): + """Unique key to lookup bonds.""" + return '-'.join(sorted([atom1_name, atom2_name])) + + # Translate bond angles into distances ("virtual bonds"). + residue_virtual_bonds = {} + for resname, bond_angles in residue_bond_angles.items(): + # Create a fast lookup dict for bond lengths. + bond_cache = {} + for b in residue_bonds[resname]: + bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b + residue_virtual_bonds[resname] = [] + for ba in bond_angles: + bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)] + bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)] + + # Compute distance between atom1 and atom3 using the law of cosines + # c^2 = a^2 + b^2 - 2ab*cos(gamma). + gamma = ba.angle_rad + length = np.sqrt(bond1.length**2 + bond2.length**2 + - 2 * bond1.length * bond2.length * np.cos(gamma)) + + # Propagation of uncertainty assuming uncorrelated errors. + dl_outer = 0.5 / length + dl_dgamma = (2 * bond1.length * bond2.length + * np.sin(gamma)) * dl_outer + dl_db1 = (2 * bond1.length + - 2 * bond2.length * np.cos(gamma)) * dl_outer + dl_db2 = (2 * bond2.length + - 2 * bond1.length * np.cos(gamma)) * dl_outer + stddev = np.sqrt((dl_dgamma * ba.stddev)**2 + + (dl_db1 * bond1.stddev)**2 + + (dl_db2 * bond2.stddev)**2) + residue_virtual_bonds[resname].append( + Bond(ba.atom1_name, ba.atom3name, length, stddev)) + + return (residue_bonds, residue_virtual_bonds, residue_bond_angles) + + +# Between-residue bond lengths for general bonds (first element) and for Proline +# (second element). +between_res_bond_length_c_n = [1.329, 1.341] +between_res_bond_length_stddev_c_n = [0.014, 0.016] + +# Between-residue cos_angles. +between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315 +between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995 + +# This mapping is used when we need to store atom data in a format that requires +# fixed atom data size for every residue (e.g. a numpy array). +atom_types = [ + 'N', + 'CA', + 'C', + 'CB', + 'O', + 'CG', + 'CG1', + 'CG2', + 'OG', + 'OG1', + 'SG', + 'CD', + 'CD1', + 'CD2', + 'ND1', + 'ND2', + 'OD1', + 'OD2', + 'SD', + 'CE', + 'CE1', + 'CE2', + 'CE3', + 'NE', + 'NE1', + 'NE2', + 'OE1', + 'OE2', + 'CH2', + 'NH1', + 'NH2', + 'OH', + 'CZ', + 'CZ2', + 'CZ3', + 'NZ', + 'OXT', +] +atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} +atom_type_num = len(atom_types) # := 37. + +# A compact atom encoding with 14 columns +# pylint: disable=line-too-long +# pylint: disable=bad-whitespace +restype_name_to_atom14_names = { + 'ALA': ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''], + 'ARG': [ + 'N', + 'CA', + 'C', + 'O', + 'CB', + 'CG', + 'CD', + 'NE', + 'CZ', + 'NH1', + 'NH2', + '', + '', + '', + ], + 'ASN': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''], + 'ASP': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''], + 'CYS': ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''], + 'GLN': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''], + 'GLU': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''], + 'GLY': ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''], + 'HIS': [ + 'N', + 'CA', + 'C', + 'O', + 'CB', + 'CG', + 'ND1', + 'CD2', + 'CE1', + 'NE2', + '', + '', + '', + '', + ], + 'ILE': + ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''], + 'LEU': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''], + 'LYS': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''], + 'MET': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''], + 'PHE': [ + 'N', + 'CA', + 'C', + 'O', + 'CB', + 'CG', + 'CD1', + 'CD2', + 'CE1', + 'CE2', + 'CZ', + '', + '', + '', + ], + 'PRO': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''], + 'SER': ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''], + 'THR': + ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''], + 'TRP': [ + 'N', + 'CA', + 'C', + 'O', + 'CB', + 'CG', + 'CD1', + 'CD2', + 'NE1', + 'CE2', + 'CE3', + 'CZ2', + 'CZ3', + 'CH2', + ], + 'TYR': [ + 'N', + 'CA', + 'C', + 'O', + 'CB', + 'CG', + 'CD1', + 'CD2', + 'CE1', + 'CE2', + 'CZ', + 'OH', + '', + '', + ], + 'VAL': + ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''], + 'UNK': ['', '', '', '', '', '', '', '', '', '', '', '', '', ''], +} +# pylint: enable=line-too-long +# pylint: enable=bad-whitespace + +# This is the standard residue order when coding AA type as a number. +# Reproduce it by taking 3-letter AA codes and sorting them alphabetically. +restypes = [ + 'A', + 'R', + 'N', + 'D', + 'C', + 'Q', + 'E', + 'G', + 'H', + 'I', + 'L', + 'K', + 'M', + 'F', + 'P', + 'S', + 'T', + 'W', + 'Y', + 'V', +] +restype_order = {restype: i for i, restype in enumerate(restypes)} +restype_num = len(restypes) # := 20. +unk_restype_index = restype_num # Catch-all index for unknown restypes. + +restypes_with_x = restypes + ['X'] +restype_order_with_x = { + restype: i + for i, restype in enumerate(restypes_with_x) +} + + +def sequence_to_onehot(sequence: str, + mapping: Mapping[str, int], + map_unknown_to_x: bool = False) -> np.ndarray: + """Maps the given sequence into a one-hot encoded matrix. + + Args: + sequence: An amino acid sequence. + mapping: A dictionary mapping amino acids to integers. + map_unknown_to_x: If True, any amino acid that is not in the mapping will be + mapped to the unknown amino acid 'X'. If the mapping doesn't contain + amino acid 'X', an error will be thrown. If False, any amino acid not in + the mapping will throw an error. + + Returns: + A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of + the sequence. + + Raises: + ValueError: If the mapping doesn't contain values from 0 to + num_unique_aas - 1 without any gaps. + """ + num_entries = max(mapping.values()) + 1 + + if sorted(set(mapping.values())) != list(range(num_entries)): + raise ValueError( + 'The mapping must have values from 0 to num_unique_aas-1 ' + 'without any gaps. Got: %s' % sorted(mapping.values())) + + one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32) + + for aa_index, aa_type in enumerate(sequence): + if map_unknown_to_x: + if aa_type.isalpha() and aa_type.isupper(): + aa_id = mapping.get(aa_type, mapping['X']) + else: + raise ValueError( + f'Invalid character in the sequence: {aa_type}') + else: + aa_id = mapping[aa_type] + one_hot_arr[aa_index, aa_id] = 1 + + return one_hot_arr + + +restype_1to3 = { + 'A': 'ALA', + 'R': 'ARG', + 'N': 'ASN', + 'D': 'ASP', + 'C': 'CYS', + 'Q': 'GLN', + 'E': 'GLU', + 'G': 'GLY', + 'H': 'HIS', + 'I': 'ILE', + 'L': 'LEU', + 'K': 'LYS', + 'M': 'MET', + 'F': 'PHE', + 'P': 'PRO', + 'S': 'SER', + 'T': 'THR', + 'W': 'TRP', + 'Y': 'TYR', + 'V': 'VAL', +} + +# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple +# 1-to-1 mapping of 3 letter names to one letter names. The latter contains +# many more, and less common, three letter names as keys and maps many of these +# to the same one letter name (including 'X' and 'U' which we don't use here). +restype_3to1 = {v: k for k, v in restype_1to3.items()} + +# Define a restype name for all unknown residues. +unk_restype = 'UNK' + +resnames = [restype_1to3[r] for r in restypes] + [unk_restype] +resname_to_idx = {resname: i for i, resname in enumerate(resnames)} + +# The mapping here uses hhblits convention, so that B is mapped to D, J and O +# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the +# remaining 20 amino acids are kept in alphabetical order. +# There are 2 non-amino acid codes, X (representing any amino acid) and +# "-" representing a missing amino acid in an alignment. The id for these +# codes is put at the end (20 and 21) so that they can easily be ignored if +# desired. +HHBLITS_AA_TO_ID = { + 'A': 0, + 'B': 2, + 'C': 1, + 'D': 2, + 'E': 3, + 'F': 4, + 'G': 5, + 'H': 6, + 'I': 7, + 'J': 20, + 'K': 8, + 'L': 9, + 'M': 10, + 'N': 11, + 'O': 20, + 'P': 12, + 'Q': 13, + 'R': 14, + 'S': 15, + 'T': 16, + 'U': 1, + 'V': 17, + 'W': 18, + 'X': 20, + 'Y': 19, + 'Z': 3, + '-': 21, +} + +# Partial inversion of HHBLITS_AA_TO_ID. +ID_TO_HHBLITS_AA = { + 0: 'A', + 1: 'C', # Also U. + 2: 'D', # Also B. + 3: 'E', # Also Z. + 4: 'F', + 5: 'G', + 6: 'H', + 7: 'I', + 8: 'K', + 9: 'L', + 10: 'M', + 11: 'N', + 12: 'P', + 13: 'Q', + 14: 'R', + 15: 'S', + 16: 'T', + 17: 'V', + 18: 'W', + 19: 'Y', + 20: 'X', # Includes J and O. + 21: '-', +} + +restypes_with_x_and_gap = restypes + ['X', '-'] +MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple( + restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i]) + for i in range(len(restypes_with_x_and_gap))) + + +def _make_standard_atom_mask() -> np.ndarray: + """Returns [num_res_types, num_atom_types] mask array.""" + # +1 to account for unknown (all 0s). + mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32) + for restype, restype_letter in enumerate(restypes): + restype_name = restype_1to3[restype_letter] + atom_names = residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = atom_order[atom_name] + mask[restype, atom_type] = 1 + return mask + + +STANDARD_ATOM_MASK = _make_standard_atom_mask() + + +# A one hot representation for the first and second atoms defining the axis +# of rotation for each chi-angle in each residue. +def chi_angle_atom(atom_index: int) -> np.ndarray: + """Define chi-angle rigid groups via one-hot representations.""" + chi_angles_index = {} + one_hots = [] + + for k, v in chi_angles_atoms.items(): + indices = [atom_types.index(s[atom_index]) for s in v] + indices.extend([-1] * (4 - len(indices))) + chi_angles_index[k] = indices + + for r in restypes: + res3 = restype_1to3[r] + one_hot = np.eye(atom_type_num)[chi_angles_index[res3]] + one_hots.append(one_hot) + + one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`. + one_hot = np.stack(one_hots, axis=0) + one_hot = np.transpose(one_hot, [0, 2, 1]) + + return one_hot + + +chi_atom_1_one_hot = chi_angle_atom(1) +chi_atom_2_one_hot = chi_angle_atom(2) + +# An array like chi_angles_atoms but using indices rather than names. +chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes] +chi_angles_atom_indices = tree_map( + lambda n: atom_order[n], chi_angles_atom_indices, leaf_type=str) +chi_angles_atom_indices = np.array([ + chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms))) + for chi_atoms in chi_angles_atom_indices +]) + +# Mapping from (res_name, atom_name) pairs to the atom's chi group index +# and atom index within that group. +chi_groups_for_atom = collections.defaultdict(list) +for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items(): + for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res): + for atom_i, atom in enumerate(chi_group): + chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i)) +chi_groups_for_atom = dict(chi_groups_for_atom) + + +def _make_rigid_transformation_4x4(ex, ey, translation): + """Create a rigid 4x4 transformation matrix from two axes and transl.""" + # Normalize ex. + ex_normalized = ex / np.linalg.norm(ex) + + # make ey perpendicular to ex + ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized + ey_normalized /= np.linalg.norm(ey_normalized) + + # compute ez as cross product + eznorm = np.cross(ex_normalized, ey_normalized) + m = np.stack([ex_normalized, ey_normalized, eznorm, + translation]).transpose() + m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0) + return m + + +# create an array with (restype, atomtype) --> rigid_group_idx +# and an array with (restype, atomtype, coord) for the atom positions +# and compute affine transformation matrices (4,4) from one rigid group to the +# previous group +restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int_) +restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) +restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32) +restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int_) +restype_atom14_mask = np.zeros([21, 14], dtype=np.float32) +restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32) +restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32) + + +def _make_rigid_group_constants(): + """Fill the arrays above.""" + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + for atomname, group_idx, atom_position in rigid_group_atom_positions[ + resname]: + atomtype = atom_order[atomname] + restype_atom37_to_rigid_group[restype, atomtype] = group_idx + restype_atom37_mask[restype, atomtype] = 1 + restype_atom37_rigid_group_positions[restype, + atomtype, :] = atom_position + + atom14idx = restype_name_to_atom14_names[resname].index(atomname) + restype_atom14_to_rigid_group[restype, atom14idx] = group_idx + restype_atom14_mask[restype, atom14idx] = 1 + restype_atom14_rigid_group_positions[restype, + atom14idx, :] = atom_position + + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_positions = { + name: np.array(pos) + for name, _, pos in rigid_group_atom_positions[resname] + } + + # backbone to backbone is the identity transform + restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4) + + # pre-omega-frame to backbone (currently dummy identity matrix) + restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4) + + # phi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions['N'] - atom_positions['CA'], + ey=np.array([1.0, 0.0, 0.0]), + translation=atom_positions['N'], + ) + restype_rigid_group_default_frame[restype, 2, :, :] = mat + + # psi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions['C'] - atom_positions['CA'], + ey=atom_positions['CA'] - atom_positions['N'], + translation=atom_positions['C'], + ) + restype_rigid_group_default_frame[restype, 3, :, :] = mat + + # chi1-frame to backbone + if chi_angles_mask[restype][0]: + base_atom_names = chi_angles_atoms[resname][0] + base_atom_positions = [ + atom_positions[name] for name in base_atom_names + ] + mat = _make_rigid_transformation_4x4( + ex=base_atom_positions[2] - base_atom_positions[1], + ey=base_atom_positions[0] - base_atom_positions[1], + translation=base_atom_positions[2], + ) + restype_rigid_group_default_frame[restype, 4, :, :] = mat + + # chi2-frame to chi1-frame + # chi3-frame to chi2-frame + # chi4-frame to chi3-frame + # luckily all rotation axes for the next frame start at (0,0,0) of the + # previous frame + for chi_idx in range(1, 4): + if chi_angles_mask[restype][chi_idx]: + axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2] + axis_end_atom_position = atom_positions[axis_end_atom_name] + mat = _make_rigid_transformation_4x4( + ex=axis_end_atom_position, + ey=np.array([-1.0, 0.0, 0.0]), + translation=axis_end_atom_position, + ) + restype_rigid_group_default_frame[restype, + 4 + chi_idx, :, :] = mat + + +_make_rigid_group_constants() + + +def make_atom14_dists_bounds(overlap_tolerance=1.5, + bond_length_tolerance_factor=15): + """compute upper and lower bounds for bonds to assess violations.""" + restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32) + residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props() + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_list = restype_name_to_atom14_names[resname] + + # create lower and upper bounds for clashes + for atom1_idx, atom1_name in enumerate(atom_list): + if not atom1_name: + continue + atom1_radius = van_der_waals_radius[atom1_name[0]] + for atom2_idx, atom2_name in enumerate(atom_list): + if (not atom2_name) or atom1_idx == atom2_idx: + continue + atom2_radius = van_der_waals_radius[atom2_name[0]] + lower = atom1_radius + atom2_radius - overlap_tolerance + upper = 1e10 + restype_atom14_bond_lower_bound[restype, atom1_idx, + atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, atom2_idx, + atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, atom1_idx, + atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, atom2_idx, + atom1_idx] = upper + + # overwrite lower and upper bounds for bonds and angles + for b in residue_bonds[resname] + residue_virtual_bonds[resname]: + atom1_idx = atom_list.index(b.atom1_name) + atom2_idx = atom_list.index(b.atom2_name) + lower = b.length - bond_length_tolerance_factor * b.stddev + upper = b.length + bond_length_tolerance_factor * b.stddev + restype_atom14_bond_lower_bound[restype, atom1_idx, + atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, atom2_idx, + atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, atom1_idx, + atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, atom2_idx, + atom1_idx] = upper + restype_atom14_bond_stddev[restype, atom1_idx, + atom2_idx] = b.stddev + restype_atom14_bond_stddev[restype, atom2_idx, + atom1_idx] = b.stddev + return { + 'lower_bound': restype_atom14_bond_lower_bound, # shape (21,14,14) + 'upper_bound': restype_atom14_bond_upper_bound, # shape (21,14,14) + 'stddev': restype_atom14_bond_stddev, # shape (21,14,14) + } + + +def _make_atom14_and_atom37_constants(): + restype_atom14_to_atom37 = [] + restype_atom37_to_atom14 = [] + restype_atom14_mask = [] + + for rt in restypes: + atom_names = restype_name_to_atom14_names[restype_1to3[rt]] + restype_atom14_to_atom37.append([(atom_order[name] if name else 0) + for name in atom_names]) + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + restype_atom37_to_atom14.append([ + (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) + for name in atom_types + ]) + + restype_atom14_mask.append([(1.0 if name else 0.0) + for name in atom_names]) + + # Add dummy mapping for restype 'UNK' + restype_atom14_to_atom37.append([0] * 14) + restype_atom37_to_atom14.append([0] * 37) + restype_atom14_mask.append([0.0] * 14) + + restype_atom14_to_atom37 = np.array( + restype_atom14_to_atom37, dtype=np.int32) + restype_atom37_to_atom14 = np.array( + restype_atom37_to_atom14, dtype=np.int32) + restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32) + + return restype_atom14_to_atom37, restype_atom37_to_atom14, restype_atom14_mask + + +( + restype_atom14_to_atom37, + restype_atom37_to_atom14, + restype_atom14_mask, +) = _make_atom14_and_atom37_constants() + + +def _make_renaming_matrices(): + # As the atom naming is ambiguous for 7 of the 20 amino acids, provide + # alternative ground truth coordinates where the naming is swapped + restype_3 = [restype_1to3[res] for res in restypes] + restype_3 += ['UNK'] + + # Matrices for renaming ambiguous atoms. + all_matrices = {res: np.eye(14) for res in restype_3} + for resname, swap in residue_atom_renaming_swaps.items(): + correspondences = np.arange(14) + for source_atom_swap, target_atom_swap in swap.items(): + source_index = restype_name_to_atom14_names[resname].index( + source_atom_swap) + target_index = restype_name_to_atom14_names[resname].index( + target_atom_swap) + correspondences[source_index] = target_index + correspondences[target_index] = source_index + renaming_matrix = np.zeros((14, 14)) + for index, correspondence in enumerate(correspondences): + renaming_matrix[index, correspondence] = 1.0 + all_matrices[resname] = renaming_matrix + renaming_matrices = np.stack( + [all_matrices[restype] for restype in restype_3]) + return renaming_matrices + + +renaming_matrices = _make_renaming_matrices() + + +def _make_atom14_is_ambiguous(): + # Create an ambiguous atoms mask. shape: (21, 14). + restype_atom14_is_ambiguous = np.zeros((21, 14)) + for resname, swap in residue_atom_renaming_swaps.items(): + for atom_name1, atom_name2 in swap.items(): + restype = restype_order[restype_3to1[resname]] + atom_idx1 = restype_name_to_atom14_names[resname].index(atom_name1) + atom_idx2 = restype_name_to_atom14_names[resname].index(atom_name2) + restype_atom14_is_ambiguous[restype, atom_idx1] = 1 + restype_atom14_is_ambiguous[restype, atom_idx2] = 1 + return restype_atom14_is_ambiguous + + +restype_atom14_is_ambiguous = _make_atom14_is_ambiguous() + + +def get_chi_atom_indices(): + """Returns atom indices needed to compute chi angles for all residue types. + + Returns: + A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are + in the order specified in restypes + unknown residue type + at the end. For chi angles which are not defined on the residue, the + positions indices are by default set to 0. + """ + chi_atom_indices = [] + for residue_name in restypes: + residue_name = restype_1to3[residue_name] + residue_chi_angles = chi_angles_atoms[residue_name] + atom_indices = [] + for chi_angle in residue_chi_angles: + atom_indices.append([atom_order[atom] for atom in chi_angle]) + for _ in range(4 - len(atom_indices)): + atom_indices.append([0, 0, 0, + 0]) # For chi angles not defined on the AA. + chi_atom_indices.append(atom_indices) + + chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. + + return chi_atom_indices + + +chi_atom_indices = get_chi_atom_indices() diff --git a/modelscope/models/science/unifold/data/stereo_chemical_props.txt b/modelscope/models/science/unifold/data/stereo_chemical_props.txt new file mode 100644 index 00000000..25262efd --- /dev/null +++ b/modelscope/models/science/unifold/data/stereo_chemical_props.txt @@ -0,0 +1,345 @@ +Bond Residue Mean StdDev +CA-CB ALA 1.520 0.021 +N-CA ALA 1.459 0.020 +CA-C ALA 1.525 0.026 +C-O ALA 1.229 0.019 +CA-CB ARG 1.535 0.022 +CB-CG ARG 1.521 0.027 +CG-CD ARG 1.515 0.025 +CD-NE ARG 1.460 0.017 +NE-CZ ARG 1.326 0.013 +CZ-NH1 ARG 1.326 0.013 +CZ-NH2 ARG 1.326 0.013 +N-CA ARG 1.459 0.020 +CA-C ARG 1.525 0.026 +C-O ARG 1.229 0.019 +CA-CB ASN 1.527 0.026 +CB-CG ASN 1.506 0.023 +CG-OD1 ASN 1.235 0.022 +CG-ND2 ASN 1.324 0.025 +N-CA ASN 1.459 0.020 +CA-C ASN 1.525 0.026 +C-O ASN 1.229 0.019 +CA-CB ASP 1.535 0.022 +CB-CG ASP 1.513 0.021 +CG-OD1 ASP 1.249 0.023 +CG-OD2 ASP 1.249 0.023 +N-CA ASP 1.459 0.020 +CA-C ASP 1.525 0.026 +C-O ASP 1.229 0.019 +CA-CB CYS 1.526 0.013 +CB-SG CYS 1.812 0.016 +N-CA CYS 1.459 0.020 +CA-C CYS 1.525 0.026 +C-O CYS 1.229 0.019 +CA-CB GLU 1.535 0.022 +CB-CG GLU 1.517 0.019 +CG-CD GLU 1.515 0.015 +CD-OE1 GLU 1.252 0.011 +CD-OE2 GLU 1.252 0.011 +N-CA GLU 1.459 0.020 +CA-C GLU 1.525 0.026 +C-O GLU 1.229 0.019 +CA-CB GLN 1.535 0.022 +CB-CG GLN 1.521 0.027 +CG-CD GLN 1.506 0.023 +CD-OE1 GLN 1.235 0.022 +CD-NE2 GLN 1.324 0.025 +N-CA GLN 1.459 0.020 +CA-C GLN 1.525 0.026 +C-O GLN 1.229 0.019 +N-CA GLY 1.456 0.015 +CA-C GLY 1.514 0.016 +C-O GLY 1.232 0.016 +CA-CB HIS 1.535 0.022 +CB-CG HIS 1.492 0.016 +CG-ND1 HIS 1.369 0.015 +CG-CD2 HIS 1.353 0.017 +ND1-CE1 HIS 1.343 0.025 +CD2-NE2 HIS 1.415 0.021 +CE1-NE2 HIS 1.322 0.023 +N-CA HIS 1.459 0.020 +CA-C HIS 1.525 0.026 +C-O HIS 1.229 0.019 +CA-CB ILE 1.544 0.023 +CB-CG1 ILE 1.536 0.028 +CB-CG2 ILE 1.524 0.031 +CG1-CD1 ILE 1.500 0.069 +N-CA ILE 1.459 0.020 +CA-C ILE 1.525 0.026 +C-O ILE 1.229 0.019 +CA-CB LEU 1.533 0.023 +CB-CG LEU 1.521 0.029 +CG-CD1 LEU 1.514 0.037 +CG-CD2 LEU 1.514 0.037 +N-CA LEU 1.459 0.020 +CA-C LEU 1.525 0.026 +C-O LEU 1.229 0.019 +CA-CB LYS 1.535 0.022 +CB-CG LYS 1.521 0.027 +CG-CD LYS 1.520 0.034 +CD-CE LYS 1.508 0.025 +CE-NZ LYS 1.486 0.025 +N-CA LYS 1.459 0.020 +CA-C LYS 1.525 0.026 +C-O LYS 1.229 0.019 +CA-CB MET 1.535 0.022 +CB-CG MET 1.509 0.032 +CG-SD MET 1.807 0.026 +SD-CE MET 1.774 0.056 +N-CA MET 1.459 0.020 +CA-C MET 1.525 0.026 +C-O MET 1.229 0.019 +CA-CB PHE 1.535 0.022 +CB-CG PHE 1.509 0.017 +CG-CD1 PHE 1.383 0.015 +CG-CD2 PHE 1.383 0.015 +CD1-CE1 PHE 1.388 0.020 +CD2-CE2 PHE 1.388 0.020 +CE1-CZ PHE 1.369 0.019 +CE2-CZ PHE 1.369 0.019 +N-CA PHE 1.459 0.020 +CA-C PHE 1.525 0.026 +C-O PHE 1.229 0.019 +CA-CB PRO 1.531 0.020 +CB-CG PRO 1.495 0.050 +CG-CD PRO 1.502 0.033 +CD-N PRO 1.474 0.014 +N-CA PRO 1.468 0.017 +CA-C PRO 1.524 0.020 +C-O PRO 1.228 0.020 +CA-CB SER 1.525 0.015 +CB-OG SER 1.418 0.013 +N-CA SER 1.459 0.020 +CA-C SER 1.525 0.026 +C-O SER 1.229 0.019 +CA-CB THR 1.529 0.026 +CB-OG1 THR 1.428 0.020 +CB-CG2 THR 1.519 0.033 +N-CA THR 1.459 0.020 +CA-C THR 1.525 0.026 +C-O THR 1.229 0.019 +CA-CB TRP 1.535 0.022 +CB-CG TRP 1.498 0.018 +CG-CD1 TRP 1.363 0.014 +CG-CD2 TRP 1.432 0.017 +CD1-NE1 TRP 1.375 0.017 +NE1-CE2 TRP 1.371 0.013 +CD2-CE2 TRP 1.409 0.012 +CD2-CE3 TRP 1.399 0.015 +CE2-CZ2 TRP 1.393 0.017 +CE3-CZ3 TRP 1.380 0.017 +CZ2-CH2 TRP 1.369 0.019 +CZ3-CH2 TRP 1.396 0.016 +N-CA TRP 1.459 0.020 +CA-C TRP 1.525 0.026 +C-O TRP 1.229 0.019 +CA-CB TYR 1.535 0.022 +CB-CG TYR 1.512 0.015 +CG-CD1 TYR 1.387 0.013 +CG-CD2 TYR 1.387 0.013 +CD1-CE1 TYR 1.389 0.015 +CD2-CE2 TYR 1.389 0.015 +CE1-CZ TYR 1.381 0.013 +CE2-CZ TYR 1.381 0.013 +CZ-OH TYR 1.374 0.017 +N-CA TYR 1.459 0.020 +CA-C TYR 1.525 0.026 +C-O TYR 1.229 0.019 +CA-CB VAL 1.543 0.021 +CB-CG1 VAL 1.524 0.021 +CB-CG2 VAL 1.524 0.021 +N-CA VAL 1.459 0.020 +CA-C VAL 1.525 0.026 +C-O VAL 1.229 0.019 +- + +Angle Residue Mean StdDev +N-CA-CB ALA 110.1 1.4 +CB-CA-C ALA 110.1 1.5 +N-CA-C ALA 111.0 2.7 +CA-C-O ALA 120.1 2.1 +N-CA-CB ARG 110.6 1.8 +CB-CA-C ARG 110.4 2.0 +CA-CB-CG ARG 113.4 2.2 +CB-CG-CD ARG 111.6 2.6 +CG-CD-NE ARG 111.8 2.1 +CD-NE-CZ ARG 123.6 1.4 +NE-CZ-NH1 ARG 120.3 0.5 +NE-CZ-NH2 ARG 120.3 0.5 +NH1-CZ-NH2 ARG 119.4 1.1 +N-CA-C ARG 111.0 2.7 +CA-C-O ARG 120.1 2.1 +N-CA-CB ASN 110.6 1.8 +CB-CA-C ASN 110.4 2.0 +CA-CB-CG ASN 113.4 2.2 +CB-CG-ND2 ASN 116.7 2.4 +CB-CG-OD1 ASN 121.6 2.0 +ND2-CG-OD1 ASN 121.9 2.3 +N-CA-C ASN 111.0 2.7 +CA-C-O ASN 120.1 2.1 +N-CA-CB ASP 110.6 1.8 +CB-CA-C ASP 110.4 2.0 +CA-CB-CG ASP 113.4 2.2 +CB-CG-OD1 ASP 118.3 0.9 +CB-CG-OD2 ASP 118.3 0.9 +OD1-CG-OD2 ASP 123.3 1.9 +N-CA-C ASP 111.0 2.7 +CA-C-O ASP 120.1 2.1 +N-CA-CB CYS 110.8 1.5 +CB-CA-C CYS 111.5 1.2 +CA-CB-SG CYS 114.2 1.1 +N-CA-C CYS 111.0 2.7 +CA-C-O CYS 120.1 2.1 +N-CA-CB GLU 110.6 1.8 +CB-CA-C GLU 110.4 2.0 +CA-CB-CG GLU 113.4 2.2 +CB-CG-CD GLU 114.2 2.7 +CG-CD-OE1 GLU 118.3 2.0 +CG-CD-OE2 GLU 118.3 2.0 +OE1-CD-OE2 GLU 123.3 1.2 +N-CA-C GLU 111.0 2.7 +CA-C-O GLU 120.1 2.1 +N-CA-CB GLN 110.6 1.8 +CB-CA-C GLN 110.4 2.0 +CA-CB-CG GLN 113.4 2.2 +CB-CG-CD GLN 111.6 2.6 +CG-CD-OE1 GLN 121.6 2.0 +CG-CD-NE2 GLN 116.7 2.4 +OE1-CD-NE2 GLN 121.9 2.3 +N-CA-C GLN 111.0 2.7 +CA-C-O GLN 120.1 2.1 +N-CA-C GLY 113.1 2.5 +CA-C-O GLY 120.6 1.8 +N-CA-CB HIS 110.6 1.8 +CB-CA-C HIS 110.4 2.0 +CA-CB-CG HIS 113.6 1.7 +CB-CG-ND1 HIS 123.2 2.5 +CB-CG-CD2 HIS 130.8 3.1 +CG-ND1-CE1 HIS 108.2 1.4 +ND1-CE1-NE2 HIS 109.9 2.2 +CE1-NE2-CD2 HIS 106.6 2.5 +NE2-CD2-CG HIS 109.2 1.9 +CD2-CG-ND1 HIS 106.0 1.4 +N-CA-C HIS 111.0 2.7 +CA-C-O HIS 120.1 2.1 +N-CA-CB ILE 110.8 2.3 +CB-CA-C ILE 111.6 2.0 +CA-CB-CG1 ILE 111.0 1.9 +CB-CG1-CD1 ILE 113.9 2.8 +CA-CB-CG2 ILE 110.9 2.0 +CG1-CB-CG2 ILE 111.4 2.2 +N-CA-C ILE 111.0 2.7 +CA-C-O ILE 120.1 2.1 +N-CA-CB LEU 110.4 2.0 +CB-CA-C LEU 110.2 1.9 +CA-CB-CG LEU 115.3 2.3 +CB-CG-CD1 LEU 111.0 1.7 +CB-CG-CD2 LEU 111.0 1.7 +CD1-CG-CD2 LEU 110.5 3.0 +N-CA-C LEU 111.0 2.7 +CA-C-O LEU 120.1 2.1 +N-CA-CB LYS 110.6 1.8 +CB-CA-C LYS 110.4 2.0 +CA-CB-CG LYS 113.4 2.2 +CB-CG-CD LYS 111.6 2.6 +CG-CD-CE LYS 111.9 3.0 +CD-CE-NZ LYS 111.7 2.3 +N-CA-C LYS 111.0 2.7 +CA-C-O LYS 120.1 2.1 +N-CA-CB MET 110.6 1.8 +CB-CA-C MET 110.4 2.0 +CA-CB-CG MET 113.3 1.7 +CB-CG-SD MET 112.4 3.0 +CG-SD-CE MET 100.2 1.6 +N-CA-C MET 111.0 2.7 +CA-C-O MET 120.1 2.1 +N-CA-CB PHE 110.6 1.8 +CB-CA-C PHE 110.4 2.0 +CA-CB-CG PHE 113.9 2.4 +CB-CG-CD1 PHE 120.8 0.7 +CB-CG-CD2 PHE 120.8 0.7 +CD1-CG-CD2 PHE 118.3 1.3 +CG-CD1-CE1 PHE 120.8 1.1 +CG-CD2-CE2 PHE 120.8 1.1 +CD1-CE1-CZ PHE 120.1 1.2 +CD2-CE2-CZ PHE 120.1 1.2 +CE1-CZ-CE2 PHE 120.0 1.8 +N-CA-C PHE 111.0 2.7 +CA-C-O PHE 120.1 2.1 +N-CA-CB PRO 103.3 1.2 +CB-CA-C PRO 111.7 2.1 +CA-CB-CG PRO 104.8 1.9 +CB-CG-CD PRO 106.5 3.9 +CG-CD-N PRO 103.2 1.5 +CA-N-CD PRO 111.7 1.4 +N-CA-C PRO 112.1 2.6 +CA-C-O PRO 120.2 2.4 +N-CA-CB SER 110.5 1.5 +CB-CA-C SER 110.1 1.9 +CA-CB-OG SER 111.2 2.7 +N-CA-C SER 111.0 2.7 +CA-C-O SER 120.1 2.1 +N-CA-CB THR 110.3 1.9 +CB-CA-C THR 111.6 2.7 +CA-CB-OG1 THR 109.0 2.1 +CA-CB-CG2 THR 112.4 1.4 +OG1-CB-CG2 THR 110.0 2.3 +N-CA-C THR 111.0 2.7 +CA-C-O THR 120.1 2.1 +N-CA-CB TRP 110.6 1.8 +CB-CA-C TRP 110.4 2.0 +CA-CB-CG TRP 113.7 1.9 +CB-CG-CD1 TRP 127.0 1.3 +CB-CG-CD2 TRP 126.6 1.3 +CD1-CG-CD2 TRP 106.3 0.8 +CG-CD1-NE1 TRP 110.1 1.0 +CD1-NE1-CE2 TRP 109.0 0.9 +NE1-CE2-CD2 TRP 107.3 1.0 +CE2-CD2-CG TRP 107.3 0.8 +CG-CD2-CE3 TRP 133.9 0.9 +NE1-CE2-CZ2 TRP 130.4 1.1 +CE3-CD2-CE2 TRP 118.7 1.2 +CD2-CE2-CZ2 TRP 122.3 1.2 +CE2-CZ2-CH2 TRP 117.4 1.0 +CZ2-CH2-CZ3 TRP 121.6 1.2 +CH2-CZ3-CE3 TRP 121.2 1.1 +CZ3-CE3-CD2 TRP 118.8 1.3 +N-CA-C TRP 111.0 2.7 +CA-C-O TRP 120.1 2.1 +N-CA-CB TYR 110.6 1.8 +CB-CA-C TYR 110.4 2.0 +CA-CB-CG TYR 113.4 1.9 +CB-CG-CD1 TYR 121.0 0.6 +CB-CG-CD2 TYR 121.0 0.6 +CD1-CG-CD2 TYR 117.9 1.1 +CG-CD1-CE1 TYR 121.3 0.8 +CG-CD2-CE2 TYR 121.3 0.8 +CD1-CE1-CZ TYR 119.8 0.9 +CD2-CE2-CZ TYR 119.8 0.9 +CE1-CZ-CE2 TYR 119.8 1.6 +CE1-CZ-OH TYR 120.1 2.7 +CE2-CZ-OH TYR 120.1 2.7 +N-CA-C TYR 111.0 2.7 +CA-C-O TYR 120.1 2.1 +N-CA-CB VAL 111.5 2.2 +CB-CA-C VAL 111.4 1.9 +CA-CB-CG1 VAL 110.9 1.5 +CA-CB-CG2 VAL 110.9 1.5 +CG1-CB-CG2 VAL 110.9 1.6 +N-CA-C VAL 111.0 2.7 +CA-C-O VAL 120.1 2.1 +- + +Non-bonded distance Minimum Dist Tolerance +C-C 3.4 1.5 +C-N 3.25 1.5 +C-S 3.5 1.5 +C-O 3.22 1.5 +N-N 3.1 1.5 +N-S 3.35 1.5 +N-O 3.07 1.5 +O-S 3.32 1.5 +O-O 3.04 1.5 +S-S 2.03 1.0 +- diff --git a/modelscope/models/science/unifold/data/utils.py b/modelscope/models/science/unifold/data/utils.py new file mode 100644 index 00000000..2be91ef0 --- /dev/null +++ b/modelscope/models/science/unifold/data/utils.py @@ -0,0 +1,161 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import copy as copy_lib +import functools +import gzip +import pickle +from typing import Any, Dict + +import json +import numpy as np +from scipy import sparse as sp + +from . import residue_constants as rc +from .data_ops import NumpyDict + +# from typing import * + + +def lru_cache(maxsize=16, typed=False, copy=False, deepcopy=False): + if deepcopy: + + def decorator(f): + cached_func = functools.lru_cache(maxsize, typed)(f) + + @functools.wraps(f) + def wrapper(*args, **kwargs): + return copy_lib.deepcopy(cached_func(*args, **kwargs)) + + return wrapper + + elif copy: + + def decorator(f): + cached_func = functools.lru_cache(maxsize, typed)(f) + + @functools.wraps(f) + def wrapper(*args, **kwargs): + return copy_lib.copy(cached_func(*args, **kwargs)) + + return wrapper + + else: + decorator = functools.lru_cache(maxsize, typed) + return decorator + + +@lru_cache(maxsize=8, deepcopy=True) +def load_pickle_safe(path: str) -> Dict[str, Any]: + + def load(path): + assert path.endswith('.pkl') or path.endswith( + '.pkl.gz'), f'bad suffix in {path} as pickle file.' + open_fn = gzip.open if path.endswith('.gz') else open + with open_fn(path, 'rb') as f: + return pickle.load(f) + + ret = load(path) + ret = uncompress_features(ret) + return ret + + +@lru_cache(maxsize=8, copy=True) +def load_pickle(path: str) -> Dict[str, Any]: + + def load(path): + assert path.endswith('.pkl') or path.endswith( + '.pkl.gz'), f'bad suffix in {path} as pickle file.' + open_fn = gzip.open if path.endswith('.gz') else open + with open_fn(path, 'rb') as f: + return pickle.load(f) + + ret = load(path) + ret = uncompress_features(ret) + return ret + + +def correct_template_restypes(feature): + """Correct template restype to have the same order as residue_constants.""" + feature = np.argmax(feature, axis=-1).astype(np.int32) + new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + feature = np.take(new_order_list, feature.astype(np.int32), axis=0) + return feature + + +def convert_all_seq_feature(feature: NumpyDict) -> NumpyDict: + feature['msa'] = feature['msa'].astype(np.uint8) + if 'num_alignments' in feature: + feature.pop('num_alignments') + # make_all_seq_key = lambda k: f'{k}_all_seq' if not k.endswith('_all_seq') else k + + def make_all_seq_key(k): + if not k.endswith('_all_seq'): + return f'{k}_all_seq' + return k + + return {make_all_seq_key(k): v for k, v in feature.items()} + + +def to_dense_matrix(spmat_dict: NumpyDict): + spmat = sp.coo_matrix( + (spmat_dict['data'], (spmat_dict['row'], spmat_dict['col'])), + shape=spmat_dict['shape'], + dtype=np.float32, + ) + return spmat.toarray() + + +FEATS_DTYPE = {'msa': np.int32} + + +def uncompress_features(feats: NumpyDict) -> NumpyDict: + if 'sparse_deletion_matrix_int' in feats: + v = feats.pop('sparse_deletion_matrix_int') + v = to_dense_matrix(v) + feats['deletion_matrix'] = v + return feats + + +def filter(feature: NumpyDict, **kwargs) -> NumpyDict: + assert len(kwargs) == 1, f'wrong usage of filter with kwargs: {kwargs}' + if 'desired_keys' in kwargs: + feature = { + k: v + for k, v in feature.items() if k in kwargs['desired_keys'] + } + elif 'required_keys' in kwargs: + for k in kwargs['required_keys']: + assert k in feature, f'cannot find required key {k}.' + elif 'ignored_keys' in kwargs: + feature = { + k: v + for k, v in feature.items() if k not in kwargs['ignored_keys'] + } + else: + raise AssertionError(f'wrong usage of filter with kwargs: {kwargs}') + return feature + + +def compress_features(features: NumpyDict): + change_dtype = { + 'msa': np.uint8, + } + sparse_keys = ['deletion_matrix_int'] + + compressed_features = {} + for k, v in features.items(): + if k in change_dtype: + v = v.astype(change_dtype[k]) + if k in sparse_keys: + v = sp.coo_matrix(v, dtype=v.dtype) + sp_v = { + 'shape': v.shape, + 'row': v.row, + 'col': v.col, + 'data': v.data + } + k = f'sparse_{k}' + v = sp_v + compressed_features[k] = v + return compressed_features diff --git a/modelscope/models/science/unifold/dataset.py b/modelscope/models/science/unifold/dataset.py new file mode 100644 index 00000000..05803f2c --- /dev/null +++ b/modelscope/models/science/unifold/dataset.py @@ -0,0 +1,514 @@ +import copy +import logging +import os +# from typing import * +from typing import Dict, Iterable, List, Optional, Tuple, Union + +import json +import ml_collections as mlc +import numpy as np +import torch +from unicore.data import UnicoreDataset, data_utils +from unicore.distributed import utils as distributed_utils + +from .data import utils +from .data.data_ops import NumpyDict, TorchDict +from .data.process import process_features, process_labels +from .data.process_multimer import (add_assembly_features, + convert_monomer_features, merge_msas, + pair_and_merge, post_process) + +Rotation = Iterable[Iterable] +Translation = Iterable +Operation = Union[str, Tuple[Rotation, Translation]] +NumpyExample = Tuple[NumpyDict, Optional[List[NumpyDict]]] +TorchExample = Tuple[TorchDict, Optional[List[TorchDict]]] + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +def make_data_config( + config: mlc.ConfigDict, + mode: str, + num_res: int, +) -> Tuple[mlc.ConfigDict, List[str]]: + cfg = copy.deepcopy(config) + mode_cfg = cfg[mode] + with cfg.unlocked(): + if mode_cfg.crop_size is None: + mode_cfg.crop_size = num_res + feature_names = cfg.common.unsupervised_features + cfg.common.recycling_features + if cfg.common.use_templates: + feature_names += cfg.common.template_features + if cfg.common.is_multimer: + feature_names += cfg.common.multimer_features + if cfg[mode].supervised: + feature_names += cfg.supervised.supervised_features + + return cfg, feature_names + + +def process_label(all_atom_positions: np.ndarray, + operation: Operation) -> np.ndarray: + if operation == 'I': + return all_atom_positions + rot, trans = operation + rot = np.array(rot).reshape(3, 3) + trans = np.array(trans).reshape(3) + return all_atom_positions @ rot.T + trans + + +@utils.lru_cache(maxsize=8, copy=True) +def load_single_feature( + sequence_id: str, + monomer_feature_dir: str, + uniprot_msa_dir: Optional[str] = None, + is_monomer: bool = False, +) -> NumpyDict: + + monomer_feature = utils.load_pickle( + os.path.join(monomer_feature_dir, f'{sequence_id}.feature.pkl.gz')) + monomer_feature = convert_monomer_features(monomer_feature) + chain_feature = {**monomer_feature} + + if uniprot_msa_dir is not None: + all_seq_feature = utils.load_pickle( + os.path.join(uniprot_msa_dir, f'{sequence_id}.uniprot.pkl.gz')) + if is_monomer: + chain_feature['msa'], chain_feature[ + 'deletion_matrix'] = merge_msas( + chain_feature['msa'], + chain_feature['deletion_matrix'], + all_seq_feature['msa'], + all_seq_feature['deletion_matrix'], + ) # noqa + else: + all_seq_feature = utils.convert_all_seq_feature(all_seq_feature) + for key in [ + 'msa_all_seq', + 'msa_species_identifiers_all_seq', + 'deletion_matrix_all_seq', + ]: + chain_feature[key] = all_seq_feature[key] + + return chain_feature + + +def load_single_label( + label_id: str, + label_dir: str, + symmetry_operation: Optional[Operation] = None, +) -> NumpyDict: + label = utils.load_pickle( + os.path.join(label_dir, f'{label_id}.label.pkl.gz')) + if symmetry_operation is not None: + label['all_atom_positions'] = process_label( + label['all_atom_positions'], symmetry_operation) + label = { + k: v + for k, v in label.items() if k in + ['aatype', 'all_atom_positions', 'all_atom_mask', 'resolution'] + } + return label + + +def load( + sequence_ids: List[str], + monomer_feature_dir: str, + uniprot_msa_dir: Optional[str] = None, + label_ids: Optional[List[str]] = None, + label_dir: Optional[str] = None, + symmetry_operations: Optional[List[Operation]] = None, + is_monomer: bool = False, +) -> NumpyExample: + + all_chain_features = [ + load_single_feature(s, monomer_feature_dir, uniprot_msa_dir, + is_monomer) for s in sequence_ids + ] + + if label_ids is not None: + # load labels + assert len(label_ids) == len(sequence_ids) + assert label_dir is not None + if symmetry_operations is None: + symmetry_operations = ['I' for _ in label_ids] + all_chain_labels = [ + load_single_label(ll, label_dir, o) + for ll, o in zip(label_ids, symmetry_operations) + ] + # update labels into features to calculate spatial cropping etc. + [f.update(ll) for f, ll in zip(all_chain_features, all_chain_labels)] + + all_chain_features = add_assembly_features(all_chain_features) + + # get labels back from features, as add_assembly_features may alter the order of inputs. + if label_ids is not None: + all_chain_labels = [{ + k: f[k] + for k in + ['aatype', 'all_atom_positions', 'all_atom_mask', 'resolution'] + } for f in all_chain_features] + else: + all_chain_labels = None + + asym_len = np.array([c['seq_length'] for c in all_chain_features], + dtype=np.int64) + if is_monomer: + all_chain_features = all_chain_features[0] + else: + all_chain_features = pair_and_merge(all_chain_features) + all_chain_features = post_process(all_chain_features) + all_chain_features['asym_len'] = asym_len + + return all_chain_features, all_chain_labels + + +def process( + config: mlc.ConfigDict, + mode: str, + features: NumpyDict, + labels: Optional[List[NumpyDict]] = None, + seed: int = 0, + batch_idx: Optional[int] = None, + data_idx: Optional[int] = None, + is_distillation: bool = False, +) -> TorchExample: + + if mode == 'train': + assert batch_idx is not None + with data_utils.numpy_seed(seed, batch_idx, key='recycling'): + num_iters = np.random.randint( + 0, config.common.max_recycling_iters + 1) + use_clamped_fape = np.random.rand( + ) < config[mode].use_clamped_fape_prob + else: + num_iters = config.common.max_recycling_iters + use_clamped_fape = 1 + + features['num_recycling_iters'] = int(num_iters) + features['use_clamped_fape'] = int(use_clamped_fape) + features['is_distillation'] = int(is_distillation) + if is_distillation and 'msa_chains' in features: + features.pop('msa_chains') + + num_res = int(features['seq_length']) + cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res) + + if labels is not None: + features['resolution'] = labels[0]['resolution'].reshape(-1) + + with data_utils.numpy_seed(seed, data_idx, key='protein_feature'): + features['crop_and_fix_size_seed'] = np.random.randint(0, 63355) + features = utils.filter(features, desired_keys=feature_names) + features = {k: torch.tensor(v) for k, v in features.items()} + with torch.no_grad(): + features = process_features(features, cfg.common, cfg[mode]) + + if labels is not None: + labels = [{k: torch.tensor(v) for k, v in ll.items()} for ll in labels] + with torch.no_grad(): + labels = process_labels(labels) + + return features, labels + + +def load_and_process( + config: mlc.ConfigDict, + mode: str, + seed: int = 0, + batch_idx: Optional[int] = None, + data_idx: Optional[int] = None, + is_distillation: bool = False, + **load_kwargs, +): + is_monomer = ( + is_distillation + if 'is_monomer' not in load_kwargs else load_kwargs.pop('is_monomer')) + features, labels = load(**load_kwargs, is_monomer=is_monomer) + features, labels = process(config, mode, features, labels, seed, batch_idx, + data_idx, is_distillation) + return features, labels + + +class UnifoldDataset(UnicoreDataset): + + def __init__( + self, + args, + seed, + config, + data_path, + mode='train', + max_step=None, + disable_sd=False, + json_prefix='', + ): + self.path = data_path + + def load_json(filename): + return json.load(open(filename, 'r')) + + sample_weight = load_json( + os.path.join(self.path, + json_prefix + mode + '_sample_weight.json')) + self.multi_label = load_json( + os.path.join(self.path, json_prefix + mode + '_multi_label.json')) + self.inverse_multi_label = self._inverse_map(self.multi_label) + self.sample_weight = {} + for chain in self.inverse_multi_label: + entity = self.inverse_multi_label[chain] + self.sample_weight[chain] = sample_weight[entity] + self.seq_sample_weight = sample_weight + logger.info('load {} chains (unique {} sequences)'.format( + len(self.sample_weight), len(self.seq_sample_weight))) + self.feature_path = os.path.join(self.path, 'pdb_features') + self.label_path = os.path.join(self.path, 'pdb_labels') + sd_sample_weight_path = os.path.join( + self.path, json_prefix + 'sd_train_sample_weight.json') + if mode == 'train' and os.path.isfile( + sd_sample_weight_path) and not disable_sd: + self.sd_sample_weight = load_json(sd_sample_weight_path) + logger.info('load {} self-distillation samples.'.format( + len(self.sd_sample_weight))) + self.sd_feature_path = os.path.join(self.path, 'sd_features') + self.sd_label_path = os.path.join(self.path, 'sd_labels') + else: + self.sd_sample_weight = None + self.batch_size = ( + args.batch_size * distributed_utils.get_data_parallel_world_size() + * args.update_freq[0]) + self.data_len = ( + max_step * self.batch_size + if max_step is not None else len(self.sample_weight)) + self.mode = mode + self.num_seq, self.seq_keys, self.seq_sample_prob = self.cal_sample_weight( + self.seq_sample_weight) + self.num_chain, self.chain_keys, self.sample_prob = self.cal_sample_weight( + self.sample_weight) + if self.sd_sample_weight is not None: + ( + self.sd_num_chain, + self.sd_chain_keys, + self.sd_sample_prob, + ) = self.cal_sample_weight(self.sd_sample_weight) + self.config = config.data + self.seed = seed + self.sd_prob = args.sd_prob + + def cal_sample_weight(self, sample_weight): + prot_keys = list(sample_weight.keys()) + sum_weight = sum(sample_weight.values()) + sample_prob = [sample_weight[k] / sum_weight for k in prot_keys] + num_prot = len(prot_keys) + return num_prot, prot_keys, sample_prob + + def sample_chain(self, idx, sample_by_seq=False): + is_distillation = False + if self.mode == 'train': + with data_utils.numpy_seed(self.seed, idx, key='data_sample'): + is_distillation = ((np.random.rand(1)[0] < self.sd_prob) + if self.sd_sample_weight is not None else + False) + if is_distillation: + prot_idx = np.random.choice( + self.sd_num_chain, p=self.sd_sample_prob) + label_name = self.sd_chain_keys[prot_idx] + seq_name = label_name + else: + if not sample_by_seq: + prot_idx = np.random.choice( + self.num_chain, p=self.sample_prob) + label_name = self.chain_keys[prot_idx] + seq_name = self.inverse_multi_label[label_name] + else: + seq_idx = np.random.choice( + self.num_seq, p=self.seq_sample_prob) + seq_name = self.seq_keys[seq_idx] + label_name = np.random.choice( + self.multi_label[seq_name]) + else: + label_name = self.chain_keys[idx] + seq_name = self.inverse_multi_label[label_name] + return seq_name, label_name, is_distillation + + def __getitem__(self, idx): + sequence_id, label_id, is_distillation = self.sample_chain( + idx, sample_by_seq=True) + feature_dir, label_dir = ((self.feature_path, + self.label_path) if not is_distillation else + (self.sd_feature_path, self.sd_label_path)) + features, _ = load_and_process( + self.config, + self.mode, + self.seed, + batch_idx=(idx // self.batch_size), + data_idx=idx, + is_distillation=is_distillation, + sequence_ids=[sequence_id], + monomer_feature_dir=feature_dir, + uniprot_msa_dir=None, + label_ids=[label_id], + label_dir=label_dir, + symmetry_operations=None, + is_monomer=True, + ) + return features + + def __len__(self): + return self.data_len + + @staticmethod + def collater(samples): + # first dim is recyling. bsz is at the 2nd dim + return data_utils.collate_dict(samples, dim=1) + + @staticmethod + def _inverse_map(mapping: Dict[str, List[str]]): + inverse_mapping = {} + for ent, refs in mapping.items(): + for ref in refs: + if ref in inverse_mapping: # duplicated ent for this ref. + ent_2 = inverse_mapping[ref] + assert ( + ent == ent_2 + ), f'multiple entities ({ent_2}, {ent}) exist for reference {ref}.' + inverse_mapping[ref] = ent + return inverse_mapping + + +class UnifoldMultimerDataset(UnifoldDataset): + + def __init__( + self, + args: mlc.ConfigDict, + seed: int, + config: mlc.ConfigDict, + data_path: str, + mode: str = 'train', + max_step: Optional[int] = None, + disable_sd: bool = False, + json_prefix: str = '', + **kwargs, + ): + super().__init__(args, seed, config, data_path, mode, max_step, + disable_sd, json_prefix) + self.data_path = data_path + self.pdb_assembly = json.load( + open( + os.path.join(self.data_path, + json_prefix + 'pdb_assembly.json'))) + self.pdb_chains = self.get_chains(self.inverse_multi_label) + self.monomer_feature_path = os.path.join(self.data_path, + 'pdb_features') + self.uniprot_msa_path = os.path.join(self.data_path, 'pdb_uniprots') + self.label_path = os.path.join(self.data_path, 'pdb_labels') + self.max_chains = args.max_chains + if self.mode == 'train': + self.pdb_chains, self.sample_weight = self.filter_pdb_by_max_chains( + self.pdb_chains, self.pdb_assembly, self.sample_weight, + self.max_chains) + self.num_chain, self.chain_keys, self.sample_prob = self.cal_sample_weight( + self.sample_weight) + + def __getitem__(self, idx): + seq_id, label_id, is_distillation = self.sample_chain(idx) + if is_distillation: + label_ids = [label_id] + sequence_ids = [seq_id] + monomer_feature_path, uniprot_msa_path, label_path = ( + self.sd_feature_path, + None, + self.sd_label_path, + ) + symmetry_operations = None + else: + pdb_id = self.get_pdb_name(label_id) + if pdb_id in self.pdb_assembly and self.mode == 'train': + label_ids = [ + pdb_id + '_' + id + for id in self.pdb_assembly[pdb_id]['chains'] + ] + symmetry_operations = [ + t for t in self.pdb_assembly[pdb_id]['opers'] + ] + else: + label_ids = self.pdb_chains[pdb_id] + symmetry_operations = None + sequence_ids = [ + self.inverse_multi_label[chain_id] for chain_id in label_ids + ] + monomer_feature_path, uniprot_msa_path, label_path = ( + self.monomer_feature_path, + self.uniprot_msa_path, + self.label_path, + ) + + return load_and_process( + self.config, + self.mode, + self.seed, + batch_idx=(idx // self.batch_size), + data_idx=idx, + is_distillation=is_distillation, + sequence_ids=sequence_ids, + monomer_feature_dir=monomer_feature_path, + uniprot_msa_dir=uniprot_msa_path, + label_ids=label_ids, + label_dir=label_path, + symmetry_operations=symmetry_operations, + is_monomer=False, + ) + + @staticmethod + def collater(samples): + # first dim is recyling. bsz is at the 2nd dim + if len(samples) <= 0: # tackle empty batch + return None + feats = [s[0] for s in samples] + labs = [s[1] for s in samples if s[1] is not None] + try: + feats = data_utils.collate_dict(feats, dim=1) + except BaseException: + raise ValueError('cannot collate features', feats) + if not labs: + labs = None + return feats, labs + + @staticmethod + def get_pdb_name(chain): + return chain.split('_')[0] + + @staticmethod + def get_chains(canon_chain_map): + pdb_chains = {} + for chain in canon_chain_map: + pdb = UnifoldMultimerDataset.get_pdb_name(chain) + if pdb not in pdb_chains: + pdb_chains[pdb] = [] + pdb_chains[pdb].append(chain) + return pdb_chains + + @staticmethod + def filter_pdb_by_max_chains(pdb_chains, pdb_assembly, sample_weight, + max_chains): + new_pdb_chains = {} + for chain in pdb_chains: + if chain in pdb_assembly: + size = len(pdb_assembly[chain]['chains']) + if size <= max_chains: + new_pdb_chains[chain] = pdb_chains[chain] + else: + size = len(pdb_chains[chain]) + if size == 1: + new_pdb_chains[chain] = pdb_chains[chain] + new_sample_weight = { + k: sample_weight[k] + for k in sample_weight + if UnifoldMultimerDataset.get_pdb_name(k) in new_pdb_chains + } + logger.info( + f'filtered out {len(pdb_chains) - len(new_pdb_chains)} / {len(pdb_chains)} PDBs ' + f'({len(sample_weight) - len(new_sample_weight)} / {len(sample_weight)} chains) ' + f'by max_chains {max_chains}') + return new_pdb_chains, new_sample_weight diff --git a/modelscope/models/science/unifold/model.py b/modelscope/models/science/unifold/model.py new file mode 100644 index 00000000..6632751a --- /dev/null +++ b/modelscope/models/science/unifold/model.py @@ -0,0 +1,75 @@ +import argparse +import os +from typing import Any + +import torch + +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from .config import model_config +from .modules.alphafold import AlphaFold + +__all__ = ['UnifoldForProteinStructrue'] + + +@MODELS.register_module(Tasks.protein_structure, module_name=Models.unifold) +class UnifoldForProteinStructrue(TorchModel): + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + '--model-name', + help='choose the model config', + ) + + def __init__(self, **kwargs): + super().__init__() + parser = argparse.ArgumentParser() + parse_comm = [] + for key in kwargs: + parser.add_argument(f'--{key}') + parse_comm.append(f'--{key}') + parse_comm.append(kwargs[key]) + args = parser.parse_args(parse_comm) + base_architecture(args) + self.args = args + config = model_config( + self.args.model_name, + train=True, + ) + self.model = AlphaFold(config) + self.config = config + + # load model state dict + param_path = os.path.join(kwargs['model_dir'], + ModelFile.TORCH_MODEL_BIN_FILE) + state_dict = torch.load(param_path)['ema']['params'] + state_dict = { + '.'.join(k.split('.')[1:]): v + for k, v in state_dict.items() + } + self.model.load_state_dict(state_dict) + + def half(self): + self.model = self.model.half() + return self + + def bfloat16(self): + self.model = self.model.bfloat16() + return self + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + return cls(args) + + def forward(self, batch, **kwargs): + outputs = self.model.forward(batch) + return outputs, self.config.loss + + +def base_architecture(args): + args.model_name = getattr(args, 'model_name', 'model_2') diff --git a/modelscope/models/science/unifold/modules/alphafold.py b/modelscope/models/science/unifold/modules/alphafold.py new file mode 100644 index 00000000..71a1b310 --- /dev/null +++ b/modelscope/models/science/unifold/modules/alphafold.py @@ -0,0 +1,450 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import torch +import torch.nn as nn +from unicore.utils import tensor_tree_map + +from ..data import residue_constants +from .attentions import gen_msa_attn_mask, gen_tri_attn_mask +from .auxillary_heads import AuxiliaryHeads +from .common import residual +from .embedders import (ExtraMSAEmbedder, InputEmbedder, RecyclingEmbedder, + TemplateAngleEmbedder, TemplatePairEmbedder) +from .evoformer import EvoformerStack, ExtraMSAStack +from .featurization import (atom14_to_atom37, build_extra_msa_feat, + build_template_angle_feat, + build_template_pair_feat, + build_template_pair_feat_v2, pseudo_beta_fn) +from .structure_module import StructureModule +from .template import (TemplatePairStack, TemplatePointwiseAttention, + TemplateProjection) + + +class AlphaFold(nn.Module): + + def __init__(self, config): + super(AlphaFold, self).__init__() + + self.globals = config.globals + config = config.model + template_config = config.template + extra_msa_config = config.extra_msa + + self.input_embedder = InputEmbedder( + **config['input_embedder'], + use_chain_relative=config.is_multimer, + ) + self.recycling_embedder = RecyclingEmbedder( + **config['recycling_embedder'], ) + if config.template.enabled: + self.template_angle_embedder = TemplateAngleEmbedder( + **template_config['template_angle_embedder'], ) + self.template_pair_embedder = TemplatePairEmbedder( + **template_config['template_pair_embedder'], ) + self.template_pair_stack = TemplatePairStack( + **template_config['template_pair_stack'], ) + else: + self.template_pair_stack = None + self.enable_template_pointwise_attention = template_config[ + 'template_pointwise_attention'].enabled + if self.enable_template_pointwise_attention: + self.template_pointwise_att = TemplatePointwiseAttention( + **template_config['template_pointwise_attention'], ) + else: + self.template_proj = TemplateProjection( + **template_config['template_pointwise_attention'], ) + self.extra_msa_embedder = ExtraMSAEmbedder( + **extra_msa_config['extra_msa_embedder'], ) + self.extra_msa_stack = ExtraMSAStack( + **extra_msa_config['extra_msa_stack'], ) + self.evoformer = EvoformerStack(**config['evoformer_stack'], ) + self.structure_module = StructureModule(**config['structure_module'], ) + + self.aux_heads = AuxiliaryHeads(config['heads'], ) + + self.config = config + self.dtype = torch.float + self.inf = self.globals.inf + if self.globals.alphafold_original_mode: + self.alphafold_original_mode() + + def __make_input_float__(self): + self.input_embedder = self.input_embedder.float() + self.recycling_embedder = self.recycling_embedder.float() + + def half(self): + super().half() + if (not getattr(self, 'inference', False)): + self.__make_input_float__() + self.dtype = torch.half + return self + + def bfloat16(self): + super().bfloat16() + if (not getattr(self, 'inference', False)): + self.__make_input_float__() + self.dtype = torch.bfloat16 + return self + + def alphafold_original_mode(self): + + def set_alphafold_original_mode(module): + if hasattr(module, 'apply_alphafold_original_mode'): + module.apply_alphafold_original_mode() + if hasattr(module, 'act'): + module.act = nn.ReLU() + + self.apply(set_alphafold_original_mode) + + def inference_mode(self): + + def set_inference_mode(module): + setattr(module, 'inference', True) + + self.apply(set_inference_mode) + + def __convert_input_dtype__(self, batch): + for key in batch: + # only convert features with mask + if batch[key].dtype != self.dtype and 'mask' in key: + batch[key] = batch[key].type(self.dtype) + return batch + + def embed_templates_pair_core(self, batch, z, pair_mask, + tri_start_attn_mask, tri_end_attn_mask, + templ_dim, multichain_mask_2d): + if self.config.template.template_pair_embedder.v2_feature: + t = build_template_pair_feat_v2( + batch, + inf=self.config.template.inf, + eps=self.config.template.eps, + multichain_mask_2d=multichain_mask_2d, + **self.config.template.distogram, + ) + num_template = t[0].shape[-4] + single_templates = [ + self.template_pair_embedder([x[..., ti, :, :, :] + for x in t], z) + for ti in range(num_template) + ] + else: + t = build_template_pair_feat( + batch, + inf=self.config.template.inf, + eps=self.config.template.eps, + **self.config.template.distogram, + ) + single_templates = [ + self.template_pair_embedder(x, z) + for x in torch.unbind(t, dim=templ_dim) + ] + + t = self.template_pair_stack( + single_templates, + pair_mask, + tri_start_attn_mask=tri_start_attn_mask, + tri_end_attn_mask=tri_end_attn_mask, + templ_dim=templ_dim, + chunk_size=self.globals.chunk_size, + block_size=self.globals.block_size, + return_mean=not self.enable_template_pointwise_attention, + ) + return t + + def embed_templates_pair(self, batch, z, pair_mask, tri_start_attn_mask, + tri_end_attn_mask, templ_dim): + if self.config.template.template_pair_embedder.v2_feature and 'asym_id' in batch: + multichain_mask_2d = ( + batch['asym_id'][..., :, None] == batch['asym_id'][..., + None, :]) + multichain_mask_2d = multichain_mask_2d.unsqueeze(0) + else: + multichain_mask_2d = None + + if self.training or self.enable_template_pointwise_attention: + t = self.embed_templates_pair_core(batch, z, pair_mask, + tri_start_attn_mask, + tri_end_attn_mask, templ_dim, + multichain_mask_2d) + if self.enable_template_pointwise_attention: + t = self.template_pointwise_att( + t, + z, + template_mask=batch['template_mask'], + chunk_size=self.globals.chunk_size, + ) + t_mask = torch.sum( + batch['template_mask'], dim=-1, keepdims=True) > 0 + t_mask = t_mask[..., None, None].type(t.dtype) + t *= t_mask + else: + t = self.template_proj(t, z) + else: + template_aatype_shape = batch['template_aatype'].shape + # template_aatype is either [n_template, n_res] or [1, n_template_, n_res] + batch_templ_dim = 1 if len(template_aatype_shape) == 3 else 0 + n_templ = batch['template_aatype'].shape[batch_templ_dim] + + if n_templ <= 0: + t = None + else: + template_batch = { + k: v + for k, v in batch.items() if k.startswith('template_') + } + + def embed_one_template(i): + + def slice_template_tensor(t): + s = [slice(None) for _ in t.shape] + s[batch_templ_dim] = slice(i, i + 1) + return t[s] + + template_feats = tensor_tree_map( + slice_template_tensor, + template_batch, + ) + t = self.embed_templates_pair_core( + template_feats, z, pair_mask, tri_start_attn_mask, + tri_end_attn_mask, templ_dim, multichain_mask_2d) + return t + + t = embed_one_template(0) + # iterate templates one by one + for i in range(1, n_templ): + t += embed_one_template(i) + t /= n_templ + t = self.template_proj(t, z) + return t + + def embed_templates_angle(self, batch): + template_angle_feat, template_angle_mask = build_template_angle_feat( + batch, + v2_feature=self.config.template.template_pair_embedder.v2_feature) + t = self.template_angle_embedder(template_angle_feat) + return t, template_angle_mask + + def iteration_evoformer(self, feats, m_1_prev, z_prev, x_prev): + batch_dims = feats['target_feat'].shape[:-2] + n = feats['target_feat'].shape[-2] + seq_mask = feats['seq_mask'] + pair_mask = seq_mask[..., None] * seq_mask[..., None, :] + msa_mask = feats['msa_mask'] + + m, z = self.input_embedder( + feats['target_feat'], + feats['msa_feat'], + ) + + if m_1_prev is None: + m_1_prev = m.new_zeros( + (*batch_dims, n, self.config.input_embedder.d_msa), + requires_grad=False, + ) + if z_prev is None: + z_prev = z.new_zeros( + (*batch_dims, n, n, self.config.input_embedder.d_pair), + requires_grad=False, + ) + if x_prev is None: + x_prev = z.new_zeros( + (*batch_dims, n, residue_constants.atom_type_num, 3), + requires_grad=False, + ) + x_prev = pseudo_beta_fn(feats['aatype'], x_prev, None) + + z += self.recycling_embedder.recyle_pos(x_prev) + + m_1_prev_emb, z_prev_emb = self.recycling_embedder( + m_1_prev, + z_prev, + ) + + m[..., 0, :, :] += m_1_prev_emb + + z += z_prev_emb + + z += self.input_embedder.relpos_emb( + feats['residue_index'].long(), + feats.get('sym_id', None), + feats.get('asym_id', None), + feats.get('entity_id', None), + feats.get('num_sym', None), + ) + + m = m.type(self.dtype) + z = z.type(self.dtype) + tri_start_attn_mask, tri_end_attn_mask = gen_tri_attn_mask( + pair_mask, self.inf) + + if self.config.template.enabled: + template_mask = feats['template_mask'] + if torch.any(template_mask): + z = residual( + z, + self.embed_templates_pair( + feats, + z, + pair_mask, + tri_start_attn_mask, + tri_end_attn_mask, + templ_dim=-4, + ), + self.training, + ) + + if self.config.extra_msa.enabled: + a = self.extra_msa_embedder(build_extra_msa_feat(feats)) + extra_msa_row_mask = gen_msa_attn_mask( + feats['extra_msa_mask'], + inf=self.inf, + gen_col_mask=False, + ) + z = self.extra_msa_stack( + a, + z, + msa_mask=feats['extra_msa_mask'], + chunk_size=self.globals.chunk_size, + block_size=self.globals.block_size, + pair_mask=pair_mask, + msa_row_attn_mask=extra_msa_row_mask, + msa_col_attn_mask=None, + tri_start_attn_mask=tri_start_attn_mask, + tri_end_attn_mask=tri_end_attn_mask, + ) + + if self.config.template.embed_angles: + template_1d_feat, template_1d_mask = self.embed_templates_angle( + feats) + m = torch.cat([m, template_1d_feat], dim=-3) + msa_mask = torch.cat([feats['msa_mask'], template_1d_mask], dim=-2) + + msa_row_mask, msa_col_mask = gen_msa_attn_mask( + msa_mask, + inf=self.inf, + ) + + m, z, s = self.evoformer( + m, + z, + msa_mask=msa_mask, + pair_mask=pair_mask, + msa_row_attn_mask=msa_row_mask, + msa_col_attn_mask=msa_col_mask, + tri_start_attn_mask=tri_start_attn_mask, + tri_end_attn_mask=tri_end_attn_mask, + chunk_size=self.globals.chunk_size, + block_size=self.globals.block_size, + ) + return m, z, s, msa_mask, m_1_prev_emb, z_prev_emb + + def iteration_evoformer_structure_module(self, + batch, + m_1_prev, + z_prev, + x_prev, + cycle_no, + num_recycling, + num_ensembles=1): + z, s = 0, 0 + n_seq = batch['msa_feat'].shape[-3] + assert num_ensembles >= 1 + for ensemble_no in range(num_ensembles): + idx = cycle_no * num_ensembles + ensemble_no + + # fetch_cur_batch = lambda t: t[min(t.shape[0] - 1, idx), ...] + def fetch_cur_batch(t): + return t[min(t.shape[0] - 1, idx), ...] + + feats = tensor_tree_map(fetch_cur_batch, batch) + m, z0, s0, msa_mask, m_1_prev_emb, z_prev_emb = self.iteration_evoformer( + feats, m_1_prev, z_prev, x_prev) + z += z0 + s += s0 + del z0, s0 + if num_ensembles > 1: + z /= float(num_ensembles) + s /= float(num_ensembles) + + outputs = {} + + outputs['msa'] = m[..., :n_seq, :, :] + outputs['pair'] = z + outputs['single'] = s + + # norm loss + if (not getattr(self, 'inference', + False)) and num_recycling == (cycle_no + 1): + delta_msa = m + delta_msa[..., + 0, :, :] = delta_msa[..., + 0, :, :] - m_1_prev_emb.detach() + delta_pair = z - z_prev_emb.detach() + outputs['delta_msa'] = delta_msa + outputs['delta_pair'] = delta_pair + outputs['msa_norm_mask'] = msa_mask + + outputs['sm'] = self.structure_module( + s, + z, + feats['aatype'], + mask=feats['seq_mask'], + ) + outputs['final_atom_positions'] = atom14_to_atom37( + outputs['sm']['positions'], feats) + outputs['final_atom_mask'] = feats['atom37_atom_exists'] + outputs['pred_frame_tensor'] = outputs['sm']['frames'][-1] + + # use float32 for numerical stability + if (not getattr(self, 'inference', False)): + m_1_prev = m[..., 0, :, :].float() + z_prev = z.float() + x_prev = outputs['final_atom_positions'].float() + else: + m_1_prev = m[..., 0, :, :] + z_prev = z + x_prev = outputs['final_atom_positions'] + + return outputs, m_1_prev, z_prev, x_prev + + def forward(self, batch): + + m_1_prev = batch.get('m_1_prev', None) + z_prev = batch.get('z_prev', None) + x_prev = batch.get('x_prev', None) + + is_grad_enabled = torch.is_grad_enabled() + + num_iters = int(batch['num_recycling_iters']) + 1 + num_ensembles = int(batch['msa_mask'].shape[0]) // num_iters + if self.training: + # don't use ensemble during training + assert num_ensembles == 1 + + # convert dtypes in batch + batch = self.__convert_input_dtype__(batch) + for cycle_no in range(num_iters): + is_final_iter = cycle_no == (num_iters - 1) + with torch.set_grad_enabled(is_grad_enabled and is_final_iter): + ( + outputs, + m_1_prev, + z_prev, + x_prev, + ) = self.iteration_evoformer_structure_module( + batch, + m_1_prev, + z_prev, + x_prev, + cycle_no=cycle_no, + num_recycling=num_iters, + num_ensembles=num_ensembles, + ) + if not is_final_iter: + del outputs + + if 'asym_id' in batch: + outputs['asym_id'] = batch['asym_id'][0, ...] + outputs.update(self.aux_heads(outputs)) + return outputs diff --git a/modelscope/models/science/unifold/modules/attentions.py b/modelscope/models/science/unifold/modules/attentions.py new file mode 100644 index 00000000..d2319079 --- /dev/null +++ b/modelscope/models/science/unifold/modules/attentions.py @@ -0,0 +1,430 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from functools import partialmethod +from typing import List, Optional + +import torch +import torch.nn as nn +from unicore.modules import LayerNorm, softmax_dropout +from unicore.utils import permute_final_dims + +from .common import Linear, chunk_layer + + +def gen_attn_mask(mask, neg_inf): + assert neg_inf < -1e4 + attn_mask = torch.zeros_like(mask) + attn_mask[mask == 0] = neg_inf + return attn_mask + + +class Attention(nn.Module): + + def __init__( + self, + q_dim: int, + k_dim: int, + v_dim: int, + head_dim: int, + num_heads: int, + gating: bool = True, + ): + super(Attention, self).__init__() + + self.num_heads = num_heads + total_dim = head_dim * self.num_heads + self.gating = gating + self.linear_q = Linear(q_dim, total_dim, bias=False, init='glorot') + self.linear_k = Linear(k_dim, total_dim, bias=False, init='glorot') + self.linear_v = Linear(v_dim, total_dim, bias=False, init='glorot') + self.linear_o = Linear(total_dim, q_dim, init='final') + self.linear_g = None + if self.gating: + self.linear_g = Linear(q_dim, total_dim, init='gating') + # precompute the 1/sqrt(head_dim) + self.norm = head_dim**-0.5 + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: torch.Tensor = None, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + g = None + if self.linear_g is not None: + # gating, use raw query input + g = self.linear_g(q) + + q = self.linear_q(q) + q *= self.norm + k = self.linear_k(k) + v = self.linear_v(v) + + q = q.view(q.shape[:-1] + (self.num_heads, -1)).transpose( + -2, -3).contiguous() + k = k.view(k.shape[:-1] + (self.num_heads, -1)).transpose( + -2, -3).contiguous() + v = v.view(v.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3) + + attn = torch.matmul(q, k.transpose(-1, -2)) + del q, k + + attn = softmax_dropout(attn, 0, self.training, mask=mask, bias=bias) + o = torch.matmul(attn, v) + del attn, v + + o = o.transpose(-2, -3).contiguous() + o = o.view(*o.shape[:-2], -1) + + if g is not None: + o = torch.sigmoid(g) * o + + # merge heads + o = nn.functional.linear(o, self.linear_o.weight) + return o + + def get_output_bias(self): + return self.linear_o.bias + + +class GlobalAttention(nn.Module): + + def __init__(self, input_dim, head_dim, num_heads, inf, eps): + super(GlobalAttention, self).__init__() + + self.num_heads = num_heads + self.inf = inf + self.eps = eps + self.linear_q = Linear( + input_dim, head_dim * num_heads, bias=False, init='glorot') + self.linear_k = Linear(input_dim, head_dim, bias=False, init='glorot') + self.linear_v = Linear(input_dim, head_dim, bias=False, init='glorot') + self.linear_g = Linear(input_dim, head_dim * num_heads, init='gating') + self.linear_o = Linear(head_dim * num_heads, input_dim, init='final') + self.sigmoid = nn.Sigmoid() + # precompute the 1/sqrt(head_dim) + self.norm = head_dim**-0.5 + + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + + # gating + g = self.sigmoid(self.linear_g(x)) + + k = self.linear_k(x) + v = self.linear_v(x) + + q = torch.sum( + x * mask.unsqueeze(-1), dim=-2) / ( + torch.sum(mask, dim=-1, keepdims=True) + self.eps) + q = self.linear_q(q) + q *= self.norm + q = q.view(q.shape[:-1] + (self.num_heads, -1)) + + attn = torch.matmul(q, k.transpose(-1, -2)) + del q, k + + attn_mask = gen_attn_mask(mask, -self.inf)[..., :, None, :] + attn = softmax_dropout(attn, 0, self.training, mask=attn_mask) + + o = torch.matmul( + attn, + v, + ) + del attn, v + + g = g.view(g.shape[:-1] + (self.num_heads, -1)) + o = o.unsqueeze(-3) * g + del g + + # merge heads + o = o.reshape(o.shape[:-2] + (-1, )) + return self.linear_o(o) + + +def gen_msa_attn_mask(mask, inf, gen_col_mask=True): + row_mask = gen_attn_mask(mask, -inf)[..., :, None, None, :] + if gen_col_mask: + col_mask = gen_attn_mask(mask.transpose(-1, -2), -inf)[..., :, None, + None, :] + return row_mask, col_mask + else: + return row_mask + + +class MSAAttention(nn.Module): + + def __init__( + self, + d_in, + d_hid, + num_heads, + pair_bias=False, + d_pair=None, + ): + super(MSAAttention, self).__init__() + + self.pair_bias = pair_bias + self.layer_norm_m = LayerNorm(d_in) + self.layer_norm_z = None + self.linear_z = None + if self.pair_bias: + self.layer_norm_z = LayerNorm(d_pair) + self.linear_z = Linear( + d_pair, num_heads, bias=False, init='normal') + + self.mha = Attention(d_in, d_in, d_in, d_hid, num_heads) + + @torch.jit.ignore + def _chunk( + self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + chunk_size: int = None, + ) -> torch.Tensor: + + return chunk_layer( + self._attn_forward, + { + 'm': m, + 'mask': mask, + 'bias': bias + }, + chunk_size=chunk_size, + num_batch_dims=len(m.shape[:-2]), + ) + + @torch.jit.ignore + def _attn_chunk_forward( + self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = 2560, + ) -> torch.Tensor: + m = self.layer_norm_m(m) + num_chunk = (m.shape[-3] + chunk_size - 1) // chunk_size + outputs = [] + for i in range(num_chunk): + chunk_start = i * chunk_size + chunk_end = min(m.shape[-3], chunk_start + chunk_size) + cur_m = m[..., chunk_start:chunk_end, :, :] + cur_mask = ( + mask[..., chunk_start:chunk_end, :, :, :] + if mask is not None else None) + outputs.append( + self.mha(q=cur_m, k=cur_m, v=cur_m, mask=cur_mask, bias=bias)) + return torch.concat(outputs, dim=-3) + + def _attn_forward(self, m, mask, bias: Optional[torch.Tensor] = None): + m = self.layer_norm_m(m) + return self.mha(q=m, k=m, v=m, mask=mask, bias=bias) + + def forward( + self, + m: torch.Tensor, + z: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + + bias = None + if self.pair_bias: + z = self.layer_norm_z(z) + bias = ( + permute_final_dims(self.linear_z(z), + (2, 0, 1)).unsqueeze(-4).contiguous()) + + if chunk_size is not None: + m = self._chunk(m, attn_mask, bias, chunk_size) + else: + attn_chunk_size = 2560 + if m.shape[-3] <= attn_chunk_size: + m = self._attn_forward(m, attn_mask, bias) + else: + # reduce the peak memory cost in extra_msa_stack + return self._attn_chunk_forward( + m, attn_mask, bias, chunk_size=attn_chunk_size) + + return m + + def get_output_bias(self): + return self.mha.get_output_bias() + + +class MSARowAttentionWithPairBias(MSAAttention): + + def __init__(self, d_msa, d_pair, d_hid, num_heads): + super(MSARowAttentionWithPairBias, self).__init__( + d_msa, + d_hid, + num_heads, + pair_bias=True, + d_pair=d_pair, + ) + + +class MSAColumnAttention(MSAAttention): + + def __init__(self, d_msa, d_hid, num_heads): + super(MSAColumnAttention, self).__init__( + d_in=d_msa, + d_hid=d_hid, + num_heads=num_heads, + pair_bias=False, + d_pair=None, + ) + + def forward( + self, + m: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + m = m.transpose(-2, -3) + m = super().forward(m, attn_mask=attn_mask, chunk_size=chunk_size) + m = m.transpose(-2, -3) + + return m + + +class MSAColumnGlobalAttention(nn.Module): + + def __init__( + self, + d_in, + d_hid, + num_heads, + inf=1e9, + eps=1e-10, + ): + super(MSAColumnGlobalAttention, self).__init__() + + self.layer_norm_m = LayerNorm(d_in) + self.global_attention = GlobalAttention( + d_in, + d_hid, + num_heads, + inf=inf, + eps=eps, + ) + + @torch.jit.ignore + def _chunk( + self, + m: torch.Tensor, + mask: torch.Tensor, + chunk_size: int, + ) -> torch.Tensor: + return chunk_layer( + self._attn_forward, + { + 'm': m, + 'mask': mask + }, + chunk_size=chunk_size, + num_batch_dims=len(m.shape[:-2]), + ) + + def _attn_forward(self, m, mask): + m = self.layer_norm_m(m) + return self.global_attention(m, mask=mask) + + def forward( + self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + + m = m.transpose(-2, -3) + mask = mask.transpose(-1, -2) + + if chunk_size is not None: + m = self._chunk(m, mask, chunk_size) + else: + m = self._attn_forward(m, mask=mask) + + m = m.transpose(-2, -3) + return m + + +def gen_tri_attn_mask(mask, inf): + start_mask = gen_attn_mask(mask, -inf)[..., :, None, None, :] + end_mask = gen_attn_mask(mask.transpose(-1, -2), -inf)[..., :, None, + None, :] + return start_mask, end_mask + + +class TriangleAttention(nn.Module): + + def __init__( + self, + d_in, + d_hid, + num_heads, + starting, + ): + super(TriangleAttention, self).__init__() + self.starting = starting + self.layer_norm = LayerNorm(d_in) + self.linear = Linear(d_in, num_heads, bias=False, init='normal') + self.mha = Attention(d_in, d_in, d_in, d_hid, num_heads) + + @torch.jit.ignore + def _chunk( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + chunk_size: int = None, + ) -> torch.Tensor: + return chunk_layer( + self.mha, + { + 'q': x, + 'k': x, + 'v': x, + 'mask': mask, + 'bias': bias + }, + chunk_size=chunk_size, + num_batch_dims=len(x.shape[:-2]), + ) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + if not self.starting: + x = x.transpose(-2, -3) + + x = self.layer_norm(x) + triangle_bias = ( + permute_final_dims(self.linear(x), + (2, 0, 1)).unsqueeze(-4).contiguous()) + + if chunk_size is not None: + x = self._chunk(x, attn_mask, triangle_bias, chunk_size) + else: + x = self.mha(q=x, k=x, v=x, mask=attn_mask, bias=triangle_bias) + + if not self.starting: + x = x.transpose(-2, -3) + return x + + def get_output_bias(self): + return self.mha.get_output_bias() + + +class TriangleAttentionStarting(TriangleAttention): + __init__ = partialmethod(TriangleAttention.__init__, starting=True) + + +class TriangleAttentionEnding(TriangleAttention): + __init__ = partialmethod(TriangleAttention.__init__, starting=False) diff --git a/modelscope/models/science/unifold/modules/auxillary_heads.py b/modelscope/models/science/unifold/modules/auxillary_heads.py new file mode 100644 index 00000000..2daf5d55 --- /dev/null +++ b/modelscope/models/science/unifold/modules/auxillary_heads.py @@ -0,0 +1,171 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from typing import Dict + +import torch.nn as nn +from unicore.modules import LayerNorm + +from .common import Linear +from .confidence import (predicted_aligned_error, predicted_lddt, + predicted_tm_score) + + +class AuxiliaryHeads(nn.Module): + + def __init__(self, config): + super(AuxiliaryHeads, self).__init__() + + self.plddt = PredictedLDDTHead(**config['plddt'], ) + + self.distogram = DistogramHead(**config['distogram'], ) + + self.masked_msa = MaskedMSAHead(**config['masked_msa'], ) + + if config.experimentally_resolved.enabled: + self.experimentally_resolved = ExperimentallyResolvedHead( + **config['experimentally_resolved'], ) + + if config.pae.enabled: + self.pae = PredictedAlignedErrorHead(**config.pae, ) + + self.config = config + + def forward(self, outputs): + aux_out = {} + plddt_logits = self.plddt(outputs['sm']['single']) + aux_out['plddt_logits'] = plddt_logits + + aux_out['plddt'] = predicted_lddt(plddt_logits.detach()) + + distogram_logits = self.distogram(outputs['pair']) + aux_out['distogram_logits'] = distogram_logits + + masked_msa_logits = self.masked_msa(outputs['msa']) + aux_out['masked_msa_logits'] = masked_msa_logits + + if self.config.experimentally_resolved.enabled: + exp_res_logits = self.experimentally_resolved(outputs['single']) + aux_out['experimentally_resolved_logits'] = exp_res_logits + + if self.config.pae.enabled: + pae_logits = self.pae(outputs['pair']) + aux_out['pae_logits'] = pae_logits + pae_logits = pae_logits.detach() + aux_out.update( + predicted_aligned_error( + pae_logits, + **self.config.pae, + )) + aux_out['ptm'] = predicted_tm_score( + pae_logits, interface=False, **self.config.pae) + + iptm_weight = self.config.pae.get('iptm_weight', 0.0) + if iptm_weight > 0.0: + aux_out['iptm'] = predicted_tm_score( + pae_logits, + interface=True, + asym_id=outputs['asym_id'], + **self.config.pae, + ) + aux_out['iptm+ptm'] = ( + iptm_weight * aux_out['iptm'] + # noqa W504 + (1.0 - iptm_weight) * aux_out['ptm']) + + return aux_out + + +class PredictedLDDTHead(nn.Module): + + def __init__(self, num_bins, d_in, d_hid): + super(PredictedLDDTHead, self).__init__() + + self.num_bins = num_bins + self.d_in = d_in + self.d_hid = d_hid + + self.layer_norm = LayerNorm(self.d_in) + + self.linear_1 = Linear(self.d_in, self.d_hid, init='relu') + self.linear_2 = Linear(self.d_hid, self.d_hid, init='relu') + self.act = nn.GELU() + self.linear_3 = Linear(self.d_hid, self.num_bins, init='final') + + def forward(self, s): + s = self.layer_norm(s) + s = self.linear_1(s) + s = self.act(s) + s = self.linear_2(s) + s = self.act(s) + s = self.linear_3(s) + return s + + +class EnhancedHeadBase(nn.Module): + + def __init__(self, d_in, d_out, disable_enhance_head): + super(EnhancedHeadBase, self).__init__() + if disable_enhance_head: + self.layer_norm = None + self.linear_in = None + else: + self.layer_norm = LayerNorm(d_in) + self.linear_in = Linear(d_in, d_in, init='relu') + self.act = nn.GELU() + self.linear = Linear(d_in, d_out, init='final') + + def apply_alphafold_original_mode(self): + self.layer_norm = None + self.linear_in = None + + def forward(self, x): + if self.layer_norm is not None: + x = self.layer_norm(x) + x = self.act(self.linear_in(x)) + logits = self.linear(x) + return logits + + +class DistogramHead(EnhancedHeadBase): + + def __init__(self, d_pair, num_bins, disable_enhance_head, **kwargs): + super(DistogramHead, self).__init__( + d_in=d_pair, + d_out=num_bins, + disable_enhance_head=disable_enhance_head, + ) + + def forward(self, x): + logits = super().forward(x) + logits = logits + logits.transpose(-2, -3) + return logits + + +class PredictedAlignedErrorHead(EnhancedHeadBase): + + def __init__(self, d_pair, num_bins, disable_enhance_head, **kwargs): + super(PredictedAlignedErrorHead, self).__init__( + d_in=d_pair, + d_out=num_bins, + disable_enhance_head=disable_enhance_head, + ) + + +class MaskedMSAHead(EnhancedHeadBase): + + def __init__(self, d_msa, d_out, disable_enhance_head, **kwargs): + super(MaskedMSAHead, self).__init__( + d_in=d_msa, + d_out=d_out, + disable_enhance_head=disable_enhance_head, + ) + + +class ExperimentallyResolvedHead(EnhancedHeadBase): + + def __init__(self, d_single, d_out, disable_enhance_head, **kwargs): + super(ExperimentallyResolvedHead, self).__init__( + d_in=d_single, + d_out=d_out, + disable_enhance_head=disable_enhance_head, + ) diff --git a/modelscope/models/science/unifold/modules/common.py b/modelscope/models/science/unifold/modules/common.py new file mode 100644 index 00000000..186f2567 --- /dev/null +++ b/modelscope/models/science/unifold/modules/common.py @@ -0,0 +1,387 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from functools import partial +from typing import Any, Callable, Dict, Iterable, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from unicore.modules import LayerNorm +from unicore.utils import tensor_tree_map + + +class Linear(nn.Linear): + + def __init__( + self, + d_in: int, + d_out: int, + bias: bool = True, + init: str = 'default', + ): + super(Linear, self).__init__(d_in, d_out, bias=bias) + + self.use_bias = bias + + if self.use_bias: + with torch.no_grad(): + self.bias.fill_(0) + + if init == 'default': + self._trunc_normal_init(1.0) + elif init == 'relu': + self._trunc_normal_init(2.0) + elif init == 'glorot': + self._glorot_uniform_init() + elif init == 'gating': + self._zero_init(self.use_bias) + elif init == 'normal': + self._normal_init() + elif init == 'final': + self._zero_init(False) + else: + raise ValueError('Invalid init method.') + + def _trunc_normal_init(self, scale=1.0): + # Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) + TRUNCATED_NORMAL_STDDEV_FACTOR = 0.87962566103423978 + _, fan_in = self.weight.shape + scale = scale / max(1, fan_in) + std = (scale**0.5) / TRUNCATED_NORMAL_STDDEV_FACTOR + nn.init.trunc_normal_(self.weight, mean=0.0, std=std) + + def _glorot_uniform_init(self): + nn.init.xavier_uniform_(self.weight, gain=1) + + def _zero_init(self, use_bias=True): + with torch.no_grad(): + self.weight.fill_(0.0) + if use_bias: + with torch.no_grad(): + self.bias.fill_(1.0) + + def _normal_init(self): + torch.nn.init.kaiming_normal_(self.weight, nonlinearity='linear') + + +class Transition(nn.Module): + + def __init__(self, d_in, n): + + super(Transition, self).__init__() + + self.d_in = d_in + self.n = n + + self.layer_norm = LayerNorm(self.d_in) + self.linear_1 = Linear(self.d_in, self.n * self.d_in, init='relu') + self.act = nn.GELU() + self.linear_2 = Linear(self.n * self.d_in, d_in, init='final') + + def _transition(self, x): + x = self.layer_norm(x) + x = self.linear_1(x) + x = self.act(x) + x = self.linear_2(x) + return x + + @torch.jit.ignore + def _chunk( + self, + x: torch.Tensor, + chunk_size: int, + ) -> torch.Tensor: + return chunk_layer( + self._transition, + {'x': x}, + chunk_size=chunk_size, + num_batch_dims=len(x.shape[:-2]), + ) + + def forward( + self, + x: torch.Tensor, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + + if chunk_size is not None: + x = self._chunk(x, chunk_size) + else: + x = self._transition(x=x) + + return x + + +class OuterProductMean(nn.Module): + + def __init__(self, d_msa, d_pair, d_hid, eps=1e-3): + super(OuterProductMean, self).__init__() + + self.d_msa = d_msa + self.d_pair = d_pair + self.d_hid = d_hid + self.eps = eps + + self.layer_norm = LayerNorm(d_msa) + self.linear_1 = Linear(d_msa, d_hid) + self.linear_2 = Linear(d_msa, d_hid) + self.linear_out = Linear(d_hid**2, d_pair, init='relu') + self.act = nn.GELU() + self.linear_z = Linear(self.d_pair, self.d_pair, init='final') + self.layer_norm_out = LayerNorm(self.d_pair) + + def _opm(self, a, b): + outer = torch.einsum('...bac,...dae->...bdce', a, b) + outer = outer.reshape(outer.shape[:-2] + (-1, )) + outer = self.linear_out(outer) + return outer + + @torch.jit.ignore + def _chunk(self, a: torch.Tensor, b: torch.Tensor, + chunk_size: int) -> torch.Tensor: + a = a.reshape((-1, ) + a.shape[-3:]) + b = b.reshape((-1, ) + b.shape[-3:]) + out = [] + # TODO: optimize this + for a_prime, b_prime in zip(a, b): + outer = chunk_layer( + partial(self._opm, b=b_prime), + {'a': a_prime}, + chunk_size=chunk_size, + num_batch_dims=1, + ) + out.append(outer) + if len(out) == 1: + outer = out[0].unsqueeze(0) + else: + outer = torch.stack(out, dim=0) + outer = outer.reshape(a.shape[:-3] + outer.shape[1:]) + + return outer + + def apply_alphafold_original_mode(self): + self.linear_z = None + self.layer_norm_out = None + + def forward( + self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + + m = self.layer_norm(m) + mask = mask.unsqueeze(-1) + if self.layer_norm_out is not None: + # for numerical stability + mask = mask * (mask.size(-2)**-0.5) + a = self.linear_1(m) + b = self.linear_2(m) + if self.training: + a = a * mask + b = b * mask + else: + a *= mask + b *= mask + + a = a.transpose(-2, -3) + b = b.transpose(-2, -3) + + if chunk_size is not None: + z = self._chunk(a, b, chunk_size) + else: + z = self._opm(a, b) + + norm = torch.einsum('...abc,...adc->...bdc', mask, mask) + z /= self.eps + norm + if self.layer_norm_out is not None: + z = self.act(z) + z = self.layer_norm_out(z) + z = self.linear_z(z) + return z + + +def residual(residual, x, training): + if training: + return x + residual + else: + residual += x + return residual + + +@torch.jit.script +def fused_bias_dropout_add( + x: torch.Tensor, + bias: torch.Tensor, + residual: torch.Tensor, + dropmask: torch.Tensor, + prob: float, +) -> torch.Tensor: + return (x + bias) * F.dropout(dropmask, p=prob, training=True) + residual + + +@torch.jit.script +def fused_bias_dropout_add_inference( + x: torch.Tensor, + bias: torch.Tensor, + residual: torch.Tensor, +) -> torch.Tensor: + residual += bias + x + return residual + + +def bias_dropout_residual(module, residual, x, dropout_shared_dim, prob, + training): + bias = module.get_output_bias() + if training: + shape = list(x.shape) + shape[dropout_shared_dim] = 1 + with torch.no_grad(): + mask = x.new_ones(shape) + return fused_bias_dropout_add(x, bias, residual, mask, prob) + else: + return fused_bias_dropout_add_inference(x, bias, residual) + + +@torch.jit.script +def fused_bias_gated_dropout_add( + x: torch.Tensor, + bias: torch.Tensor, + g: torch.Tensor, + g_bias: torch.Tensor, + residual: torch.Tensor, + dropout_mask: torch.Tensor, + prob: float, +) -> torch.Tensor: + return (torch.sigmoid(g + g_bias) * (x + bias)) * F.dropout( + dropout_mask, + p=prob, + training=True, + ) + residual + + +def tri_mul_residual( + module, + residual, + outputs, + dropout_shared_dim, + prob, + training, + block_size, +): + if training: + x, g = outputs + bias, g_bias = module.get_output_bias() + shape = list(x.shape) + shape[dropout_shared_dim] = 1 + with torch.no_grad(): + mask = x.new_ones(shape) + return fused_bias_gated_dropout_add( + x, + bias, + g, + g_bias, + residual, + mask, + prob, + ) + elif block_size is None: + x, g = outputs + bias, g_bias = module.get_output_bias() + residual += (torch.sigmoid(g + g_bias) * (x + bias)) + return residual + else: + # gated is not used here + residual += outputs + return residual + + +class SimpleModuleList(nn.ModuleList): + + def __repr__(self): + return str(len(self)) + ' X ...\n' + self[0].__repr__() + + +def chunk_layer( + layer: Callable, + inputs: Dict[str, Any], + chunk_size: int, + num_batch_dims: int, +) -> Any: + # TODO: support inplace add to output + if not (len(inputs) > 0): + raise ValueError('Must provide at least one input') + + def _dict_get_shapes(input): + shapes = [] + if type(input) is torch.Tensor: + shapes.append(input.shape) + elif type(input) is dict: + for v in input.values(): + shapes.extend(_dict_get_shapes(v)) + elif isinstance(input, Iterable): + for v in input: + shapes.extend(_dict_get_shapes(v)) + else: + raise ValueError('Not supported') + + return shapes + + inputs = {k: v for k, v in inputs.items() if v is not None} + initial_dims = [ + shape[:num_batch_dims] for shape in _dict_get_shapes(inputs) + ] + orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)]) + + flat_batch_dim = 1 + for d in orig_batch_dims: + flat_batch_dim *= d + num_chunks = (flat_batch_dim + chunk_size - 1) // chunk_size + + def _flat_inputs(t): + t = t.view(-1, *t.shape[num_batch_dims:]) + assert ( + t.shape[0] == flat_batch_dim or t.shape[0] == 1 + ), 'batch dimension must be 1 or equal to the flat batch dimension' + return t + + flat_inputs = tensor_tree_map(_flat_inputs, inputs) + + out = None + for i in range(num_chunks): + chunk_start = i * chunk_size + chunk_end = min((i + 1) * chunk_size, flat_batch_dim) + + def select_chunk(t): + if t.shape[0] == 1: + return t[0:1] + else: + return t[chunk_start:chunk_end] + + chunkes = tensor_tree_map(select_chunk, flat_inputs) + + output_chunk = layer(**chunkes) + + if out is None: + out = tensor_tree_map( + lambda t: t.new_zeros((flat_batch_dim, ) + t.shape[1:]), + output_chunk) + + out_type = type(output_chunk) + if out_type is tuple: + for x, y in zip(out, output_chunk): + x[chunk_start:chunk_end] = y + elif out_type is torch.Tensor: + out[chunk_start:chunk_end] = output_chunk + else: + raise ValueError('Not supported') + + # reshape = lambda t: t.view(orig_batch_dims + t.shape[1:]) + def reshape(t): + return t.view(orig_batch_dims + t.shape[1:]) + + out = tensor_tree_map(reshape, out) + + return out diff --git a/modelscope/models/science/unifold/modules/confidence.py b/modelscope/models/science/unifold/modules/confidence.py new file mode 100644 index 00000000..7574689c --- /dev/null +++ b/modelscope/models/science/unifold/modules/confidence.py @@ -0,0 +1,159 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from typing import Dict, Optional, Tuple + +import torch + + +def predicted_lddt(plddt_logits: torch.Tensor) -> torch.Tensor: + """Computes per-residue pLDDT from logits. + Args: + logits: [num_res, num_bins] output from the PredictedLDDTHead. + Returns: + plddt: [num_res] per-residue pLDDT. + """ + num_bins = plddt_logits.shape[-1] + bin_probs = torch.nn.functional.softmax(plddt_logits.float(), dim=-1) + bin_width = 1.0 / num_bins + bounds = torch.arange( + start=0.5 * bin_width, + end=1.0, + step=bin_width, + device=plddt_logits.device) + plddt = torch.sum( + bin_probs + * bounds.view(*((1, ) * len(bin_probs.shape[:-1])), *bounds.shape), + dim=-1, + ) + return plddt + + +def compute_bin_values(breaks: torch.Tensor): + """Gets the bin centers from the bin edges. + Args: + breaks: [num_bins - 1] the error bin edges. + Returns: + bin_centers: [num_bins] the error bin centers. + """ + step = breaks[1] - breaks[0] + bin_values = breaks + step / 2 + bin_values = torch.cat([bin_values, (bin_values[-1] + step).unsqueeze(-1)], + dim=0) + return bin_values + + +def compute_predicted_aligned_error( + bin_edges: torch.Tensor, + bin_probs: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Calculates expected aligned distance errors for every pair of residues. + Args: + alignment_confidence_breaks: [num_bins - 1] the error bin edges. + aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted + probs for each error bin, for each pair of residues. + Returns: + predicted_aligned_error: [num_res, num_res] the expected aligned distance + error for each pair of residues. + max_predicted_aligned_error: The maximum predicted error possible. + """ + bin_values = compute_bin_values(bin_edges) + return torch.sum(bin_probs * bin_values, dim=-1) + + +def predicted_aligned_error( + pae_logits: torch.Tensor, + max_bin: int = 31, + num_bins: int = 64, + **kwargs, +) -> Dict[str, torch.Tensor]: + """Computes aligned confidence metrics from logits. + Args: + logits: [num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + breaks: [num_bins - 1] the error bin edges. + Returns: + aligned_confidence_probs: [num_res, num_res, num_bins] the predicted + aligned error probabilities over bins for each residue pair. + predicted_aligned_error: [num_res, num_res] the expected aligned distance + error for each pair of residues. + max_predicted_aligned_error: The maximum predicted error possible. + """ + bin_probs = torch.nn.functional.softmax(pae_logits.float(), dim=-1) + bin_edges = torch.linspace( + 0, max_bin, steps=(num_bins - 1), device=pae_logits.device) + + predicted_aligned_error = compute_predicted_aligned_error( + bin_edges=bin_edges, + bin_probs=bin_probs, + ) + + return { + 'aligned_error_probs_per_bin': bin_probs, + 'predicted_aligned_error': predicted_aligned_error, + } + + +def predicted_tm_score( + pae_logits: torch.Tensor, + residue_weights: Optional[torch.Tensor] = None, + max_bin: int = 31, + num_bins: int = 64, + eps: float = 1e-8, + asym_id: Optional[torch.Tensor] = None, + interface: bool = False, + **kwargs, +) -> torch.Tensor: + """Computes predicted TM alignment or predicted interface TM alignment score. + Args: + logits: [num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + breaks: [num_bins] the error bins. + residue_weights: [num_res] the per residue weights to use for the + expectation. + asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for + ipTM calculation, i.e. when interface=True. + interface: If True, interface predicted TM score is computed. + Returns: + ptm_score: The predicted TM alignment or the predicted iTM score. + """ + pae_logits = pae_logits.float() + if residue_weights is None: + residue_weights = pae_logits.new_ones(pae_logits.shape[:-2]) + + breaks = torch.linspace( + 0, max_bin, steps=(num_bins - 1), device=pae_logits.device) + + def tm_kernal(nres): + clipped_n = max(nres, 19) + d0 = 1.24 * (clipped_n - 15)**(1.0 / 3.0) - 1.8 + return lambda x: 1.0 / (1.0 + (x / d0)**2) + + def rmsd_kernal(eps): # leave for compute pRMS + return lambda x: 1. / (x + eps) + + bin_centers = compute_bin_values(breaks) + probs = torch.nn.functional.softmax(pae_logits, dim=-1) + tm_per_bin = tm_kernal(nres=pae_logits.shape[-2])(bin_centers) + # tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2)) + # rmsd_per_bin = rmsd_kernal()(bin_centers) + predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1) + + pair_mask = predicted_tm_term.new_ones(predicted_tm_term.shape) + if interface: + assert asym_id is not None, 'must provide asym_id for iptm calculation.' + pair_mask *= asym_id[..., :, None] != asym_id[..., None, :] + + predicted_tm_term *= pair_mask + + pair_residue_weights = pair_mask * ( + residue_weights[None, :] * residue_weights[:, None]) + normed_residue_mask = pair_residue_weights / ( + eps + pair_residue_weights.sum(dim=-1, keepdim=True)) + + per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1) + weighted = per_alignment * residue_weights + ret = per_alignment.gather( + dim=-1, index=weighted.max(dim=-1, + keepdim=True).indices).squeeze(dim=-1) + return ret diff --git a/modelscope/models/science/unifold/modules/embedders.py b/modelscope/models/science/unifold/modules/embedders.py new file mode 100644 index 00000000..84e87e2d --- /dev/null +++ b/modelscope/models/science/unifold/modules/embedders.py @@ -0,0 +1,290 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from unicore.modules import LayerNorm +from unicore.utils import one_hot + +from .common import Linear, SimpleModuleList, residual + + +class InputEmbedder(nn.Module): + + def __init__( + self, + tf_dim: int, + msa_dim: int, + d_pair: int, + d_msa: int, + relpos_k: int, + use_chain_relative: bool = False, + max_relative_chain: Optional[int] = None, + **kwargs, + ): + super(InputEmbedder, self).__init__() + + self.tf_dim = tf_dim + self.msa_dim = msa_dim + + self.d_pair = d_pair + self.d_msa = d_msa + + self.linear_tf_z_i = Linear(tf_dim, d_pair) + self.linear_tf_z_j = Linear(tf_dim, d_pair) + self.linear_tf_m = Linear(tf_dim, d_msa) + self.linear_msa_m = Linear(msa_dim, d_msa) + + # RPE stuff + self.relpos_k = relpos_k + self.use_chain_relative = use_chain_relative + self.max_relative_chain = max_relative_chain + if not self.use_chain_relative: + self.num_bins = 2 * self.relpos_k + 1 + else: + self.num_bins = 2 * self.relpos_k + 2 + self.num_bins += 1 # entity id + self.num_bins += 2 * max_relative_chain + 2 + + self.linear_relpos = Linear(self.num_bins, d_pair) + + def _relpos_indices( + self, + res_id: torch.Tensor, + sym_id: Optional[torch.Tensor] = None, + asym_id: Optional[torch.Tensor] = None, + entity_id: Optional[torch.Tensor] = None, + ): + + max_rel_res = self.relpos_k + rp = res_id[..., None] - res_id[..., None, :] + rp = rp.clip(-max_rel_res, max_rel_res) + max_rel_res + if not self.use_chain_relative: + return rp + else: + asym_id_same = asym_id[..., :, None] == asym_id[..., None, :] + rp[~asym_id_same] = 2 * max_rel_res + 1 + entity_id_same = entity_id[..., :, None] == entity_id[..., None, :] + rp_entity_id = entity_id_same.type(rp.dtype)[..., None] + + rel_sym_id = sym_id[..., :, None] - sym_id[..., None, :] + + max_rel_chain = self.max_relative_chain + + clipped_rel_chain = torch.clamp( + rel_sym_id + max_rel_chain, min=0, max=2 * max_rel_chain) + + clipped_rel_chain[~entity_id_same] = 2 * max_rel_chain + 1 + return rp, rp_entity_id, clipped_rel_chain + + def relpos_emb( + self, + res_id: torch.Tensor, + sym_id: Optional[torch.Tensor] = None, + asym_id: Optional[torch.Tensor] = None, + entity_id: Optional[torch.Tensor] = None, + num_sym: Optional[torch.Tensor] = None, + ): + + dtype = self.linear_relpos.weight.dtype + if not self.use_chain_relative: + rp = self._relpos_indices(res_id=res_id) + return self.linear_relpos( + one_hot(rp, num_classes=self.num_bins, dtype=dtype)) + else: + rp, rp_entity_id, rp_rel_chain = self._relpos_indices( + res_id=res_id, + sym_id=sym_id, + asym_id=asym_id, + entity_id=entity_id) + rp = one_hot(rp, num_classes=(2 * self.relpos_k + 2), dtype=dtype) + rp_entity_id = rp_entity_id.type(dtype) + rp_rel_chain = one_hot( + rp_rel_chain, + num_classes=(2 * self.max_relative_chain + 2), + dtype=dtype) + return self.linear_relpos( + torch.cat([rp, rp_entity_id, rp_rel_chain], dim=-1)) + + def forward( + self, + tf: torch.Tensor, + msa: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # [*, N_res, d_pair] + if self.tf_dim == 21: + # multimer use 21 target dim + tf = tf[..., 1:] + # convert type if necessary + tf = tf.type(self.linear_tf_z_i.weight.dtype) + msa = msa.type(self.linear_tf_z_i.weight.dtype) + n_clust = msa.shape[-3] + + msa_emb = self.linear_msa_m(msa) + # target_feat (aatype) into msa representation + tf_m = ( + self.linear_tf_m(tf).unsqueeze(-3).expand( + ((-1, ) * len(tf.shape[:-2]) + # noqa W504 + (n_clust, -1, -1)))) + msa_emb += tf_m + + tf_emb_i = self.linear_tf_z_i(tf) + tf_emb_j = self.linear_tf_z_j(tf) + pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :] + + return msa_emb, pair_emb + + +class RecyclingEmbedder(nn.Module): + + def __init__( + self, + d_msa: int, + d_pair: int, + min_bin: float, + max_bin: float, + num_bins: int, + inf: float = 1e8, + **kwargs, + ): + super(RecyclingEmbedder, self).__init__() + + self.d_msa = d_msa + self.d_pair = d_pair + self.min_bin = min_bin + self.max_bin = max_bin + self.num_bins = num_bins + self.inf = inf + + self.squared_bins = None + + self.linear = Linear(self.num_bins, self.d_pair) + self.layer_norm_m = LayerNorm(self.d_msa) + self.layer_norm_z = LayerNorm(self.d_pair) + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + m_update = self.layer_norm_m(m) + z_update = self.layer_norm_z(z) + + return m_update, z_update + + def recyle_pos( + self, + x: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + if self.squared_bins is None: + bins = torch.linspace( + self.min_bin, + self.max_bin, + self.num_bins, + dtype=torch.float if self.training else x.dtype, + device=x.device, + requires_grad=False, + ) + self.squared_bins = bins**2 + upper = torch.cat( + [self.squared_bins[1:], + self.squared_bins.new_tensor([self.inf])], + dim=-1) + if self.training: + x = x.float() + d = torch.sum( + (x[..., None, :] - x[..., None, :, :])**2, dim=-1, keepdims=True) + d = ((d > self.squared_bins) * # noqa W504 + (d < upper)).type(self.linear.weight.dtype) + d = self.linear(d) + return d + + +class TemplateAngleEmbedder(nn.Module): + + def __init__( + self, + d_in: int, + d_out: int, + **kwargs, + ): + super(TemplateAngleEmbedder, self).__init__() + + self.d_out = d_out + self.d_in = d_in + + self.linear_1 = Linear(self.d_in, self.d_out, init='relu') + self.act = nn.GELU() + self.linear_2 = Linear(self.d_out, self.d_out, init='relu') + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_1(x.type(self.linear_1.weight.dtype)) + x = self.act(x) + x = self.linear_2(x) + return x + + +class TemplatePairEmbedder(nn.Module): + + def __init__( + self, + d_in: int, + v2_d_in: list, + d_out: int, + d_pair: int, + v2_feature: bool = False, + **kwargs, + ): + super(TemplatePairEmbedder, self).__init__() + + self.d_out = d_out + self.v2_feature = v2_feature + if self.v2_feature: + self.d_in = v2_d_in + self.linear = SimpleModuleList() + for d_in in self.d_in: + self.linear.append(Linear(d_in, self.d_out, init='relu')) + self.z_layer_norm = LayerNorm(d_pair) + self.z_linear = Linear(d_pair, self.d_out, init='relu') + else: + self.d_in = d_in + self.linear = Linear(self.d_in, self.d_out, init='relu') + + def forward( + self, + x, + z, + ) -> torch.Tensor: + if not self.v2_feature: + x = self.linear(x.type(self.linear.weight.dtype)) + return x + else: + dtype = self.z_linear.weight.dtype + t = self.linear[0](x[0].type(dtype)) + for i, s in enumerate(x[1:]): + t = residual(t, self.linear[i + 1](s.type(dtype)), + self.training) + t = residual(t, self.z_linear(self.z_layer_norm(z)), self.training) + return t + + +class ExtraMSAEmbedder(nn.Module): + + def __init__( + self, + d_in: int, + d_out: int, + **kwargs, + ): + super(ExtraMSAEmbedder, self).__init__() + + self.d_in = d_in + self.d_out = d_out + self.linear = Linear(self.d_in, self.d_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x.type(self.linear.weight.dtype)) diff --git a/modelscope/models/science/unifold/modules/evoformer.py b/modelscope/models/science/unifold/modules/evoformer.py new file mode 100644 index 00000000..b0834986 --- /dev/null +++ b/modelscope/models/science/unifold/modules/evoformer.py @@ -0,0 +1,362 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from functools import partial +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from unicore.utils import checkpoint_sequential + +from .attentions import (MSAColumnAttention, MSAColumnGlobalAttention, + MSARowAttentionWithPairBias, TriangleAttentionEnding, + TriangleAttentionStarting) +from .common import (Linear, OuterProductMean, SimpleModuleList, Transition, + bias_dropout_residual, residual, tri_mul_residual) +from .triangle_multiplication import (TriangleMultiplicationIncoming, + TriangleMultiplicationOutgoing) + + +class EvoformerIteration(nn.Module): + + def __init__( + self, + d_msa: int, + d_pair: int, + d_hid_msa_att: int, + d_hid_opm: int, + d_hid_mul: int, + d_hid_pair_att: int, + num_heads_msa: int, + num_heads_pair: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + outer_product_mean_first: bool, + inf: float, + eps: float, + _is_extra_msa_stack: bool = False, + ): + super(EvoformerIteration, self).__init__() + + self._is_extra_msa_stack = _is_extra_msa_stack + self.outer_product_mean_first = outer_product_mean_first + + self.msa_att_row = MSARowAttentionWithPairBias( + d_msa=d_msa, + d_pair=d_pair, + d_hid=d_hid_msa_att, + num_heads=num_heads_msa, + ) + + if _is_extra_msa_stack: + self.msa_att_col = MSAColumnGlobalAttention( + d_in=d_msa, + d_hid=d_hid_msa_att, + num_heads=num_heads_msa, + inf=inf, + eps=eps, + ) + else: + self.msa_att_col = MSAColumnAttention( + d_msa, + d_hid_msa_att, + num_heads_msa, + ) + + self.msa_transition = Transition( + d_in=d_msa, + n=transition_n, + ) + + self.outer_product_mean = OuterProductMean( + d_msa, + d_pair, + d_hid_opm, + ) + + self.tri_mul_out = TriangleMultiplicationOutgoing( + d_pair, + d_hid_mul, + ) + self.tri_mul_in = TriangleMultiplicationIncoming( + d_pair, + d_hid_mul, + ) + + self.tri_att_start = TriangleAttentionStarting( + d_pair, + d_hid_pair_att, + num_heads_pair, + ) + self.tri_att_end = TriangleAttentionEnding( + d_pair, + d_hid_pair_att, + num_heads_pair, + ) + + self.pair_transition = Transition( + d_in=d_pair, + n=transition_n, + ) + + self.row_dropout_share_dim = -3 + self.col_dropout_share_dim = -2 + self.msa_dropout = msa_dropout + self.pair_dropout = pair_dropout + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: torch.Tensor, + pair_mask: torch.Tensor, + msa_row_attn_mask: torch.Tensor, + msa_col_attn_mask: Optional[torch.Tensor], + tri_start_attn_mask: torch.Tensor, + tri_end_attn_mask: torch.Tensor, + chunk_size: Optional[int] = None, + block_size: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + if self.outer_product_mean_first: + z = residual( + z, + self.outer_product_mean( + m, mask=msa_mask, chunk_size=chunk_size), self.training) + + m = bias_dropout_residual( + self.msa_att_row, + m, + self.msa_att_row( + m, z=z, attn_mask=msa_row_attn_mask, chunk_size=chunk_size), + self.row_dropout_share_dim, + self.msa_dropout, + self.training, + ) + if self._is_extra_msa_stack: + m = residual( + m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size), + self.training) + else: + m = bias_dropout_residual( + self.msa_att_col, + m, + self.msa_att_col( + m, attn_mask=msa_col_attn_mask, chunk_size=chunk_size), + self.col_dropout_share_dim, + self.msa_dropout, + self.training, + ) + m = residual(m, self.msa_transition(m, chunk_size=chunk_size), + self.training) + if not self.outer_product_mean_first: + z = residual( + z, + self.outer_product_mean( + m, mask=msa_mask, chunk_size=chunk_size), self.training) + + z = tri_mul_residual( + self.tri_mul_out, + z, + self.tri_mul_out(z, mask=pair_mask, block_size=block_size), + self.row_dropout_share_dim, + self.pair_dropout, + self.training, + block_size=block_size, + ) + + z = tri_mul_residual( + self.tri_mul_in, + z, + self.tri_mul_in(z, mask=pair_mask, block_size=block_size), + self.row_dropout_share_dim, + self.pair_dropout, + self.training, + block_size=block_size, + ) + + z = bias_dropout_residual( + self.tri_att_start, + z, + self.tri_att_start( + z, attn_mask=tri_start_attn_mask, chunk_size=chunk_size), + self.row_dropout_share_dim, + self.pair_dropout, + self.training, + ) + + z = bias_dropout_residual( + self.tri_att_end, + z, + self.tri_att_end( + z, attn_mask=tri_end_attn_mask, chunk_size=chunk_size), + self.col_dropout_share_dim, + self.pair_dropout, + self.training, + ) + z = residual(z, self.pair_transition(z, chunk_size=chunk_size), + self.training) + return m, z + + +class EvoformerStack(nn.Module): + + def __init__( + self, + d_msa: int, + d_pair: int, + d_hid_msa_att: int, + d_hid_opm: int, + d_hid_mul: int, + d_hid_pair_att: int, + d_single: int, + num_heads_msa: int, + num_heads_pair: int, + num_blocks: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + outer_product_mean_first: bool, + inf: float, + eps: float, + _is_extra_msa_stack: bool = False, + **kwargs, + ): + super(EvoformerStack, self).__init__() + + self._is_extra_msa_stack = _is_extra_msa_stack + + self.blocks = SimpleModuleList() + + for _ in range(num_blocks): + self.blocks.append( + EvoformerIteration( + d_msa=d_msa, + d_pair=d_pair, + d_hid_msa_att=d_hid_msa_att, + d_hid_opm=d_hid_opm, + d_hid_mul=d_hid_mul, + d_hid_pair_att=d_hid_pair_att, + num_heads_msa=num_heads_msa, + num_heads_pair=num_heads_pair, + transition_n=transition_n, + msa_dropout=msa_dropout, + pair_dropout=pair_dropout, + outer_product_mean_first=outer_product_mean_first, + inf=inf, + eps=eps, + _is_extra_msa_stack=_is_extra_msa_stack, + )) + if not self._is_extra_msa_stack: + self.linear = Linear(d_msa, d_single) + else: + self.linear = None + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: torch.Tensor, + pair_mask: torch.Tensor, + msa_row_attn_mask: torch.Tensor, + msa_col_attn_mask: torch.Tensor, + tri_start_attn_mask: torch.Tensor, + tri_end_attn_mask: torch.Tensor, + chunk_size: int, + block_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + blocks = [ + partial( + b, + msa_mask=msa_mask, + pair_mask=pair_mask, + msa_row_attn_mask=msa_row_attn_mask, + msa_col_attn_mask=msa_col_attn_mask, + tri_start_attn_mask=tri_start_attn_mask, + tri_end_attn_mask=tri_end_attn_mask, + chunk_size=chunk_size, + block_size=block_size) for b in self.blocks + ] + + m, z = checkpoint_sequential( + blocks, + input=(m, z), + ) + + s = None + if not self._is_extra_msa_stack: + seq_dim = -3 + index = torch.tensor([0], device=m.device) + s = self.linear(torch.index_select(m, dim=seq_dim, index=index)) + s = s.squeeze(seq_dim) + + return m, z, s + + +class ExtraMSAStack(EvoformerStack): + + def __init__( + self, + d_msa: int, + d_pair: int, + d_hid_msa_att: int, + d_hid_opm: int, + d_hid_mul: int, + d_hid_pair_att: int, + num_heads_msa: int, + num_heads_pair: int, + num_blocks: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + outer_product_mean_first: bool, + inf: float, + eps: float, + **kwargs, + ): + super(ExtraMSAStack, self).__init__( + d_msa=d_msa, + d_pair=d_pair, + d_hid_msa_att=d_hid_msa_att, + d_hid_opm=d_hid_opm, + d_hid_mul=d_hid_mul, + d_hid_pair_att=d_hid_pair_att, + d_single=None, + num_heads_msa=num_heads_msa, + num_heads_pair=num_heads_pair, + num_blocks=num_blocks, + transition_n=transition_n, + msa_dropout=msa_dropout, + pair_dropout=pair_dropout, + outer_product_mean_first=outer_product_mean_first, + inf=inf, + eps=eps, + _is_extra_msa_stack=True, + ) + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: Optional[torch.Tensor] = None, + pair_mask: Optional[torch.Tensor] = None, + msa_row_attn_mask: torch.Tensor = None, + msa_col_attn_mask: torch.Tensor = None, + tri_start_attn_mask: torch.Tensor = None, + tri_end_attn_mask: torch.Tensor = None, + chunk_size: int = None, + block_size: int = None, + ) -> torch.Tensor: + _, z, _ = super().forward( + m, + z, + msa_mask=msa_mask, + pair_mask=pair_mask, + msa_row_attn_mask=msa_row_attn_mask, + msa_col_attn_mask=msa_col_attn_mask, + tri_start_attn_mask=tri_start_attn_mask, + tri_end_attn_mask=tri_end_attn_mask, + chunk_size=chunk_size, + block_size=block_size) + return z diff --git a/modelscope/models/science/unifold/modules/featurization.py b/modelscope/models/science/unifold/modules/featurization.py new file mode 100644 index 00000000..b62adc9d --- /dev/null +++ b/modelscope/models/science/unifold/modules/featurization.py @@ -0,0 +1,195 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from typing import Dict + +import torch +import torch.nn as nn +from unicore.utils import batched_gather, one_hot + +from modelscope.models.science.unifold.data import residue_constants as rc +from .frame import Frame + + +def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): + is_gly = aatype == rc.restype_order['G'] + ca_idx = rc.atom_order['CA'] + cb_idx = rc.atom_order['CB'] + pseudo_beta = torch.where( + is_gly[..., None].expand(*((-1, ) * len(is_gly.shape)), 3), + all_atom_positions[..., ca_idx, :], + all_atom_positions[..., cb_idx, :], + ) + + if all_atom_masks is not None: + pseudo_beta_mask = torch.where( + is_gly, + all_atom_masks[..., ca_idx], + all_atom_masks[..., cb_idx], + ) + return pseudo_beta, pseudo_beta_mask + else: + return pseudo_beta + + +def atom14_to_atom37(atom14, batch): + atom37_data = batched_gather( + atom14, + batch['residx_atom37_to_atom14'], + dim=-2, + num_batch_dims=len(atom14.shape[:-2]), + ) + + atom37_data = atom37_data * batch['atom37_atom_exists'][..., None] + + return atom37_data + + +def build_template_angle_feat(template_feats, v2_feature=False): + template_aatype = template_feats['template_aatype'] + torsion_angles_sin_cos = template_feats['template_torsion_angles_sin_cos'] + torsion_angles_mask = template_feats['template_torsion_angles_mask'] + if not v2_feature: + alt_torsion_angles_sin_cos = template_feats[ + 'template_alt_torsion_angles_sin_cos'] + template_angle_feat = torch.cat( + [ + one_hot(template_aatype, 22), + torsion_angles_sin_cos.reshape( + *torsion_angles_sin_cos.shape[:-2], 14), + alt_torsion_angles_sin_cos.reshape( + *alt_torsion_angles_sin_cos.shape[:-2], 14), + torsion_angles_mask, + ], + dim=-1, + ) + template_angle_mask = torsion_angles_mask[..., 2] + else: + chi_mask = torsion_angles_mask[..., 3:] + chi_angles_sin = torsion_angles_sin_cos[..., 3:, 0] * chi_mask + chi_angles_cos = torsion_angles_sin_cos[..., 3:, 1] * chi_mask + template_angle_feat = torch.cat( + [ + one_hot(template_aatype, 22), + chi_angles_sin, + chi_angles_cos, + chi_mask, + ], + dim=-1, + ) + template_angle_mask = chi_mask[..., 0] + return template_angle_feat, template_angle_mask + + +def build_template_pair_feat( + batch, + min_bin, + max_bin, + num_bins, + eps=1e-20, + inf=1e8, +): + template_mask = batch['template_pseudo_beta_mask'] + template_mask_2d = template_mask[..., None] * template_mask[..., None, :] + + tpb = batch['template_pseudo_beta'] + dgram = torch.sum( + (tpb[..., None, :] - tpb[..., None, :, :])**2, dim=-1, keepdim=True) + lower = torch.linspace(min_bin, max_bin, num_bins, device=tpb.device)**2 + upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1) + dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype) + + to_concat = [dgram, template_mask_2d[..., None]] + + aatype_one_hot = nn.functional.one_hot( + batch['template_aatype'], + rc.restype_num + 2, + ) + + n_res = batch['template_aatype'].shape[-1] + to_concat.append(aatype_one_hot[..., None, :, :].expand( + *aatype_one_hot.shape[:-2], n_res, -1, -1)) + to_concat.append(aatype_one_hot[..., + None, :].expand(*aatype_one_hot.shape[:-2], + -1, n_res, -1)) + + to_concat.append(template_mask_2d.new_zeros(*template_mask_2d.shape, 3)) + to_concat.append(template_mask_2d[..., None]) + + act = torch.cat(to_concat, dim=-1) + act = act * template_mask_2d[..., None] + + return act + + +def build_template_pair_feat_v2( + batch, + min_bin, + max_bin, + num_bins, + multichain_mask_2d=None, + eps=1e-20, + inf=1e8, +): + template_mask = batch['template_pseudo_beta_mask'] + template_mask_2d = template_mask[..., None] * template_mask[..., None, :] + if multichain_mask_2d is not None: + template_mask_2d *= multichain_mask_2d + + tpb = batch['template_pseudo_beta'] + dgram = torch.sum( + (tpb[..., None, :] - tpb[..., None, :, :])**2, dim=-1, keepdim=True) + lower = torch.linspace(min_bin, max_bin, num_bins, device=tpb.device)**2 + upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1) + dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype) + dgram *= template_mask_2d[..., None] + to_concat = [dgram, template_mask_2d[..., None]] + + aatype_one_hot = one_hot( + batch['template_aatype'], + rc.restype_num + 2, + ) + + n_res = batch['template_aatype'].shape[-1] + to_concat.append(aatype_one_hot[..., None, :, :].expand( + *aatype_one_hot.shape[:-2], n_res, -1, -1)) + to_concat.append(aatype_one_hot[..., + None, :].expand(*aatype_one_hot.shape[:-2], + -1, n_res, -1)) + + n, ca, c = [rc.atom_order[a] for a in ['N', 'CA', 'C']] + rigids = Frame.make_transform_from_reference( + n_xyz=batch['template_all_atom_positions'][..., n, :], + ca_xyz=batch['template_all_atom_positions'][..., ca, :], + c_xyz=batch['template_all_atom_positions'][..., c, :], + eps=eps, + ) + points = rigids.get_trans()[..., None, :, :] + rigid_vec = rigids[..., None].invert_apply(points) + + inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1)) + + t_aa_masks = batch['template_all_atom_mask'] + backbone_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., + c] + backbone_mask_2d = backbone_mask[..., :, None] * backbone_mask[..., + None, :] + if multichain_mask_2d is not None: + backbone_mask_2d *= multichain_mask_2d + + inv_distance_scalar = inv_distance_scalar * backbone_mask_2d + unit_vector_data = rigid_vec * inv_distance_scalar[..., None] + to_concat.extend(torch.unbind(unit_vector_data[..., None, :], dim=-1)) + to_concat.append(backbone_mask_2d[..., None]) + + return to_concat + + +def build_extra_msa_feat(batch): + msa_1hot = one_hot(batch['extra_msa'], 23) + msa_feat = [ + msa_1hot, + batch['extra_msa_has_deletion'].unsqueeze(-1), + batch['extra_msa_deletion_value'].unsqueeze(-1), + ] + return torch.cat(msa_feat, dim=-1) diff --git a/modelscope/models/science/unifold/modules/frame.py b/modelscope/models/science/unifold/modules/frame.py new file mode 100644 index 00000000..5a0e4d6a --- /dev/null +++ b/modelscope/models/science/unifold/modules/frame.py @@ -0,0 +1,562 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from __future__ import annotations # noqa +from typing import Any, Callable, Iterable, Optional, Sequence, Tuple + +import numpy as np +import torch + + +def zero_translation( + batch_dims: Tuple[int], + dtype: Optional[torch.dtype] = torch.float, + device: Optional[torch.device] = torch.device('cpu'), + requires_grad: bool = False, +) -> torch.Tensor: + trans = torch.zeros((*batch_dims, 3), + dtype=dtype, + device=device, + requires_grad=requires_grad) + return trans + + +# pylint: disable=bad-whitespace +_QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32) + +_QUAT_TO_ROT[0, 0] = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] # rr +_QUAT_TO_ROT[1, 1] = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] # ii +_QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [0, 1, 0], [0, 0, -1]] # jj +_QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [0, -1, 0], [0, 0, 1]] # kk + +_QUAT_TO_ROT[1, 2] = [[0, 2, 0], [2, 0, 0], [0, 0, 0]] # ij +_QUAT_TO_ROT[1, 3] = [[0, 0, 2], [0, 0, 0], [2, 0, 0]] # ik +_QUAT_TO_ROT[2, 3] = [[0, 0, 0], [0, 0, 2], [0, 2, 0]] # jk + +_QUAT_TO_ROT[0, 1] = [[0, 0, 0], [0, 0, -2], [0, 2, 0]] # ir +_QUAT_TO_ROT[0, 2] = [[0, 0, 2], [0, 0, 0], [-2, 0, 0]] # jr +_QUAT_TO_ROT[0, 3] = [[0, -2, 0], [2, 0, 0], [0, 0, 0]] # kr + +_QUAT_TO_ROT = _QUAT_TO_ROT.reshape(4, 4, 9) +_QUAT_TO_ROT_tensor = torch.from_numpy(_QUAT_TO_ROT) + +_QUAT_MULTIPLY = np.zeros((4, 4, 4)) +_QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], + [0, 0, 0, -1]] + +_QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], + [0, 0, -1, 0]] + +_QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0], [0, 0, 0, -1], [1, 0, 0, 0], + [0, 1, 0, 0]] + +_QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0], + [1, 0, 0, 0]] + +_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :] +_QUAT_MULTIPLY_BY_VEC_tensor = torch.from_numpy(_QUAT_MULTIPLY_BY_VEC) + + +class Rotation: + + def __init__( + self, + mat: torch.Tensor, + ): + if mat.shape[-2:] != (3, 3): + raise ValueError(f'incorrect rotation shape: {mat.shape}') + self._mat = mat + + @staticmethod + def identity( + shape, + dtype: Optional[torch.dtype] = torch.float, + device: Optional[torch.device] = torch.device('cpu'), + requires_grad: bool = False, + ) -> Rotation: + mat = torch.eye( + 3, dtype=dtype, device=device, requires_grad=requires_grad) + mat = mat.view(*((1, ) * len(shape)), 3, 3) + mat = mat.expand(*shape, -1, -1) + return Rotation(mat) + + @staticmethod + def mat_mul_mat(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return (a.float() @ b.float()).type(a.dtype) + + @staticmethod + def mat_mul_vec(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + return (r.float() @ t.float().unsqueeze(-1)).squeeze(-1).type(t.dtype) + + def __getitem__(self, index: Any) -> Rotation: + if not isinstance(index, tuple): + index = (index, ) + return Rotation(mat=self._mat[index + (slice(None), slice(None))]) + + def __mul__(self, right: Any) -> Rotation: + if isinstance(right, (int, float)): + return Rotation(mat=self._mat * right) + elif isinstance(right, torch.Tensor): + return Rotation(mat=self._mat * right[..., None, None]) + else: + raise TypeError( + f'multiplicand must be a tensor or a number, got {type(right)}.' + ) + + def __rmul__(self, left: Any) -> Rotation: + return self.__mul__(left) + + def __matmul__(self, other: Rotation) -> Rotation: + new_mat = Rotation.mat_mul_mat(self.rot_mat, other.rot_mat) + return Rotation(mat=new_mat) + + @property + def _inv_mat(self): + return self._mat.transpose(-1, -2) + + @property + def rot_mat(self) -> torch.Tensor: + return self._mat + + def invert(self) -> Rotation: + return Rotation(mat=self._inv_mat) + + def apply(self, pts: torch.Tensor) -> torch.Tensor: + return Rotation.mat_mul_vec(self._mat, pts) + + def invert_apply(self, pts: torch.Tensor) -> torch.Tensor: + return Rotation.mat_mul_vec(self._inv_mat, pts) + + # inherit tensor behaviors + @property + def shape(self) -> torch.Size: + s = self._mat.shape[:-2] + return s + + @property + def dtype(self) -> torch.dtype: + return self._mat.dtype + + @property + def device(self) -> torch.device: + return self._mat.device + + @property + def requires_grad(self) -> bool: + return self._mat.requires_grad + + def unsqueeze(self, dim: int) -> Rotation: + if dim >= len(self.shape): + raise ValueError('Invalid dimension') + + rot_mats = self._mat.unsqueeze(dim if dim >= 0 else dim - 2) + return Rotation(mat=rot_mats) + + def map_tensor_fn(self, fn: Callable[[torch.Tensor], + torch.Tensor]) -> Rotation: + mat = self._mat.view(self._mat.shape[:-2] + (9, )) + mat = torch.stack(list(map(fn, torch.unbind(mat, dim=-1))), dim=-1) + mat = mat.view(mat.shape[:-1] + (3, 3)) + return Rotation(mat=mat) + + @staticmethod + def cat(rs: Sequence[Rotation], dim: int) -> Rotation: + rot_mats = [r.rot_mat for r in rs] + rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2) + + return Rotation(mat=rot_mats) + + def cuda(self) -> Rotation: + return Rotation(mat=self._mat.cuda()) + + def to(self, device: Optional[torch.device], + dtype: Optional[torch.dtype]) -> Rotation: + return Rotation(mat=self._mat.to(device=device, dtype=dtype)) + + def type(self, dtype: Optional[torch.dtype]) -> Rotation: + return Rotation(mat=self._mat.type(dtype)) + + def detach(self) -> Rotation: + return Rotation(mat=self._mat.detach()) + + +class Frame: + + def __init__( + self, + rotation: Optional[Rotation], + translation: Optional[torch.Tensor], + ): + if rotation is None and translation is None: + rotation = Rotation.identity((0, )) + translation = zero_translation((0, )) + elif translation is None: + translation = zero_translation(rotation.shape, rotation.dtype, + rotation.device, + rotation.requires_grad) + + elif rotation is None: + rotation = Rotation.identity( + translation.shape[:-1], + translation.dtype, + translation.device, + translation.requires_grad, + ) + + if (rotation.shape != translation.shape[:-1]) or (rotation.device + != # noqa W504 + translation.device): + raise ValueError('RotationMatrix and translation incompatible') + + self._r = rotation + self._t = translation + + @staticmethod + def identity( + shape: Iterable[int], + dtype: Optional[torch.dtype] = torch.float, + device: Optional[torch.device] = torch.device('cpu'), + requires_grad: bool = False, + ) -> Frame: + return Frame( + Rotation.identity(shape, dtype, device, requires_grad), + zero_translation(shape, dtype, device, requires_grad), + ) + + def __getitem__( + self, + index: Any, + ) -> Frame: + if type(index) != tuple: + index = (index, ) + + return Frame( + self._r[index], + self._t[index + (slice(None), )], + ) + + def __mul__( + self, + right: torch.Tensor, + ) -> Frame: + if not (isinstance(right, torch.Tensor)): + raise TypeError('The other multiplicand must be a Tensor') + + new_rots = self._r * right + new_trans = self._t * right[..., None] + + return Frame(new_rots, new_trans) + + def __rmul__( + self, + left: torch.Tensor, + ) -> Frame: + return self.__mul__(left) + + @property + def shape(self) -> torch.Size: + s = self._t.shape[:-1] + return s + + @property + def device(self) -> torch.device: + return self._t.device + + def get_rots(self) -> Rotation: + return self._r + + def get_trans(self) -> torch.Tensor: + return self._t + + def compose( + self, + other: Frame, + ) -> Frame: + new_rot = self._r @ other._r + new_trans = self._r.apply(other._t) + self._t + return Frame(new_rot, new_trans) + + def apply( + self, + pts: torch.Tensor, + ) -> torch.Tensor: + rotated = self._r.apply(pts) + return rotated + self._t + + def invert_apply(self, pts: torch.Tensor) -> torch.Tensor: + pts = pts - self._t + return self._r.invert_apply(pts) + + def invert(self) -> Frame: + rot_inv = self._r.invert() + trn_inv = rot_inv.apply(self._t) + + return Frame(rot_inv, -1 * trn_inv) + + def map_tensor_fn(self, fn: Callable[[torch.Tensor], + torch.Tensor]) -> Frame: + new_rots = self._r.map_tensor_fn(fn) + new_trans = torch.stack( + list(map(fn, torch.unbind(self._t, dim=-1))), dim=-1) + + return Frame(new_rots, new_trans) + + def to_tensor_4x4(self) -> torch.Tensor: + tensor = self._t.new_zeros((*self.shape, 4, 4)) + tensor[..., :3, :3] = self._r.rot_mat + tensor[..., :3, 3] = self._t + tensor[..., 3, 3] = 1 + return tensor + + @staticmethod + def from_tensor_4x4(t: torch.Tensor) -> Frame: + if t.shape[-2:] != (4, 4): + raise ValueError('Incorrectly shaped input tensor') + + rots = Rotation(mat=t[..., :3, :3]) + trans = t[..., :3, 3] + + return Frame(rots, trans) + + @staticmethod + def from_3_points( + p_neg_x_axis: torch.Tensor, + origin: torch.Tensor, + p_xy_plane: torch.Tensor, + eps: float = 1e-8, + ) -> Frame: + p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1) + origin = torch.unbind(origin, dim=-1) + p_xy_plane = torch.unbind(p_xy_plane, dim=-1) + + e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)] + e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)] + + denom = torch.sqrt(sum((c * c for c in e0)) + eps) + e0 = [c / denom for c in e0] + dot = sum((c1 * c2 for c1, c2 in zip(e0, e1))) + e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)] + denom = torch.sqrt(sum((c * c for c in e1)) + eps) + e1 = [c / denom for c in e1] + e2 = [ + e0[1] * e1[2] - e0[2] * e1[1], + e0[2] * e1[0] - e0[0] * e1[2], + e0[0] * e1[1] - e0[1] * e1[0], + ] + + rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1) + rots = rots.reshape(rots.shape[:-1] + (3, 3)) + + rot_obj = Rotation(mat=rots) + + return Frame(rot_obj, torch.stack(origin, dim=-1)) + + def unsqueeze( + self, + dim: int, + ) -> Frame: + if dim >= len(self.shape): + raise ValueError('Invalid dimension') + rots = self._r.unsqueeze(dim) + trans = self._t.unsqueeze(dim if dim >= 0 else dim - 1) + + return Frame(rots, trans) + + @staticmethod + def cat( + Ts: Sequence[Frame], + dim: int, + ) -> Frame: + rots = Rotation.cat([T._r for T in Ts], dim) + trans = torch.cat([T._t for T in Ts], dim=dim if dim >= 0 else dim - 1) + + return Frame(rots, trans) + + def apply_rot_fn(self, fn: Callable[[Rotation], Rotation]) -> Frame: + return Frame(fn(self._r), self._t) + + def apply_trans_fn(self, fn: Callable[[torch.Tensor], + torch.Tensor]) -> Frame: + return Frame(self._r, fn(self._t)) + + def scale_translation(self, trans_scale_factor: float) -> Frame: + # fn = lambda t: t * trans_scale_factor + def fn(t): + return t * trans_scale_factor + + return self.apply_trans_fn(fn) + + def stop_rot_gradient(self) -> Frame: + # fn = lambda r: r.detach() + def fn(r): + return r.detach() + + return self.apply_rot_fn(fn) + + @staticmethod + def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20): + input_dtype = ca_xyz.dtype + n_xyz = n_xyz.float() + ca_xyz = ca_xyz.float() + c_xyz = c_xyz.float() + n_xyz = n_xyz - ca_xyz + c_xyz = c_xyz - ca_xyz + + c_x, c_y, d_pair = [c_xyz[..., i] for i in range(3)] + norm = torch.sqrt(eps + c_x**2 + c_y**2) + sin_c1 = -c_y / norm + cos_c1 = c_x / norm + + c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3)) + c1_rots[..., 0, 0] = cos_c1 + c1_rots[..., 0, 1] = -1 * sin_c1 + c1_rots[..., 1, 0] = sin_c1 + c1_rots[..., 1, 1] = cos_c1 + c1_rots[..., 2, 2] = 1 + + norm = torch.sqrt(eps + c_x**2 + c_y**2 + d_pair**2) + sin_c2 = d_pair / norm + cos_c2 = torch.sqrt(c_x**2 + c_y**2) / norm + + c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) + c2_rots[..., 0, 0] = cos_c2 + c2_rots[..., 0, 2] = sin_c2 + c2_rots[..., 1, 1] = 1 + c2_rots[..., 2, 0] = -1 * sin_c2 + c2_rots[..., 2, 2] = cos_c2 + + c_rots = Rotation.mat_mul_mat(c2_rots, c1_rots) + n_xyz = Rotation.mat_mul_vec(c_rots, n_xyz) + + _, n_y, n_z = [n_xyz[..., i] for i in range(3)] + norm = torch.sqrt(eps + n_y**2 + n_z**2) + sin_n = -n_z / norm + cos_n = n_y / norm + + n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) + n_rots[..., 0, 0] = 1 + n_rots[..., 1, 1] = cos_n + n_rots[..., 1, 2] = -1 * sin_n + n_rots[..., 2, 1] = sin_n + n_rots[..., 2, 2] = cos_n + + rots = Rotation.mat_mul_mat(n_rots, c_rots) + + rots = rots.transpose(-1, -2) + rot_obj = Rotation(mat=rots.type(input_dtype)) + + return Frame(rot_obj, ca_xyz.type(input_dtype)) + + def cuda(self) -> Frame: + return Frame(self._r.cuda(), self._t.cuda()) + + @property + def dtype(self) -> torch.dtype: + assert self._r.dtype == self._t.dtype + return self._r.dtype + + def type(self, dtype) -> Frame: + return Frame(self._r.type(dtype), self._t.type(dtype)) + + +class Quaternion: + + def __init__(self, quaternion: torch.Tensor, translation: torch.Tensor): + if quaternion.shape[-1] != 4: + raise ValueError(f'incorrect quaternion shape: {quaternion.shape}') + self._q = quaternion + self._t = translation + + @staticmethod + def identity( + shape: Iterable[int], + dtype: Optional[torch.dtype] = torch.float, + device: Optional[torch.device] = torch.device('cpu'), + requires_grad: bool = False, + ) -> Quaternion: + trans = zero_translation(shape, dtype, device, requires_grad) + quats = torch.zeros((*shape, 4), + dtype=dtype, + device=device, + requires_grad=requires_grad) + with torch.no_grad(): + quats[..., 0] = 1 + return Quaternion(quats, trans) + + def get_quats(self): + return self._q + + def get_trans(self): + return self._t + + def get_rot_mats(self): + quats = self.get_quats() + rot_mats = Quaternion.quat_to_rot(quats) + return rot_mats + + @staticmethod + def quat_to_rot(normalized_quat): + global _QUAT_TO_ROT_tensor + dtype = normalized_quat.dtype + normalized_quat = normalized_quat.float() + if _QUAT_TO_ROT_tensor.device != normalized_quat.device: + _QUAT_TO_ROT_tensor = _QUAT_TO_ROT_tensor.to( + normalized_quat.device) + rot_tensor = torch.sum( + _QUAT_TO_ROT_tensor * normalized_quat[..., :, None, None] + * normalized_quat[..., None, :, None], + dim=(-3, -2), + ) + rot_tensor = rot_tensor.type(dtype) + rot_tensor = rot_tensor.view(*rot_tensor.shape[:-1], 3, 3) + return rot_tensor + + @staticmethod + def normalize_quat(quats): + dtype = quats.dtype + quats = quats.float() + quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True) + quats = quats.type(dtype) + return quats + + @staticmethod + def quat_multiply_by_vec(quat, vec): + dtype = quat.dtype + quat = quat.float() + vec = vec.float() + global _QUAT_MULTIPLY_BY_VEC_tensor + if _QUAT_MULTIPLY_BY_VEC_tensor.device != quat.device: + _QUAT_MULTIPLY_BY_VEC_tensor = _QUAT_MULTIPLY_BY_VEC_tensor.to( + quat.device) + mat = _QUAT_MULTIPLY_BY_VEC_tensor + reshaped_mat = mat.view((1, ) * len(quat.shape[:-1]) + mat.shape) + return torch.sum( + reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None], + dim=(-3, -2), + ).type(dtype) + + def compose_q_update_vec(self, + q_update_vec: torch.Tensor, + normalize_quats: bool = True) -> torch.Tensor: + quats = self.get_quats() + new_quats = quats + Quaternion.quat_multiply_by_vec( + quats, q_update_vec) + if normalize_quats: + new_quats = Quaternion.normalize_quat(new_quats) + return new_quats + + def compose_update_vec( + self, + update_vec: torch.Tensor, + pre_rot_mat: Rotation, + ) -> Quaternion: + q_vec, t_vec = update_vec[..., :3], update_vec[..., 3:] + new_quats = self.compose_q_update_vec(q_vec) + + trans_update = pre_rot_mat.apply(t_vec) + new_trans = self._t + trans_update + + return Quaternion(new_quats, new_trans) + + def stop_rot_gradient(self) -> Quaternion: + return Quaternion(self._q.detach(), self._t) diff --git a/modelscope/models/science/unifold/modules/structure_module.py b/modelscope/models/science/unifold/modules/structure_module.py new file mode 100644 index 00000000..5d4da30b --- /dev/null +++ b/modelscope/models/science/unifold/modules/structure_module.py @@ -0,0 +1,592 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import math +from typing import Tuple + +import torch +import torch.nn as nn +from unicore.modules import LayerNorm, softmax_dropout +from unicore.utils import dict_multimap, one_hot, permute_final_dims + +from modelscope.models.science.unifold.data.residue_constants import ( + restype_atom14_mask, restype_atom14_rigid_group_positions, + restype_atom14_to_rigid_group, restype_rigid_group_default_frame) +from .attentions import gen_attn_mask +from .common import Linear, SimpleModuleList, residual +from .frame import Frame, Quaternion, Rotation + + +def ipa_point_weights_init_(weights): + with torch.no_grad(): + softplus_inverse_1 = 0.541324854612918 + weights.fill_(softplus_inverse_1) + + +def torsion_angles_to_frames( + frame: Frame, + alpha: torch.Tensor, + aatype: torch.Tensor, + default_frames: torch.Tensor, +): + default_frame = Frame.from_tensor_4x4(default_frames[aatype, ...]) + + bb_rot = alpha.new_zeros((*((1, ) * len(alpha.shape[:-1])), 2)) + bb_rot[..., 1] = 1 + + alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], + dim=-2) + + all_rots = alpha.new_zeros(default_frame.get_rots().rot_mat.shape) + all_rots[..., 0, 0] = 1 + all_rots[..., 1, 1] = alpha[..., 1] + all_rots[..., 1, 2] = -alpha[..., 0] + all_rots[..., 2, 1:] = alpha + + all_rots = Frame(Rotation(mat=all_rots), None) + + all_frames = default_frame.compose(all_rots) + + chi2_frame_to_frame = all_frames[..., 5] + chi3_frame_to_frame = all_frames[..., 6] + chi4_frame_to_frame = all_frames[..., 7] + + chi1_frame_to_bb = all_frames[..., 4] + chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame) + chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) + chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) + + all_frames_to_bb = Frame.cat( + [ + all_frames[..., :5], + chi2_frame_to_bb.unsqueeze(-1), + chi3_frame_to_bb.unsqueeze(-1), + chi4_frame_to_bb.unsqueeze(-1), + ], + dim=-1, + ) + + all_frames_to_global = frame[..., None].compose(all_frames_to_bb) + + return all_frames_to_global + + +def frames_and_literature_positions_to_atom14_pos( + frame: Frame, + aatype: torch.Tensor, + default_frames, + group_idx, + atom_mask, + lit_positions, +): + group_mask = group_idx[aatype, ...] + group_mask = one_hot( + group_mask, + num_classes=default_frames.shape[-3], + ) + + t_atoms_to_global = frame[..., None, :] * group_mask + t_atoms_to_global = t_atoms_to_global.map_tensor_fn( + lambda x: torch.sum(x, dim=-1)) + + atom_mask = atom_mask[aatype, ...].unsqueeze(-1) + + lit_positions = lit_positions[aatype, ...] + pred_positions = t_atoms_to_global.apply(lit_positions) + pred_positions = pred_positions * atom_mask + + return pred_positions + + +class SideChainAngleResnetIteration(nn.Module): + + def __init__(self, d_hid): + super(SideChainAngleResnetIteration, self).__init__() + + self.d_hid = d_hid + + self.linear_1 = Linear(self.d_hid, self.d_hid, init='relu') + self.act = nn.GELU() + self.linear_2 = Linear(self.d_hid, self.d_hid, init='final') + + def forward(self, s: torch.Tensor) -> torch.Tensor: + + x = self.act(s) + x = self.linear_1(x) + x = self.act(x) + x = self.linear_2(x) + + return residual(s, x, self.training) + + +class SidechainAngleResnet(nn.Module): + + def __init__(self, d_in, d_hid, num_blocks, num_angles): + super(SidechainAngleResnet, self).__init__() + + self.linear_in = Linear(d_in, d_hid) + self.act = nn.GELU() + self.linear_initial = Linear(d_in, d_hid) + + self.layers = SimpleModuleList() + for _ in range(num_blocks): + self.layers.append(SideChainAngleResnetIteration(d_hid=d_hid)) + + self.linear_out = Linear(d_hid, num_angles * 2) + + def forward(self, s: torch.Tensor, + initial_s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + + initial_s = self.linear_initial(self.act(initial_s)) + s = self.linear_in(self.act(s)) + + s = s + initial_s + + for layer in self.layers: + s = layer(s) + + s = self.linear_out(self.act(s)) + + s = s.view(s.shape[:-1] + (-1, 2)) + + unnormalized_s = s + norm_denom = torch.sqrt( + torch.clamp( + torch.sum(s.float()**2, dim=-1, keepdim=True), + min=1e-12, + )) + s = s.float() / norm_denom + + return unnormalized_s, s.type(unnormalized_s.dtype) + + +class InvariantPointAttention(nn.Module): + + def __init__( + self, + d_single: int, + d_pair: int, + d_hid: int, + num_heads: int, + num_qk_points: int, + num_v_points: int, + separate_kv: bool = False, + bias: bool = True, + eps: float = 1e-8, + ): + super(InvariantPointAttention, self).__init__() + + self.d_hid = d_hid + self.num_heads = num_heads + self.num_qk_points = num_qk_points + self.num_v_points = num_v_points + self.eps = eps + + hc = self.d_hid * self.num_heads + self.linear_q = Linear(d_single, hc, bias=bias) + self.separate_kv = separate_kv + if self.separate_kv: + self.linear_k = Linear(d_single, hc, bias=bias) + self.linear_v = Linear(d_single, hc, bias=bias) + else: + self.linear_kv = Linear(d_single, 2 * hc, bias=bias) + + hpq = self.num_heads * self.num_qk_points * 3 + self.linear_q_points = Linear(d_single, hpq) + hpk = self.num_heads * self.num_qk_points * 3 + hpv = self.num_heads * self.num_v_points * 3 + if self.separate_kv: + self.linear_k_points = Linear(d_single, hpk) + self.linear_v_points = Linear(d_single, hpv) + else: + hpkv = hpk + hpv + self.linear_kv_points = Linear(d_single, hpkv) + + self.linear_b = Linear(d_pair, self.num_heads) + + self.head_weights = nn.Parameter(torch.zeros((num_heads))) + ipa_point_weights_init_(self.head_weights) + + concat_out_dim = self.num_heads * ( + d_pair + self.d_hid + self.num_v_points * 4) + self.linear_out = Linear(concat_out_dim, d_single, init='final') + + self.softplus = nn.Softplus() + + def forward( + self, + s: torch.Tensor, + z: torch.Tensor, + f: Frame, + square_mask: torch.Tensor, + ) -> torch.Tensor: + q = self.linear_q(s) + + q = q.view(q.shape[:-1] + (self.num_heads, -1)) + + if self.separate_kv: + k = self.linear_k(s) + v = self.linear_v(s) + k = k.view(k.shape[:-1] + (self.num_heads, -1)) + v = v.view(v.shape[:-1] + (self.num_heads, -1)) + else: + kv = self.linear_kv(s) + kv = kv.view(kv.shape[:-1] + (self.num_heads, -1)) + k, v = torch.split(kv, self.d_hid, dim=-1) + + q_pts = self.linear_q_points(s) + + def process_points(pts, no_points): + shape = pts.shape[:-1] + (pts.shape[-1] // 3, 3) + if self.separate_kv: + # alphafold-multimer uses different layout + pts = pts.view(pts.shape[:-1] + + (self.num_heads, no_points * 3)) + pts = torch.split(pts, pts.shape[-1] // 3, dim=-1) + pts = torch.stack(pts, dim=-1).view(*shape) + pts = f[..., None].apply(pts) + + pts = pts.view(pts.shape[:-2] + (self.num_heads, no_points, 3)) + return pts + + q_pts = process_points(q_pts, self.num_qk_points) + + if self.separate_kv: + k_pts = self.linear_k_points(s) + v_pts = self.linear_v_points(s) + k_pts = process_points(k_pts, self.num_qk_points) + v_pts = process_points(v_pts, self.num_v_points) + else: + kv_pts = self.linear_kv_points(s) + + kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1) + kv_pts = torch.stack(kv_pts, dim=-1) + kv_pts = f[..., None].apply(kv_pts) + + kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.num_heads, -1, 3)) + + k_pts, v_pts = torch.split( + kv_pts, [self.num_qk_points, self.num_v_points], dim=-2) + + bias = self.linear_b(z) + + attn = torch.matmul( + permute_final_dims(q, (1, 0, 2)), + permute_final_dims(k, (1, 2, 0)), + ) + + if self.training: + attn = attn * math.sqrt(1.0 / (3 * self.d_hid)) + attn = attn + ( + math.sqrt(1.0 / 3) * permute_final_dims(bias, (2, 0, 1))) + pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) + pt_att = pt_att.float()**2 + else: + attn *= math.sqrt(1.0 / (3 * self.d_hid)) + attn += (math.sqrt(1.0 / 3) * permute_final_dims(bias, (2, 0, 1))) + pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) + pt_att *= pt_att + + pt_att = pt_att.sum(dim=-1) + head_weights = self.softplus(self.head_weights).view( + *((1, ) * len(pt_att.shape[:-2]) + (-1, 1))) + head_weights = head_weights * math.sqrt( + 1.0 / (3 * (self.num_qk_points * 9.0 / 2))) + pt_att *= head_weights * (-0.5) + + pt_att = torch.sum(pt_att, dim=-1) + + pt_att = permute_final_dims(pt_att, (2, 0, 1)) + attn += square_mask + attn = softmax_dropout( + attn, 0, self.training, bias=pt_att.type(attn.dtype)) + del pt_att, q_pts, k_pts, bias + o = torch.matmul(attn, v.transpose(-2, -3)).transpose(-2, -3) + o = o.contiguous().view(*o.shape[:-2], -1) + del q, k, v + o_pts = torch.sum( + (attn[..., None, :, :, None] + * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]), + dim=-2, + ) + + o_pts = permute_final_dims(o_pts, (2, 0, 3, 1)) + o_pts = f[..., None, None].invert_apply(o_pts) + if self.training: + o_pts_norm = torch.sqrt( + torch.sum(o_pts.float()**2, dim=-1) + self.eps).type( + o_pts.dtype) + else: + o_pts_norm = torch.sqrt(torch.sum(o_pts**2, dim=-1) + + self.eps).type(o_pts.dtype) + + o_pts_norm = o_pts_norm.view(*o_pts_norm.shape[:-2], -1) + + o_pts = o_pts.view(*o_pts.shape[:-3], -1, 3) + + o_pair = torch.matmul(attn.transpose(-2, -3), z) + + o_pair = o_pair.view(*o_pair.shape[:-2], -1) + + s = self.linear_out( + torch.cat((o, *torch.unbind(o_pts, dim=-1), o_pts_norm, o_pair), + dim=-1)) + + return s + + +class BackboneUpdate(nn.Module): + + def __init__(self, d_single): + super(BackboneUpdate, self).__init__() + self.linear = Linear(d_single, 6, init='final') + + def forward(self, s: torch.Tensor): + return self.linear(s) + + +class StructureModuleTransitionLayer(nn.Module): + + def __init__(self, c): + super(StructureModuleTransitionLayer, self).__init__() + + self.linear_1 = Linear(c, c, init='relu') + self.linear_2 = Linear(c, c, init='relu') + self.act = nn.GELU() + self.linear_3 = Linear(c, c, init='final') + + def forward(self, s): + s_old = s + s = self.linear_1(s) + s = self.act(s) + s = self.linear_2(s) + s = self.act(s) + s = self.linear_3(s) + + s = residual(s_old, s, self.training) + + return s + + +class StructureModuleTransition(nn.Module): + + def __init__(self, c, num_layers, dropout_rate): + super(StructureModuleTransition, self).__init__() + + self.num_layers = num_layers + self.dropout_rate = dropout_rate + + self.layers = SimpleModuleList() + for _ in range(self.num_layers): + self.layers.append(StructureModuleTransitionLayer(c)) + + self.dropout = nn.Dropout(self.dropout_rate) + self.layer_norm = LayerNorm(c) + + def forward(self, s): + for layer in self.layers: + s = layer(s) + + s = self.dropout(s) + s = self.layer_norm(s) + + return s + + +class StructureModule(nn.Module): + + def __init__( + self, + d_single, + d_pair, + d_ipa, + d_angle, + num_heads_ipa, + num_qk_points, + num_v_points, + dropout_rate, + num_blocks, + no_transition_layers, + num_resnet_blocks, + num_angles, + trans_scale_factor, + separate_kv, + ipa_bias, + epsilon, + inf, + **kwargs, + ): + super(StructureModule, self).__init__() + + self.num_blocks = num_blocks + self.trans_scale_factor = trans_scale_factor + self.default_frames = None + self.group_idx = None + self.atom_mask = None + self.lit_positions = None + self.inf = inf + + self.layer_norm_s = LayerNorm(d_single) + self.layer_norm_z = LayerNorm(d_pair) + + self.linear_in = Linear(d_single, d_single) + + self.ipa = InvariantPointAttention( + d_single, + d_pair, + d_ipa, + num_heads_ipa, + num_qk_points, + num_v_points, + separate_kv=separate_kv, + bias=ipa_bias, + eps=epsilon, + ) + + self.ipa_dropout = nn.Dropout(dropout_rate) + self.layer_norm_ipa = LayerNorm(d_single) + + self.transition = StructureModuleTransition( + d_single, + no_transition_layers, + dropout_rate, + ) + + self.bb_update = BackboneUpdate(d_single) + + self.angle_resnet = SidechainAngleResnet( + d_single, + d_angle, + num_resnet_blocks, + num_angles, + ) + + def forward( + self, + s, + z, + aatype, + mask=None, + ): + if mask is None: + mask = s.new_ones(s.shape[:-1]) + + # generate square mask + square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) + square_mask = gen_attn_mask(square_mask, -self.inf).unsqueeze(-3) + s = self.layer_norm_s(s) + z = self.layer_norm_z(z) + initial_s = s + s = self.linear_in(s) + + quat_encoder = Quaternion.identity( + s.shape[:-1], + s.dtype, + s.device, + requires_grad=False, + ) + backb_to_global = Frame( + Rotation(mat=quat_encoder.get_rot_mats(), ), + quat_encoder.get_trans(), + ) + outputs = [] + for i in range(self.num_blocks): + s = residual(s, self.ipa(s, z, backb_to_global, square_mask), + self.training) + s = self.ipa_dropout(s) + s = self.layer_norm_ipa(s) + s = self.transition(s) + + # update quaternion encoder + # use backb_to_global to avoid quat-to-rot conversion + quat_encoder = quat_encoder.compose_update_vec( + self.bb_update(s), pre_rot_mat=backb_to_global.get_rots()) + + # initial_s is always used to update the backbone + unnormalized_angles, angles = self.angle_resnet(s, initial_s) + + # convert quaternion to rotation matrix + backb_to_global = Frame( + Rotation(mat=quat_encoder.get_rot_mats(), ), + quat_encoder.get_trans(), + ) + if i == self.num_blocks - 1: + all_frames_to_global = self.torsion_angles_to_frames( + backb_to_global.scale_translation(self.trans_scale_factor), + angles, + aatype, + ) + + pred_positions = self.frames_and_literature_positions_to_atom14_pos( + all_frames_to_global, + aatype, + ) + + preds = { + 'frames': + backb_to_global.scale_translation( + self.trans_scale_factor).to_tensor_4x4(), + 'unnormalized_angles': + unnormalized_angles, + 'angles': + angles, + } + + outputs.append(preds) + if i < (self.num_blocks - 1): + # stop gradient in iteration + quat_encoder = quat_encoder.stop_rot_gradient() + backb_to_global = backb_to_global.stop_rot_gradient() + + outputs = dict_multimap(torch.stack, outputs) + outputs['sidechain_frames'] = all_frames_to_global.to_tensor_4x4() + outputs['positions'] = pred_positions + outputs['single'] = s + + return outputs + + def _init_residue_constants(self, float_dtype, device): + if self.default_frames is None: + self.default_frames = torch.tensor( + restype_rigid_group_default_frame, + dtype=float_dtype, + device=device, + requires_grad=False, + ) + if self.group_idx is None: + self.group_idx = torch.tensor( + restype_atom14_to_rigid_group, + device=device, + requires_grad=False, + ) + if self.atom_mask is None: + self.atom_mask = torch.tensor( + restype_atom14_mask, + dtype=float_dtype, + device=device, + requires_grad=False, + ) + if self.lit_positions is None: + self.lit_positions = torch.tensor( + restype_atom14_rigid_group_positions, + dtype=float_dtype, + device=device, + requires_grad=False, + ) + + def torsion_angles_to_frames(self, frame, alpha, aatype): + self._init_residue_constants(alpha.dtype, alpha.device) + return torsion_angles_to_frames(frame, alpha, aatype, + self.default_frames) + + def frames_and_literature_positions_to_atom14_pos(self, frame, aatype): + self._init_residue_constants(frame.get_rots().dtype, + frame.get_rots().device) + return frames_and_literature_positions_to_atom14_pos( + frame, + aatype, + self.default_frames, + self.group_idx, + self.atom_mask, + self.lit_positions, + ) diff --git a/modelscope/models/science/unifold/modules/template.py b/modelscope/models/science/unifold/modules/template.py new file mode 100644 index 00000000..49e5bec0 --- /dev/null +++ b/modelscope/models/science/unifold/modules/template.py @@ -0,0 +1,330 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import math +from functools import partial +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from unicore.modules import LayerNorm +from unicore.utils import (checkpoint_sequential, permute_final_dims, + tensor_tree_map) + +from .attentions import (Attention, TriangleAttentionEnding, + TriangleAttentionStarting, gen_attn_mask) +from .common import (Linear, SimpleModuleList, Transition, + bias_dropout_residual, chunk_layer, residual, + tri_mul_residual) +from .featurization import build_template_pair_feat_v2 +from .triangle_multiplication import (TriangleMultiplicationIncoming, + TriangleMultiplicationOutgoing) + + +class TemplatePointwiseAttention(nn.Module): + + def __init__(self, d_template, d_pair, d_hid, num_heads, inf, **kwargs): + super(TemplatePointwiseAttention, self).__init__() + + self.inf = inf + + self.mha = Attention( + d_pair, + d_template, + d_template, + d_hid, + num_heads, + gating=False, + ) + + def _chunk( + self, + z: torch.Tensor, + t: torch.Tensor, + mask: torch.Tensor, + chunk_size: int, + ) -> torch.Tensor: + mha_inputs = { + 'q': z, + 'k': t, + 'v': t, + 'mask': mask, + } + return chunk_layer( + self.mha, + mha_inputs, + chunk_size=chunk_size, + num_batch_dims=len(z.shape[:-2]), + ) + + def forward( + self, + t: torch.Tensor, + z: torch.Tensor, + template_mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + if template_mask is None: + template_mask = t.new_ones(t.shape[:-3]) + + mask = gen_attn_mask(template_mask, -self.inf)[..., None, None, None, + None, :] + z = z.unsqueeze(-2) + + t = permute_final_dims(t, (1, 2, 0, 3)) + + if chunk_size is not None: + z = self._chunk(z, t, mask, chunk_size) + else: + z = self.mha(z, t, t, mask=mask) + + z = z.squeeze(-2) + + return z + + +class TemplateProjection(nn.Module): + + def __init__(self, d_template, d_pair, **kwargs): + super(TemplateProjection, self).__init__() + + self.d_pair = d_pair + self.act = nn.ReLU() + self.output_linear = Linear(d_template, d_pair, init='relu') + + def forward(self, t, z) -> torch.Tensor: + if t is None: + # handle for non-template case + shape = z.shape + shape[-1] = self.d_pair + t = torch.zeros(shape, dtype=z.dtype, device=z.device) + t = self.act(t) + z_t = self.output_linear(t) + return z_t + + +class TemplatePairStackBlock(nn.Module): + + def __init__( + self, + d_template: int, + d_hid_tri_att: int, + d_hid_tri_mul: int, + num_heads: int, + pair_transition_n: int, + dropout_rate: float, + tri_attn_first: bool, + inf: float, + **kwargs, + ): + super(TemplatePairStackBlock, self).__init__() + + self.tri_att_start = TriangleAttentionStarting( + d_template, + d_hid_tri_att, + num_heads, + ) + self.tri_att_end = TriangleAttentionEnding( + d_template, + d_hid_tri_att, + num_heads, + ) + + self.tri_mul_out = TriangleMultiplicationOutgoing( + d_template, + d_hid_tri_mul, + ) + self.tri_mul_in = TriangleMultiplicationIncoming( + d_template, + d_hid_tri_mul, + ) + + self.pair_transition = Transition( + d_template, + pair_transition_n, + ) + self.tri_attn_first = tri_attn_first + self.dropout = dropout_rate + self.row_dropout_share_dim = -3 + self.col_dropout_share_dim = -2 + + def forward( + self, + s: torch.Tensor, + mask: torch.Tensor, + tri_start_attn_mask: torch.Tensor, + tri_end_attn_mask: torch.Tensor, + chunk_size: Optional[int] = None, + block_size: Optional[int] = None, + ): + if self.tri_attn_first: + s = bias_dropout_residual( + self.tri_att_start, + s, + self.tri_att_start( + s, attn_mask=tri_start_attn_mask, chunk_size=chunk_size), + self.row_dropout_share_dim, + self.dropout, + self.training, + ) + + s = bias_dropout_residual( + self.tri_att_end, + s, + self.tri_att_end( + s, attn_mask=tri_end_attn_mask, chunk_size=chunk_size), + self.col_dropout_share_dim, + self.dropout, + self.training, + ) + s = tri_mul_residual( + self.tri_mul_out, + s, + self.tri_mul_out(s, mask=mask, block_size=block_size), + self.row_dropout_share_dim, + self.dropout, + self.training, + block_size=block_size, + ) + + s = tri_mul_residual( + self.tri_mul_in, + s, + self.tri_mul_in(s, mask=mask, block_size=block_size), + self.row_dropout_share_dim, + self.dropout, + self.training, + block_size=block_size, + ) + else: + s = tri_mul_residual( + self.tri_mul_out, + s, + self.tri_mul_out(s, mask=mask, block_size=block_size), + self.row_dropout_share_dim, + self.dropout, + self.training, + block_size=block_size, + ) + + s = tri_mul_residual( + self.tri_mul_in, + s, + self.tri_mul_in(s, mask=mask, block_size=block_size), + self.row_dropout_share_dim, + self.dropout, + self.training, + block_size=block_size, + ) + + s = bias_dropout_residual( + self.tri_att_start, + s, + self.tri_att_start( + s, attn_mask=tri_start_attn_mask, chunk_size=chunk_size), + self.row_dropout_share_dim, + self.dropout, + self.training, + ) + + s = bias_dropout_residual( + self.tri_att_end, + s, + self.tri_att_end( + s, attn_mask=tri_end_attn_mask, chunk_size=chunk_size), + self.col_dropout_share_dim, + self.dropout, + self.training, + ) + s = residual(s, self.pair_transition( + s, + chunk_size=chunk_size, + ), self.training) + return s + + +class TemplatePairStack(nn.Module): + + def __init__( + self, + d_template, + d_hid_tri_att, + d_hid_tri_mul, + num_blocks, + num_heads, + pair_transition_n, + dropout_rate, + tri_attn_first, + inf=1e9, + **kwargs, + ): + super(TemplatePairStack, self).__init__() + + self.blocks = SimpleModuleList() + for _ in range(num_blocks): + self.blocks.append( + TemplatePairStackBlock( + d_template=d_template, + d_hid_tri_att=d_hid_tri_att, + d_hid_tri_mul=d_hid_tri_mul, + num_heads=num_heads, + pair_transition_n=pair_transition_n, + dropout_rate=dropout_rate, + inf=inf, + tri_attn_first=tri_attn_first, + )) + + self.layer_norm = LayerNorm(d_template) + + def forward( + self, + single_templates: Tuple[torch.Tensor], + mask: torch.tensor, + tri_start_attn_mask: torch.Tensor, + tri_end_attn_mask: torch.Tensor, + templ_dim: int, + chunk_size: int, + block_size: int, + return_mean: bool, + ): + + def one_template(i): + (s, ) = checkpoint_sequential( + functions=[ + partial( + b, + mask=mask, + tri_start_attn_mask=tri_start_attn_mask, + tri_end_attn_mask=tri_end_attn_mask, + chunk_size=chunk_size, + block_size=block_size, + ) for b in self.blocks + ], + input=(single_templates[i], ), + ) + return s + + n_templ = len(single_templates) + if n_templ > 0: + new_single_templates = [one_template(0)] + if return_mean: + t = self.layer_norm(new_single_templates[0]) + for i in range(1, n_templ): + s = one_template(i) + if return_mean: + t = residual(t, self.layer_norm(s), self.training) + else: + new_single_templates.append(s) + + if return_mean: + if n_templ > 0: + t /= n_templ + else: + t = None + else: + t = torch.cat( + [s.unsqueeze(templ_dim) for s in new_single_templates], + dim=templ_dim) + t = self.layer_norm(t) + + return t diff --git a/modelscope/models/science/unifold/modules/triangle_multiplication.py b/modelscope/models/science/unifold/modules/triangle_multiplication.py new file mode 100644 index 00000000..c4094cd2 --- /dev/null +++ b/modelscope/models/science/unifold/modules/triangle_multiplication.py @@ -0,0 +1,158 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from functools import partialmethod +from typing import List, Optional + +import torch +import torch.nn as nn +from unicore.modules import LayerNorm +from unicore.utils import permute_final_dims + +from .common import Linear + + +class TriangleMultiplication(nn.Module): + + def __init__(self, d_pair, d_hid, outgoing=True): + super(TriangleMultiplication, self).__init__() + self.outgoing = outgoing + + self.linear_ab_p = Linear(d_pair, d_hid * 2) + self.linear_ab_g = Linear(d_pair, d_hid * 2, init='gating') + + self.linear_g = Linear(d_pair, d_pair, init='gating') + self.linear_z = Linear(d_hid, d_pair, init='final') + + self.layer_norm_in = LayerNorm(d_pair) + self.layer_norm_out = LayerNorm(d_hid) + + self._alphafold_original_mode = False + + def _chunk_2d( + self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + block_size: int = None, + ) -> torch.Tensor: + + # avoid too small chunk size + # block_size = max(block_size, 256) + new_z = z.new_zeros(z.shape) + dim1 = z.shape[-3] + + def _slice_linear(z, linear: Linear, a=True): + d_hid = linear.bias.shape[0] // 2 + index = 0 if a else d_hid + p = ( + nn.functional.linear(z, linear.weight[index:index + d_hid]) + + linear.bias[index:index + d_hid]) + return p + + def _chunk_projection(z, mask, a=True): + p = _slice_linear(z, self.linear_ab_p, a) * mask + p *= torch.sigmoid(_slice_linear(z, self.linear_ab_g, a)) + return p + + num_chunk = (dim1 + block_size - 1) // block_size + for i in range(num_chunk): + chunk_start = i * block_size + chunk_end = min(chunk_start + block_size, dim1) + if self.outgoing: + a_chunk = _chunk_projection( + z[..., chunk_start:chunk_end, :, :], + mask[..., chunk_start:chunk_end, :, :], + a=True, + ) + a_chunk = permute_final_dims(a_chunk, (2, 0, 1)) + else: + a_chunk = _chunk_projection( + z[..., :, chunk_start:chunk_end, :], + mask[..., :, chunk_start:chunk_end, :], + a=True, + ) + a_chunk = a_chunk.transpose(-1, -3) + + for j in range(num_chunk): + j_chunk_start = j * block_size + j_chunk_end = min(j_chunk_start + block_size, dim1) + if self.outgoing: + b_chunk = _chunk_projection( + z[..., j_chunk_start:j_chunk_end, :, :], + mask[..., j_chunk_start:j_chunk_end, :, :], + a=False, + ) + b_chunk = b_chunk.transpose(-1, -3) + else: + b_chunk = _chunk_projection( + z[..., :, j_chunk_start:j_chunk_end, :], + mask[..., :, j_chunk_start:j_chunk_end, :], + a=False, + ) + b_chunk = permute_final_dims(b_chunk, (2, 0, 1)) + x_chunk = torch.matmul(a_chunk, b_chunk) + del b_chunk + x_chunk = permute_final_dims(x_chunk, (1, 2, 0)) + x_chunk = self.layer_norm_out(x_chunk) + x_chunk = self.linear_z(x_chunk) + x_chunk *= torch.sigmoid( + self.linear_g(z[..., chunk_start:chunk_end, + j_chunk_start:j_chunk_end, :])) + new_z[..., chunk_start:chunk_end, + j_chunk_start:j_chunk_end, :] = x_chunk + del x_chunk + del a_chunk + return new_z + + def forward( + self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + block_size=None, + ) -> torch.Tensor: + + mask = mask.unsqueeze(-1) + if not self._alphafold_original_mode: + # divided by 1/sqrt(dim) for numerical stability + mask = mask * (mask.shape[-2]**-0.5) + + z = self.layer_norm_in(z) + if not self.training and block_size is not None: + return self._chunk_2d(z, mask, block_size=block_size) + + g = nn.functional.linear(z, self.linear_g.weight) + if self.training: + ab = self.linear_ab_p(z) * mask * torch.sigmoid( + self.linear_ab_g(z)) + else: + ab = self.linear_ab_p(z) + ab *= mask + ab *= torch.sigmoid(self.linear_ab_g(z)) + a, b = torch.chunk(ab, 2, dim=-1) + del z, ab + + if self.outgoing: + a = permute_final_dims(a, (2, 0, 1)) + b = b.transpose(-1, -3) + else: + b = permute_final_dims(b, (2, 0, 1)) + a = a.transpose(-1, -3) + x = torch.matmul(a, b) + del a, b + + x = permute_final_dims(x, (1, 2, 0)) + + x = self.layer_norm_out(x) + x = nn.functional.linear(x, self.linear_z.weight) + return x, g + + def get_output_bias(self): + return self.linear_z.bias, self.linear_g.bias + + +class TriangleMultiplicationOutgoing(TriangleMultiplication): + __init__ = partialmethod(TriangleMultiplication.__init__, outgoing=True) + + +class TriangleMultiplicationIncoming(TriangleMultiplication): + __init__ = partialmethod(TriangleMultiplication.__init__, outgoing=False) diff --git a/modelscope/models/science/unifold/msa/__init__.py b/modelscope/models/science/unifold/msa/__init__.py new file mode 100644 index 00000000..2121062c --- /dev/null +++ b/modelscope/models/science/unifold/msa/__init__.py @@ -0,0 +1 @@ +""" Scripts for MSA & template searching. """ diff --git a/modelscope/models/science/unifold/msa/mmcif.py b/modelscope/models/science/unifold/msa/mmcif.py new file mode 100644 index 00000000..cf67239f --- /dev/null +++ b/modelscope/models/science/unifold/msa/mmcif.py @@ -0,0 +1,483 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Parses the mmCIF file format.""" +import collections +import dataclasses +import functools +import io +from typing import Any, Mapping, Optional, Sequence, Tuple + +from absl import logging +from Bio import PDB +from Bio.Data import SCOPData +from Bio.PDB.MMCIFParser import MMCIFParser + +# Type aliases: +ChainId = str +PdbHeader = Mapping[str, Any] +PdbStructure = PDB.Structure.Structure +SeqRes = str +MmCIFDict = Mapping[str, Sequence[str]] + + +@dataclasses.dataclass(frozen=True) +class Monomer: + id: str + num: int + + +# Note - mmCIF format provides no guarantees on the type of author-assigned +# sequence numbers. They need not be integers. +@dataclasses.dataclass(frozen=True) +class AtomSite: + residue_name: str + author_chain_id: str + mmcif_chain_id: str + author_seq_num: str + mmcif_seq_num: int + insertion_code: str + hetatm_atom: str + model_num: int + + +# Used to map SEQRES index to a residue in the structure. +@dataclasses.dataclass(frozen=True) +class ResiduePosition: + chain_id: str + residue_number: int + insertion_code: str + + +@dataclasses.dataclass(frozen=True) +class ResidueAtPosition: + position: Optional[ResiduePosition] + name: str + is_missing: bool + hetflag: str + + +@dataclasses.dataclass(frozen=True) +class MmcifObject: + """Representation of a parsed mmCIF file. + + Contains: + file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all + files being processed. + header: Biopython header. + structure: Biopython structure. + chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g. + {'A': 'ABCDEFG'} + seqres_to_structure: Dict; for each chain_id contains a mapping between + SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition, 1: ResidueAtPosition, ...}} + raw_string: The raw string used to construct the MmcifObject. + """ + + file_id: str + header: PdbHeader + structure: PdbStructure + chain_to_seqres: Mapping[ChainId, SeqRes] + seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]] + raw_string: Any + mmcif_to_author_chain_id: Mapping[ChainId, ChainId] + valid_chains: Mapping[ChainId, str] + + +@dataclasses.dataclass(frozen=True) +class ParsingResult: + """Returned by the parse function. + + Contains: + mmcif_object: A MmcifObject, may be None if no chain could be successfully + parsed. + errors: A dict mapping (file_id, chain_id) to any exception generated. + """ + + mmcif_object: Optional[MmcifObject] + errors: Mapping[Tuple[str, str], Any] + + +class ParseError(Exception): + """An error indicating that an mmCIF file could not be parsed.""" + + +def mmcif_loop_to_list(prefix: str, + parsed_info: MmCIFDict) -> Sequence[Mapping[str, str]]: + """Extracts loop associated with a prefix from mmCIF data as a list. + + Reference for loop_ in mmCIF: + http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html + + Args: + prefix: Prefix shared by each of the data items in the loop. + e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, + _entity_poly_seq.mon_id. Should include the trailing period. + parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython + parser. + + Returns: + Returns a list of dicts; each dict represents 1 entry from an mmCIF loop. + """ + cols = [] + data = [] + for key, value in parsed_info.items(): + if key.startswith(prefix): + cols.append(key) + data.append(value) + + assert all([ + len(xs) == len(data[0]) for xs in data + ]), ('mmCIF error: Not all loops are the same length: %s' % cols) + + return [dict(zip(cols, xs)) for xs in zip(*data)] + + +def mmcif_loop_to_dict( + prefix: str, + index: str, + parsed_info: MmCIFDict, +) -> Mapping[str, Mapping[str, str]]: + """Extracts loop associated with a prefix from mmCIF data as a dictionary. + + Args: + prefix: Prefix shared by each of the data items in the loop. + e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, + _entity_poly_seq.mon_id. Should include the trailing period. + index: Which item of loop data should serve as the key. + parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython + parser. + + Returns: + Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop, + indexed by the index column. + """ + entries = mmcif_loop_to_list(prefix, parsed_info) + return {entry[index]: entry for entry in entries} + + +@functools.lru_cache(16, typed=False) +def fast_parse(*, + file_id: str, + mmcif_string: str, + catch_all_errors: bool = True) -> ParsingResult: + """Entry point, parses an mmcif_string. + + Args: + file_id: A string identifier for this file. Should be unique within the + collection of files being processed. + mmcif_string: Contents of an mmCIF file. + catch_all_errors: If True, all exceptions are caught and error messages are + returned as part of the ParsingResult. If False exceptions will be allowed + to propagate. + + Returns: + A ParsingResult. + """ + errors = {} + try: + parser = MMCIFParser(QUIET=True) + # handle = io.StringIO(mmcif_string) + # full_structure = parser.get_structure('', handle) + parsed_info = parser._mmcif_dict # pylint:disable=protected-access + + # Ensure all values are lists, even if singletons. + for key, value in parsed_info.items(): + if not isinstance(value, list): + parsed_info[key] = [value] + + header = _get_header(parsed_info) + + # Determine the protein chains, and their start numbers according to the + # internal mmCIF numbering scheme (likely but not guaranteed to be 1). + valid_chains = _get_protein_chains(parsed_info=parsed_info) + if not valid_chains: + return ParsingResult( + None, {(file_id, ''): 'No protein chains found in this file.'}) + + mmcif_to_author_chain_id = {} + # seq_to_structure_mappings = {} + for atom in _get_atom_site_list(parsed_info): + if atom.model_num != '1': + # We only process the first model at the moment. + continue + mmcif_to_author_chain_id[ + atom.mmcif_chain_id] = atom.author_chain_id + + mmcif_object = MmcifObject( + file_id=file_id, + header=header, + structure=None, + chain_to_seqres=None, + seqres_to_structure=None, + raw_string=parsed_info, + mmcif_to_author_chain_id=mmcif_to_author_chain_id, + valid_chains=valid_chains, + ) + + return ParsingResult(mmcif_object=mmcif_object, errors=errors) + except Exception as e: # pylint:disable=broad-except + errors[(file_id, '')] = e + if not catch_all_errors: + raise + return ParsingResult(mmcif_object=None, errors=errors) + + +@functools.lru_cache(16, typed=False) +def parse(*, + file_id: str, + mmcif_string: str, + catch_all_errors: bool = True) -> ParsingResult: + """Entry point, parses an mmcif_string. + + Args: + file_id: A string identifier for this file. Should be unique within the + collection of files being processed. + mmcif_string: Contents of an mmCIF file. + catch_all_errors: If True, all exceptions are caught and error messages are + returned as part of the ParsingResult. If False exceptions will be allowed + to propagate. + + Returns: + A ParsingResult. + """ + errors = {} + try: + parser = PDB.MMCIFParser(QUIET=True) + handle = io.StringIO(mmcif_string) + full_structure = parser.get_structure('', handle) + first_model_structure = _get_first_model(full_structure) + # Extract the _mmcif_dict from the parser, which contains useful fields not + # reflected in the Biopython structure. + parsed_info = parser._mmcif_dict # pylint:disable=protected-access + + # Ensure all values are lists, even if singletons. + for key, value in parsed_info.items(): + if not isinstance(value, list): + parsed_info[key] = [value] + + header = _get_header(parsed_info) + + # Determine the protein chains, and their start numbers according to the + # internal mmCIF numbering scheme (likely but not guaranteed to be 1). + valid_chains = _get_protein_chains(parsed_info=parsed_info) + if not valid_chains: + return ParsingResult( + None, {(file_id, ''): 'No protein chains found in this file.'}) + seq_start_num = { + chain_id: min([monomer.num for monomer in seq]) + for chain_id, seq in valid_chains.items() + } + + # Loop over the atoms for which we have coordinates. Populate two mappings: + # -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used + # the authors / Biopython). + # -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition). + mmcif_to_author_chain_id = {} + seq_to_structure_mappings = {} + for atom in _get_atom_site_list(parsed_info): + if atom.model_num != '1': + # We only process the first model at the moment. + continue + + mmcif_to_author_chain_id[ + atom.mmcif_chain_id] = atom.author_chain_id + + if atom.mmcif_chain_id in valid_chains: + hetflag = ' ' + if atom.hetatm_atom == 'HETATM': + # Water atoms are assigned a special hetflag of W in Biopython. We + # need to do the same, so that this hetflag can be used to fetch + # a residue from the Biopython structure by id. + if atom.residue_name in ('HOH', 'WAT'): + hetflag = 'W' + else: + hetflag = 'H_' + atom.residue_name + insertion_code = atom.insertion_code + if not _is_set(atom.insertion_code): + insertion_code = ' ' + position = ResiduePosition( + chain_id=atom.author_chain_id, + residue_number=int(atom.author_seq_num), + insertion_code=insertion_code, + ) + seq_idx = int( + atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id] + current = seq_to_structure_mappings.get( + atom.author_chain_id, {}) + current[seq_idx] = ResidueAtPosition( + position=position, + name=atom.residue_name, + is_missing=False, + hetflag=hetflag, + ) + seq_to_structure_mappings[atom.author_chain_id] = current + + # Add missing residue information to seq_to_structure_mappings. + for chain_id, seq_info in valid_chains.items(): + author_chain = mmcif_to_author_chain_id[chain_id] + current_mapping = seq_to_structure_mappings[author_chain] + for idx, monomer in enumerate(seq_info): + if idx not in current_mapping: + current_mapping[idx] = ResidueAtPosition( + position=None, + name=monomer.id, + is_missing=True, + hetflag=' ') + + author_chain_to_sequence = {} + for chain_id, seq_info in valid_chains.items(): + author_chain = mmcif_to_author_chain_id[chain_id] + seq = [] + for monomer in seq_info: + code = SCOPData.protein_letters_3to1.get(monomer.id, 'X') + seq.append(code if len(code) == 1 else 'X') + seq = ''.join(seq) + author_chain_to_sequence[author_chain] = seq + + mmcif_object = MmcifObject( + file_id=file_id, + header=header, + structure=first_model_structure, + chain_to_seqres=author_chain_to_sequence, + seqres_to_structure=seq_to_structure_mappings, + raw_string=parsed_info, + mmcif_to_author_chain_id=mmcif_to_author_chain_id, + valid_chains=valid_chains, + ) + + return ParsingResult(mmcif_object=mmcif_object, errors=errors) + except Exception as e: # pylint:disable=broad-except + errors[(file_id, '')] = e + if not catch_all_errors: + raise + return ParsingResult(mmcif_object=None, errors=errors) + + +def _get_first_model(structure: PdbStructure) -> PdbStructure: + """Returns the first model in a Biopython structure.""" + return next(structure.get_models()) + + +_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21 + + +def get_release_date(parsed_info: MmCIFDict) -> str: + """Returns the oldest revision date.""" + revision_dates = parsed_info['_pdbx_audit_revision_history.revision_date'] + return min(revision_dates) + + +def _get_header(parsed_info: MmCIFDict) -> PdbHeader: + """Returns a basic header containing method, release date and resolution.""" + header = {} + + experiments = mmcif_loop_to_list('_exptl.', parsed_info) + header['structure_method'] = ','.join( + [experiment['_exptl.method'].lower() for experiment in experiments]) + + # Note: The release_date here corresponds to the oldest revision. We prefer to + # use this for dataset filtering over the deposition_date. + if '_pdbx_audit_revision_history.revision_date' in parsed_info: + header['release_date'] = get_release_date(parsed_info) + else: + logging.warning('Could not determine release_date: %s', + parsed_info['_entry.id']) + + header['resolution'] = 0.00 + for res_key in ( + '_refine.ls_d_res_high', + '_em_3d_reconstruction.resolution', + '_reflns.d_resolution_high', + ): + if res_key in parsed_info: + try: + raw_resolution = parsed_info[res_key][0] + header['resolution'] = float(raw_resolution) + except ValueError: + logging.debug('Invalid resolution format: %s', + parsed_info[res_key]) + + return header + + +def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]: + """Returns list of atom sites; contains data not present in the structure.""" + return [ + AtomSite(*site) for site in zip( # pylint:disable=g-complex-comprehension + parsed_info['_atom_site.label_comp_id'], + parsed_info['_atom_site.auth_asym_id'], + parsed_info['_atom_site.label_asym_id'], + parsed_info['_atom_site.auth_seq_id'], + parsed_info['_atom_site.label_seq_id'], + parsed_info['_atom_site.pdbx_PDB_ins_code'], + parsed_info['_atom_site.group_PDB'], + parsed_info['_atom_site.pdbx_PDB_model_num'], + ) + ] + + +def _get_protein_chains( + *, parsed_info: Mapping[str, + Any]) -> Mapping[ChainId, Sequence[Monomer]]: + """Extracts polymer information for protein chains only. + + Args: + parsed_info: _mmcif_dict produced by the Biopython parser. + + Returns: + A dict mapping mmcif chain id to a list of Monomers. + """ + # Get polymer information for each entity in the structure. + entity_poly_seqs = mmcif_loop_to_list('_entity_poly_seq.', parsed_info) + + polymers = collections.defaultdict(list) + for entity_poly_seq in entity_poly_seqs: + polymers[entity_poly_seq['_entity_poly_seq.entity_id']].append( + Monomer( + id=entity_poly_seq['_entity_poly_seq.mon_id'], + num=int(entity_poly_seq['_entity_poly_seq.num']), + )) + + # Get chemical compositions. Will allow us to identify which of these polymers + # are proteins. + chem_comps = mmcif_loop_to_dict('_chem_comp.', '_chem_comp.id', + parsed_info) + + # Get chains information for each entity. Necessary so that we can return a + # dict keyed on chain id rather than entity. + struct_asyms = mmcif_loop_to_list('_struct_asym.', parsed_info) + + entity_to_mmcif_chains = collections.defaultdict(list) + for struct_asym in struct_asyms: + chain_id = struct_asym['_struct_asym.id'] + entity_id = struct_asym['_struct_asym.entity_id'] + entity_to_mmcif_chains[entity_id].append(chain_id) + + # Identify and return the valid protein chains. + valid_chains = {} + for entity_id, seq_info in polymers.items(): + chain_ids = entity_to_mmcif_chains[entity_id] + + # Reject polymers without any peptide-like components, such as DNA/RNA. + if any([ + 'peptide' in chem_comps[monomer.id]['_chem_comp.type'] + for monomer in seq_info + ]): + for chain_id in chain_ids: + valid_chains[chain_id] = seq_info + return valid_chains + + +def _is_set(data: str) -> bool: + """Returns False if data is a special mmCIF character indicating 'unset'.""" + return data not in ('.', '?') diff --git a/modelscope/models/science/unifold/msa/msa_identifiers.py b/modelscope/models/science/unifold/msa/msa_identifiers.py new file mode 100644 index 00000000..366239db --- /dev/null +++ b/modelscope/models/science/unifold/msa/msa_identifiers.py @@ -0,0 +1,88 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for extracting identifiers from MSA sequence descriptions.""" + +import dataclasses +import re +from typing import Optional + +# Sequences coming from UniProtKB database come in the +# `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE` +# or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively). +_UNIPROT_PATTERN = re.compile( + r""" + ^ + # UniProtKB/TrEMBL or UniProtKB/Swiss-Prot + (?:tr|sp) + \| + # A primary accession number of the UniProtKB entry. + (?P[A-Za-z0-9]{6,10}) + # Occasionally there is a _0 or _1 isoform suffix, which we ignore. + (?:_\d)? + \| + # TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic + # protein ID code. + (?:[A-Za-z0-9]+) + _ + # A mnemonic species identification code. + (?P([A-Za-z0-9]){1,5}) + # Small BFD uses a final value after an underscore, which we ignore. + (?:_\d+)? + $ + """, + re.VERBOSE, +) + + +@dataclasses.dataclass(frozen=True) +class Identifiers: + species_id: str = '' + + +def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers: + """Gets accession id and species from an msa sequence identifier. + + The sequence identifier has the format specified by + _UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN. + An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE` + + Args: + msa_sequence_identifier: a sequence identifier. + + Returns: + An `Identifiers` instance with a species_id. These + can be empty in the case where no identifier was found. + """ + matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip()) + if matches: + return Identifiers(species_id=matches.group('SpeciesIdentifier')) + return Identifiers() + + +def _extract_sequence_identifier(description: str) -> Optional[str]: + """Extracts sequence identifier from description. Returns None if no match.""" + split_description = description.split() + if split_description: + return split_description[0].partition('/')[0] + else: + return None + + +def get_identifiers(description: str) -> Identifiers: + """Computes extra MSA features from the description.""" + sequence_identifier = _extract_sequence_identifier(description) + if sequence_identifier is None: + return Identifiers() + else: + return _parse_sequence_identifier(sequence_identifier) diff --git a/modelscope/models/science/unifold/msa/parsers.py b/modelscope/models/science/unifold/msa/parsers.py new file mode 100644 index 00000000..bf36c816 --- /dev/null +++ b/modelscope/models/science/unifold/msa/parsers.py @@ -0,0 +1,627 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Functions for parsing various file formats.""" +import collections +import dataclasses +import itertools +import re +import string +from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple + +DeletionMatrix = Sequence[Sequence[int]] + + +@dataclasses.dataclass(frozen=True) +class Msa: + """Class representing a parsed MSA file.""" + + sequences: Sequence[str] + deletion_matrix: DeletionMatrix + descriptions: Sequence[str] + + def __post_init__(self): + if not (len(self.sequences) == len(self.deletion_matrix) == len( + self.descriptions)): + raise ValueError( + 'All fields for an MSA must have the same length. ' + f'Got {len(self.sequences)} sequences, ' + f'{len(self.deletion_matrix)} rows in the deletion matrix and ' + f'{len(self.descriptions)} descriptions.') + + def __len__(self): + return len(self.sequences) + + def truncate(self, max_seqs: int): + return Msa( + sequences=self.sequences[:max_seqs], + deletion_matrix=self.deletion_matrix[:max_seqs], + descriptions=self.descriptions[:max_seqs], + ) + + +@dataclasses.dataclass(frozen=True) +class TemplateHit: + """Class representing a template hit.""" + + index: int + name: str + aligned_cols: int + sum_probs: Optional[float] + query: str + hit_sequence: str + indices_query: List[int] + indices_hit: List[int] + + +def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: + """Parses FASTA string and returns list of strings with amino-acid sequences. + + Arguments: + fasta_string: The string contents of a FASTA file. + + Returns: + A tuple of two lists: + * A list of sequences. + * A list of sequence descriptions taken from the comment lines. In the + same order as the sequences. + """ + sequences = [] + descriptions = [] + index = -1 + for line in fasta_string.splitlines(): + line = line.strip() + if line.startswith('>'): + index += 1 + descriptions.append(line[1:]) # Remove the '>' at the beginning. + sequences.append('') + continue + elif not line: + continue # Skip blank lines. + sequences[index] += line + + return sequences, descriptions + + +def parse_stockholm(stockholm_string: str) -> Msa: + """Parses sequences and deletion matrix from stockholm format alignment. + + Args: + stockholm_string: The string contents of a stockholm file. The first + sequence in the file should be the query sequence. + + Returns: + A tuple of: + * A list of sequences that have been aligned to the query. These + might contain duplicates. + * The deletion matrix for the alignment as a list of lists. The element + at `deletion_matrix[i][j]` is the number of residues deleted from + the aligned sequence i at residue position j. + * The names of the targets matched, including the jackhmmer subsequence + suffix. + """ + name_to_sequence = collections.OrderedDict() + for line in stockholm_string.splitlines(): + line = line.strip() + if not line or line.startswith(('#', '//')): + continue + name, sequence = line.split() + if name not in name_to_sequence: + name_to_sequence[name] = '' + name_to_sequence[name] += sequence + + msa = [] + deletion_matrix = [] + + query = '' + keep_columns = [] + for seq_index, sequence in enumerate(name_to_sequence.values()): + if seq_index == 0: + # Gather the columns with gaps from the query + query = sequence + keep_columns = [i for i, res in enumerate(query) if res != '-'] + + # Remove the columns with gaps in the query from all sequences. + aligned_sequence = ''.join([sequence[c] for c in keep_columns]) + + msa.append(aligned_sequence) + + # Count the number of deletions w.r.t. query. + deletion_vec = [] + deletion_count = 0 + for seq_res, query_res in zip(sequence, query): + if seq_res != '-' or query_res != '-': + if query_res == '-': + deletion_count += 1 + else: + deletion_vec.append(deletion_count) + deletion_count = 0 + deletion_matrix.append(deletion_vec) + + return Msa( + sequences=msa, + deletion_matrix=deletion_matrix, + descriptions=list(name_to_sequence.keys()), + ) + + +def parse_a3m(a3m_string: str) -> Msa: + """Parses sequences and deletion matrix from a3m format alignment. + + Args: + a3m_string: The string contents of a a3m file. The first sequence in the + file should be the query sequence. + + Returns: + A tuple of: + * A list of sequences that have been aligned to the query. These + might contain duplicates. + * The deletion matrix for the alignment as a list of lists. The element + at `deletion_matrix[i][j]` is the number of residues deleted from + the aligned sequence i at residue position j. + * A list of descriptions, one per sequence, from the a3m file. + """ + sequences, descriptions = parse_fasta(a3m_string) + deletion_matrix = [] + for msa_sequence in sequences: + deletion_vec = [] + deletion_count = 0 + for j in msa_sequence: + if j.islower(): + deletion_count += 1 + else: + deletion_vec.append(deletion_count) + deletion_count = 0 + deletion_matrix.append(deletion_vec) + + # Make the MSA matrix out of aligned (deletion-free) sequences. + deletion_table = str.maketrans('', '', string.ascii_lowercase) + aligned_sequences = [s.translate(deletion_table) for s in sequences] + return Msa( + sequences=aligned_sequences, + deletion_matrix=deletion_matrix, + descriptions=descriptions, + ) + + +def _convert_sto_seq_to_a3m(query_non_gaps: Sequence[bool], + sto_seq: str) -> Iterable[str]: + for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq): + if is_query_res_non_gap: + yield sequence_res + elif sequence_res != '-': + yield sequence_res.lower() + + +def convert_stockholm_to_a3m( + stockholm_format: str, + max_sequences: Optional[int] = None, + remove_first_row_gaps: bool = True, +) -> str: + """Converts MSA in Stockholm format to the A3M format.""" + descriptions = {} + sequences = {} + reached_max_sequences = False + + for line in stockholm_format.splitlines(): + reached_max_sequences = max_sequences and len( + sequences) >= max_sequences + if line.strip() and not line.startswith(('#', '//')): + # Ignore blank lines, markup and end symbols - remainder are alignment + # sequence parts. + seqname, aligned_seq = line.split(maxsplit=1) + if seqname not in sequences: + if reached_max_sequences: + continue + sequences[seqname] = '' + sequences[seqname] += aligned_seq + + for line in stockholm_format.splitlines(): + if line[:4] == '#=GS': + # Description row - example format is: + # #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ... + columns = line.split(maxsplit=3) + seqname, feature = columns[1:3] + value = columns[3] if len(columns) == 4 else '' + if feature != 'DE': + continue + if reached_max_sequences and seqname not in sequences: + continue + descriptions[seqname] = value + if len(descriptions) == len(sequences): + break + + # Convert sto format to a3m line by line + a3m_sequences = {} + if remove_first_row_gaps: + # query_sequence is assumed to be the first sequence + query_sequence = next(iter(sequences.values())) + query_non_gaps = [res != '-' for res in query_sequence] + for seqname, sto_sequence in sequences.items(): + # Dots are optional in a3m format and are commonly removed. + out_sequence = sto_sequence.replace('.', '') + if remove_first_row_gaps: + out_sequence = ''.join( + _convert_sto_seq_to_a3m(query_non_gaps, out_sequence)) + a3m_sequences[seqname] = out_sequence + + fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}" + for k in a3m_sequences) + return '\n'.join(fasta_chunks) + '\n' # Include terminating newline. + + +def _keep_line(line: str, seqnames: Set[str]) -> bool: + """Function to decide which lines to keep.""" + if not line.strip(): + return True + if line.strip() == '//': # End tag + return True + if line.startswith('# STOCKHOLM'): # Start tag + return True + if line.startswith('#=GC RF'): # Reference Annotation Line + return True + if line[:4] == '#=GS': # Description lines - keep if sequence in list. + _, seqname, _ = line.split(maxsplit=2) + return seqname in seqnames + elif line.startswith('#'): # Other markup - filter out + return False + else: # Alignment data - keep if sequence in list. + seqname = line.partition(' ')[0] + return seqname in seqnames + + +def truncate_stockholm_msa(stockholm_msa: str, max_sequences: int) -> str: + """Truncates a stockholm file to a maximum number of sequences.""" + seqnames = set() + filtered_lines = [] + for line in stockholm_msa.splitlines(): + if line.strip() and not line.startswith(('#', '//')): + # Ignore blank lines, markup and end symbols - remainder are alignment + # sequence parts. + seqname = line.partition(' ')[0] + seqnames.add(seqname) + if len(seqnames) >= max_sequences: + break + + for line in stockholm_msa.splitlines(): + if _keep_line(line, seqnames): + filtered_lines.append(line) + + return '\n'.join(filtered_lines) + '\n' + + +def remove_empty_columns_from_stockholm_msa(stockholm_msa: str) -> str: + """Removes empty columns (dashes-only) from a Stockholm MSA.""" + processed_lines = {} + unprocessed_lines = {} + for i, line in enumerate(stockholm_msa.splitlines()): + if line.startswith('#=GC RF'): + reference_annotation_i = i + reference_annotation_line = line + # Reached the end of this chunk of the alignment. Process chunk. + _, _, first_alignment = line.rpartition(' ') + mask = [] + for j in range(len(first_alignment)): + for _, unprocessed_line in unprocessed_lines.items(): + prefix, _, alignment = unprocessed_line.rpartition(' ') + if alignment[j] != '-': + mask.append(True) + break + else: # Every row contained a hyphen - empty column. + mask.append(False) + # Add reference annotation for processing with mask. + unprocessed_lines[ + reference_annotation_i] = reference_annotation_line + + if not any( + mask + ): # All columns were empty. Output empty lines for chunk. + for line_index in unprocessed_lines: + processed_lines[line_index] = '' + else: + for line_index, unprocessed_line in unprocessed_lines.items(): + prefix, _, alignment = unprocessed_line.rpartition(' ') + masked_alignment = ''.join( + itertools.compress(alignment, mask)) + processed_lines[ + line_index] = f'{prefix} {masked_alignment}' + + # Clear raw_alignments. + unprocessed_lines = {} + elif line.strip() and not line.startswith(('#', '//')): + unprocessed_lines[i] = line + else: + processed_lines[i] = line + return '\n'.join((processed_lines[i] for i in range(len(processed_lines)))) + + +def deduplicate_stockholm_msa(stockholm_msa: str) -> str: + """Remove duplicate sequences (ignoring insertions wrt query).""" + sequence_dict = collections.defaultdict(str) + + # First we must extract all sequences from the MSA. + for line in stockholm_msa.splitlines(): + # Only consider the alignments - ignore reference annotation, empty lines, + # descriptions or markup. + if line.strip() and not line.startswith(('#', '//')): + line = line.strip() + seqname, alignment = line.split() + sequence_dict[seqname] += alignment + + seen_sequences = set() + seqnames = set() + # First alignment is the query. + query_align = next(iter(sequence_dict.values())) + mask = [c != '-' for c in query_align] # Mask is False for insertions. + for seqname, alignment in sequence_dict.items(): + # Apply mask to remove all insertions from the string. + masked_alignment = ''.join(itertools.compress(alignment, mask)) + if masked_alignment in seen_sequences: + continue + else: + seen_sequences.add(masked_alignment) + seqnames.add(seqname) + + filtered_lines = [] + for line in stockholm_msa.splitlines(): + if _keep_line(line, seqnames): + filtered_lines.append(line) + + return '\n'.join(filtered_lines) + '\n' + + +def _get_hhr_line_regex_groups(regex_pattern: str, + line: str) -> Sequence[Optional[str]]: + match = re.match(regex_pattern, line) + if match is None: + raise RuntimeError(f'Could not parse query line {line}') + return match.groups() + + +def _update_hhr_residue_indices_list(sequence: str, start_index: int, + indices_list: List[int]): + """Computes the relative indices for each residue with respect to the original sequence.""" + counter = start_index + for symbol in sequence: + if symbol == '-': + indices_list.append(-1) + else: + indices_list.append(counter) + counter += 1 + + +def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: + """Parses the detailed HMM HMM comparison section for a single Hit. + + This works on .hhr files generated from both HHBlits and HHSearch. + + Args: + detailed_lines: A list of lines from a single comparison section between 2 + sequences (which each have their own HMM's) + + Returns: + A dictionary with the information from that detailed comparison section + + Raises: + RuntimeError: If a certain line cannot be processed + """ + # Parse first 2 lines. + number_of_hit = int(detailed_lines[0].split()[-1]) + name_hit = detailed_lines[1][1:] + + # Parse the summary line. + pattern = ( + 'Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t' + ' ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t ' + ']*Template_Neff=(.*)') + match = re.match(pattern, detailed_lines[2]) + if match is None: + raise RuntimeError( + 'Could not parse section: %s. Expected this: \n%s to contain summary.' + % (detailed_lines, detailed_lines[2])) + (_, _, _, aligned_cols, _, _, sum_probs, + _) = [float(x) for x in match.groups()] + + # The next section reads the detailed comparisons. These are in a 'human + # readable' format which has a fixed length. The strategy employed is to + # assume that each block starts with the query sequence line, and to parse + # that with a regexp in order to deduce the fixed length used for that block. + query = '' + hit_sequence = '' + indices_query = [] + indices_hit = [] + length_block = None + + for line in detailed_lines[3:]: + # Parse the query sequence line + if (line.startswith('Q ') and not line.startswith('Q ss_dssp') + and not line.startswith('Q ss_pred') + and not line.startswith('Q Consensus')): + # Thus the first 17 characters must be 'Q ', and we can parse + # everything after that. + # start sequence end total_sequence_length + patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)' + groups = _get_hhr_line_regex_groups(patt, line[17:]) + + # Get the length of the parsed block using the start and finish indices, + # and ensure it is the same as the actual block length. + start = int(groups[0]) - 1 # Make index zero based. + delta_query = groups[1] + end = int(groups[2]) + num_insertions = len([x for x in delta_query if x == '-']) + length_block = end - start + num_insertions + assert length_block == len(delta_query) + + # Update the query sequence and indices list. + query += delta_query + _update_hhr_residue_indices_list(delta_query, start, indices_query) + + elif line.startswith('T '): + # Parse the hit sequence. + if (not line.startswith('T ss_dssp') + and not line.startswith('T ss_pred') + and not line.startswith('T Consensus')): + # Thus the first 17 characters must be 'T ', and we can + # parse everything after that. + # start sequence end total_sequence_length + patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)' + groups = _get_hhr_line_regex_groups(patt, line[17:]) + start = int(groups[0]) - 1 # Make index zero based. + delta_hit_sequence = groups[1] + assert length_block == len(delta_hit_sequence) + + # Update the hit sequence and indices list. + hit_sequence += delta_hit_sequence + _update_hhr_residue_indices_list(delta_hit_sequence, start, + indices_hit) + + return TemplateHit( + index=number_of_hit, + name=name_hit, + aligned_cols=int(aligned_cols), + sum_probs=sum_probs, + query=query, + hit_sequence=hit_sequence, + indices_query=indices_query, + indices_hit=indices_hit, + ) + + +def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]: + """Parses the content of an entire HHR file.""" + lines = hhr_string.splitlines() + + # Each .hhr file starts with a results table, then has a sequence of hit + # "paragraphs", each paragraph starting with a line 'No '. We + # iterate through each paragraph to parse each hit. + + block_starts = [ + i for i, line in enumerate(lines) if line.startswith('No ') + ] + + hits = [] + if block_starts: + block_starts.append(len(lines)) # Add the end of the final block. + for i in range(len(block_starts) - 1): + hits.append( + _parse_hhr_hit(lines[block_starts[i]:block_starts[i + 1]])) + return hits + + +def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]: + """Parse target to e-value mapping parsed from Jackhmmer tblout string.""" + e_values = {'query': 0} + lines = [line for line in tblout.splitlines() if line[0] != '#'] + # As per http://eddylab.org/software/hmmer/Userguide.pdf fields are + # space-delimited. Relevant fields are (1) target name: and + # (5) E-value (full sequence) (numbering from 1). + for line in lines: + fields = line.split() + e_value = fields[4] + target_name = fields[0] + e_values[target_name] = float(e_value) + return e_values + + +def _get_indices(sequence: str, start: int) -> List[int]: + """Returns indices for non-gap/insert residues starting at the given index.""" + indices = [] + counter = start + for symbol in sequence: + # Skip gaps but add a placeholder so that the alignment is preserved. + if symbol == '-': + indices.append(-1) + # Skip deleted residues, but increase the counter. + elif symbol.islower(): + counter += 1 + # Normal aligned residue. Increase the counter and append to indices. + else: + indices.append(counter) + counter += 1 + return indices + + +@dataclasses.dataclass(frozen=True) +class HitMetadata: + pdb_id: str + chain: str + start: int + end: int + length: int + text: str + + +def _parse_hmmsearch_description(description: str) -> HitMetadata: + """Parses the hmmsearch A3M sequence description line.""" + # Example 1: >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text + # Example 2: >5g3r_A/1-55 [subseq from] mol:protein length:352 + match = re.match( + r'^>?([a-z0-9]+)_(\w+)/([0-9]+)-([0-9]+).*protein length:([0-9]+) *(.*)$', + description.strip(), + ) + + if not match: + raise ValueError(f'Could not parse description: "{description}".') + + return HitMetadata( + pdb_id=match[1], + chain=match[2], + start=int(match[3]), + end=int(match[4]), + length=int(match[5]), + text=match[6], + ) + + +def parse_hmmsearch_a3m(query_sequence: str, + a3m_string: str, + skip_first: bool = True) -> Sequence[TemplateHit]: + """Parses an a3m string produced by hmmsearch. + + Args: + query_sequence: The query sequence. + a3m_string: The a3m string produced by hmmsearch. + skip_first: Whether to skip the first sequence in the a3m string. + + Returns: + A sequence of `TemplateHit` results. + """ + # Zip the descriptions and MSAs together, skip the first query sequence. + parsed_a3m = list(zip(*parse_fasta(a3m_string))) + if skip_first: + parsed_a3m = parsed_a3m[1:] + + indices_query = _get_indices(query_sequence, start=0) + + hits = [] + for i, (hit_sequence, hit_description) in enumerate(parsed_a3m, start=1): + if 'mol:protein' not in hit_description: + continue # Skip non-protein chains. + metadata = _parse_hmmsearch_description(hit_description) + # Aligned columns are only the match states. + aligned_cols = sum([r.isupper() and r != '-' for r in hit_sequence]) + indices_hit = _get_indices(hit_sequence, start=metadata.start - 1) + + hit = TemplateHit( + index=i, + name=f'{metadata.pdb_id}_{metadata.chain}', + aligned_cols=aligned_cols, + sum_probs=None, + query=query_sequence, + hit_sequence=hit_sequence.upper(), + indices_query=indices_query, + indices_hit=indices_hit, + ) + hits.append(hit) + + return hits diff --git a/modelscope/models/science/unifold/msa/pipeline.py b/modelscope/models/science/unifold/msa/pipeline.py new file mode 100644 index 00000000..b7889bff --- /dev/null +++ b/modelscope/models/science/unifold/msa/pipeline.py @@ -0,0 +1,282 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Functions for building the input features for the unifold model.""" + +import os +from typing import Any, Mapping, MutableMapping, Optional, Sequence, Union + +import numpy as np +from absl import logging + +from modelscope.models.science.unifold.data import residue_constants +from modelscope.models.science.unifold.msa import (msa_identifiers, parsers, + templates) +from modelscope.models.science.unifold.msa.tools import (hhblits, hhsearch, + hmmsearch, jackhmmer) + +FeatureDict = MutableMapping[str, np.ndarray] +TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch] + + +def make_sequence_features(sequence: str, description: str, + num_res: int) -> FeatureDict: + """Constructs a feature dict of sequence features.""" + features = {} + features['aatype'] = residue_constants.sequence_to_onehot( + sequence=sequence, + mapping=residue_constants.restype_order_with_x, + map_unknown_to_x=True, + ) + features['between_segment_residues'] = np.zeros((num_res, ), + dtype=np.int32) + features['domain_name'] = np.array([description.encode('utf-8')], + dtype=np.object_) + features['residue_index'] = np.array(range(num_res), dtype=np.int32) + features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32) + features['sequence'] = np.array([sequence.encode('utf-8')], + dtype=np.object_) + return features + + +def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict: + """Constructs a feature dict of MSA features.""" + if not msas: + raise ValueError('At least one MSA must be provided.') + + int_msa = [] + deletion_matrix = [] + species_ids = [] + seen_sequences = set() + for msa_index, msa in enumerate(msas): + if not msa: + raise ValueError( + f'MSA {msa_index} must contain at least one sequence.') + for sequence_index, sequence in enumerate(msa.sequences): + if sequence in seen_sequences: + continue + seen_sequences.add(sequence) + int_msa.append( + [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]) + deletion_matrix.append(msa.deletion_matrix[sequence_index]) + identifiers = msa_identifiers.get_identifiers( + msa.descriptions[sequence_index]) + species_ids.append(identifiers.species_id.encode('utf-8')) + + num_res = len(msas[0].sequences[0]) + num_alignments = len(int_msa) + features = {} + features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32) + features['msa'] = np.array(int_msa, dtype=np.int32) + features['num_alignments'] = np.array( + [num_alignments] * num_res, dtype=np.int32) + features['msa_species_identifiers'] = np.array( + species_ids, dtype=np.object_) + return features + + +def run_msa_tool( + msa_runner, + input_fasta_path: str, + msa_out_path: str, + msa_format: str, + use_precomputed_msas: bool, +) -> Mapping[str, Any]: + """Runs an MSA tool, checking if output already exists first.""" + if not use_precomputed_msas or not os.path.exists(msa_out_path): + result = msa_runner.query(input_fasta_path)[0] + with open(msa_out_path, 'w') as f: + f.write(result[msa_format]) + else: + logging.warning('Reading MSA from file %s', msa_out_path) + with open(msa_out_path, 'r') as f: + result = {msa_format: f.read()} + return result + + +class DataPipeline: + """Runs the alignment tools and assembles the input features.""" + + def __init__( + self, + jackhmmer_binary_path: str, + hhblits_binary_path: str, + uniref90_database_path: str, + mgnify_database_path: str, + bfd_database_path: Optional[str], + uniclust30_database_path: Optional[str], + small_bfd_database_path: Optional[str], + uniprot_database_path: Optional[str], + template_searcher: TemplateSearcher, + template_featurizer: templates.TemplateHitFeaturizer, + use_small_bfd: bool, + mgnify_max_hits: int = 501, + uniref_max_hits: int = 10000, + use_precomputed_msas: bool = False, + ): + """Initializes the data pipeline.""" + self._use_small_bfd = use_small_bfd + self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=uniref90_database_path) + if use_small_bfd: + self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=small_bfd_database_path) + else: + self.hhblits_bfd_uniclust_runner = hhblits.HHBlits( + binary_path=hhblits_binary_path, + databases=[bfd_database_path, uniclust30_database_path], + ) + self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=mgnify_database_path) + self.jackhmmer_uniprot_runner = jackhmmer.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=uniprot_database_path) + self.template_searcher = template_searcher + self.template_featurizer = template_featurizer + self.mgnify_max_hits = mgnify_max_hits + self.uniref_max_hits = uniref_max_hits + self.use_precomputed_msas = use_precomputed_msas + + def process(self, input_fasta_path: str, + msa_output_dir: str) -> FeatureDict: + """Runs alignment tools on the input sequence and creates features.""" + with open(input_fasta_path) as f: + input_fasta_str = f.read() + input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) + if len(input_seqs) != 1: + raise ValueError( + f'More than one input sequence found in {input_fasta_path}.') + input_sequence = input_seqs[0] + input_description = input_descs[0] + num_res = len(input_sequence) + + uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto') + jackhmmer_uniref90_result = run_msa_tool( + self.jackhmmer_uniref90_runner, + input_fasta_path, + uniref90_out_path, + 'sto', + self.use_precomputed_msas, + ) + mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto') + jackhmmer_mgnify_result = run_msa_tool( + self.jackhmmer_mgnify_runner, + input_fasta_path, + mgnify_out_path, + 'sto', + self.use_precomputed_msas, + ) + + msa_for_templates = jackhmmer_uniref90_result['sto'] + msa_for_templates = parsers.truncate_stockholm_msa( + msa_for_templates, max_sequences=self.uniref_max_hits) + msa_for_templates = parsers.deduplicate_stockholm_msa( + msa_for_templates) + msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa( + msa_for_templates) + + if self.template_searcher.input_format == 'sto': + pdb_templates_result = self.template_searcher.query( + msa_for_templates) + elif self.template_searcher.input_format == 'a3m': + uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m( + msa_for_templates) + pdb_templates_result = self.template_searcher.query( + uniref90_msa_as_a3m) + else: + raise ValueError('Unrecognized template input format: ' + f'{self.template_searcher.input_format}') + + pdb_hits_out_path = os.path.join( + msa_output_dir, f'pdb_hits.{self.template_searcher.output_format}') + with open(pdb_hits_out_path, 'w') as f: + f.write(pdb_templates_result) + + uniref90_msa = parsers.parse_stockholm( + jackhmmer_uniref90_result['sto']) + uniref90_msa = uniref90_msa.truncate(max_seqs=self.uniref_max_hits) + mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto']) + mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits) + + pdb_template_hits = self.template_searcher.get_template_hits( + output_string=pdb_templates_result, input_sequence=input_sequence) + + if self._use_small_bfd: + bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto') + jackhmmer_small_bfd_result = run_msa_tool( + self.jackhmmer_small_bfd_runner, + input_fasta_path, + bfd_out_path, + 'sto', + self.use_precomputed_msas, + ) + bfd_msa = parsers.parse_stockholm( + jackhmmer_small_bfd_result['sto']) + else: + bfd_out_path = os.path.join(msa_output_dir, + 'bfd_uniclust_hits.a3m') + hhblits_bfd_uniclust_result = run_msa_tool( + self.hhblits_bfd_uniclust_runner, + input_fasta_path, + bfd_out_path, + 'a3m', + self.use_precomputed_msas, + ) + bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m']) + + templates_result = self.template_featurizer.get_templates( + query_sequence=input_sequence, hits=pdb_template_hits) + + sequence_features = make_sequence_features( + sequence=input_sequence, + description=input_description, + num_res=num_res) + + msa_features = make_msa_features((uniref90_msa, bfd_msa, mgnify_msa)) + + logging.info('Uniref90 MSA size: %d sequences.', len(uniref90_msa)) + logging.info('BFD MSA size: %d sequences.', len(bfd_msa)) + logging.info('MGnify MSA size: %d sequences.', len(mgnify_msa)) + logging.info( + 'Final (deduplicated) MSA size: %d sequences.', + msa_features['num_alignments'][0], + ) + logging.info( + 'Total number of templates (NB: this can include bad ' + 'templates and is later filtered to top 4): %d.', + templates_result.features['template_domain_names'].shape[0], + ) + + return { + **sequence_features, + **msa_features, + **templates_result.features + } + + def process_uniprot(self, input_fasta_path: str, + msa_output_dir: str) -> FeatureDict: + uniprot_path = os.path.join(msa_output_dir, 'uniprot_hits.sto') + uniprot_result = run_msa_tool( + self.jackhmmer_uniprot_runner, + input_fasta_path, + uniprot_path, + 'sto', + self.use_precomputed_msas, + ) + msa = parsers.parse_stockholm(uniprot_result['sto']) + msa = msa.truncate(max_seqs=50000) + all_seq_dict = make_msa_features([msa]) + return all_seq_dict diff --git a/modelscope/models/science/unifold/msa/templates.py b/modelscope/models/science/unifold/msa/templates.py new file mode 100644 index 00000000..fe3bcef9 --- /dev/null +++ b/modelscope/models/science/unifold/msa/templates.py @@ -0,0 +1,1110 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Functions for getting templates and calculating template features.""" +import abc +import dataclasses +import datetime +import functools +import glob +import os +import re +from typing import Any, Dict, Mapping, Optional, Sequence, Tuple + +import numpy as np +from absl import logging + +from modelscope.models.science.unifold.data import residue_constants +from modelscope.models.science.unifold.msa import mmcif, parsers +from modelscope.models.science.unifold.msa.tools import kalign + + +class Error(Exception): + """Base class for exceptions.""" + + +class NoChainsError(Error): + """An error indicating that template mmCIF didn't have any chains.""" + + +class SequenceNotInTemplateError(Error): + """An error indicating that template mmCIF didn't contain the sequence.""" + + +class NoAtomDataInTemplateError(Error): + """An error indicating that template mmCIF didn't contain atom positions.""" + + +class TemplateAtomMaskAllZerosError(Error): + """An error indicating that template mmCIF had all atom positions masked.""" + + +class QueryToTemplateAlignError(Error): + """An error indicating that the query can't be aligned to the template.""" + + +class CaDistanceError(Error): + """An error indicating that a CA atom distance exceeds a threshold.""" + + +class MultipleChainsError(Error): + """An error indicating that multiple chains were found for a given ID.""" + + +# Prefilter exceptions. +class PrefilterError(Exception): + """A base class for template prefilter exceptions.""" + + +class DateError(PrefilterError): + """An error indicating that the hit date was after the max allowed date.""" + + +class AlignRatioError(PrefilterError): + """An error indicating that the hit align ratio to the query was too small.""" + + +class DuplicateError(PrefilterError): + """An error indicating that the hit was an exact subsequence of the query.""" + + +class LengthError(PrefilterError): + """An error indicating that the hit was too short.""" + + +TEMPLATE_FEATURES = { + 'template_aatype': np.float32, + 'template_all_atom_mask': np.float32, + 'template_all_atom_positions': np.float32, + 'template_domain_names': np.object_, + 'template_sequence': np.object_, + 'template_sum_probs': np.float32, +} + + +def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]: + """Returns PDB id and chain id for an HHSearch Hit.""" + # PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown. + id_match = re.match(r'[a-zA-Z\d]{4}_[a-zA-Z0-9.]+', hit.name) + if not id_match: + raise ValueError( + f'hit.name did not start with PDBID_chain: {hit.name}') + pdb_id, chain_id = id_match.group(0).split('_') + return pdb_id.lower(), chain_id + + +def _is_after_cutoff( + pdb_id: str, + release_dates: Mapping[str, datetime.datetime], + release_date_cutoff: Optional[datetime.datetime], +) -> bool: + """Checks if the template date is after the release date cutoff. + + Args: + pdb_id: 4 letter pdb code. + release_dates: Dictionary mapping PDB ids to their structure release dates. + release_date_cutoff: Max release date that is valid for this query. + + Returns: + True if the template release date is after the cutoff, False otherwise. + """ + if release_date_cutoff is None: + raise ValueError('The release_date_cutoff must not be None.') + if pdb_id in release_dates: + return release_dates[pdb_id] > release_date_cutoff + else: + # Since this is just a quick prefilter to reduce the number of mmCIF files + # we need to parse, we don't have to worry about returning True here. + return False + + +def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, Optional[str]]: + """Parses the data file from PDB that lists which pdb_ids are obsolete.""" + with open(obsolete_file_path) as f: + result = {} + for line in f: + line = line.strip() + # Format: Date From To + # 'OBSLTE 06-NOV-19 6G9Y' - Removed, rare + # 'OBSLTE 31-JUL-94 116L 216L' - Replaced, common + # 'OBSLTE 26-SEP-06 2H33 2JM5 2OWI' - Replaced by multiple, rare + if line.startswith('OBSLTE'): + if len(line) > 30: + # Replaced by at least one structure. + from_id = line[20:24].lower() + to_id = line[29:33].lower() + result[from_id] = to_id + elif len(line) == 24: + # Removed. + from_id = line[20:24].lower() + result[from_id] = None + return result + + +def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]: + """Parses release dates file, returns a mapping from PDBs to release dates.""" + if path.endswith('txt'): + release_dates = {} + with open(path, 'r') as f: + for line in f: + pdb_id, date = line.split(':') + date = date.strip() + # Python 3.6 doesn't have datetime.date.fromisoformat() which is about + # 90x faster than strptime. However, splitting the string manually is + # about 10x faster than strptime. + release_dates[pdb_id.strip()] = datetime.datetime( + year=int(date[:4]), + month=int(date[5:7]), + day=int(date[8:10])) + return release_dates + else: + raise ValueError('Invalid format of the release date file %s.' % path) + + +def _assess_hhsearch_hit( + hit: parsers.TemplateHit, + hit_pdb_code: str, + query_sequence: str, + release_dates: Mapping[str, datetime.datetime], + release_date_cutoff: datetime.datetime, + max_subsequence_ratio: float = 0.95, + min_align_ratio: float = 0.1, +) -> bool: + """Determines if template is valid (without parsing the template mmcif file). + + Args: + hit: HhrHit for the template. + hit_pdb_code: The 4 letter pdb code of the template hit. This might be + different from the value in the actual hit since the original pdb might + have become obsolete. + query_sequence: Amino acid sequence of the query. + release_dates: Dictionary mapping pdb codes to their structure release + dates. + release_date_cutoff: Max release date that is valid for this query. + max_subsequence_ratio: Exclude any exact matches with this much overlap. + min_align_ratio: Minimum overlap between the template and query. + + Returns: + True if the hit passed the prefilter. Raises an exception otherwise. + + Raises: + DateError: If the hit date was after the max allowed date. + AlignRatioError: If the hit align ratio to the query was too small. + DuplicateError: If the hit was an exact subsequence of the query. + LengthError: If the hit was too short. + """ + aligned_cols = hit.aligned_cols + align_ratio = aligned_cols / len(query_sequence) + + template_sequence = hit.hit_sequence.replace('-', '') + length_ratio = float(len(template_sequence)) / len(query_sequence) + + # Check whether the template is a large subsequence or duplicate of original + # query. This can happen due to duplicate entries in the PDB database. + duplicate = ( + template_sequence in query_sequence + and length_ratio > max_subsequence_ratio) + + if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff): + raise DateError( + f'Date ({release_dates[hit_pdb_code]}) > max template date ' + f'({release_date_cutoff}).') + + if align_ratio <= min_align_ratio: + raise AlignRatioError( + 'Proportion of residues aligned to query too small. ' + f'Align ratio: {align_ratio}.') + + if duplicate: + raise DuplicateError( + 'Template is an exact subsequence of query with large ' + f'coverage. Length ratio: {length_ratio}.') + + if len(template_sequence) < 10: + raise LengthError( + f'Template too short. Length: {len(template_sequence)}.') + + return True + + +def _find_template_in_pdb( + template_chain_id: str, template_sequence: str, + mmcif_object: mmcif.MmcifObject) -> Tuple[str, str, int]: + """Tries to find the template chain in the given pdb file. + + This method tries the three following things in order: + 1. Tries if there is an exact match in both the chain ID and the sequence. + If yes, the chain sequence is returned. Otherwise: + 2. Tries if there is an exact match only in the sequence. + If yes, the chain sequence is returned. Otherwise: + 3. Tries if there is a fuzzy match (X = wildcard) in the sequence. + If yes, the chain sequence is returned. + If none of these succeed, a SequenceNotInTemplateError is thrown. + + Args: + template_chain_id: The template chain ID. + template_sequence: The template chain sequence. + mmcif_object: The PDB object to search for the template in. + + Returns: + A tuple with: + * The chain sequence that was found to match the template in the PDB object. + * The ID of the chain that is being returned. + * The offset where the template sequence starts in the chain sequence. + + Raises: + SequenceNotInTemplateError: If no match is found after the steps described + above. + """ + # Try if there is an exact match in both the chain ID and the (sub)sequence. + pdb_id = mmcif_object.file_id + chain_sequence = mmcif_object.chain_to_seqres.get(template_chain_id) + if chain_sequence and (template_sequence in chain_sequence): + logging.info('Found an exact template match %s_%s.', pdb_id, + template_chain_id) + mapping_offset = chain_sequence.find(template_sequence) + return chain_sequence, template_chain_id, mapping_offset + + # Try if there is an exact match in the (sub)sequence only. + for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items(): + if chain_sequence and (template_sequence in chain_sequence): + logging.info('Found a sequence-only match %s_%s.', pdb_id, + chain_id) + mapping_offset = chain_sequence.find(template_sequence) + return chain_sequence, chain_id, mapping_offset + + # Return a chain sequence that fuzzy matches (X = wildcard) the template. + # Make parentheses unnamed groups (?:_) to avoid the 100 named groups limit. + regex = ['.' if aa == 'X' else '(?:%s|X)' % aa for aa in template_sequence] + regex = re.compile(''.join(regex)) + for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items(): + match = re.search(regex, chain_sequence) + if match: + logging.info('Found a fuzzy sequence-only match %s_%s.', pdb_id, + chain_id) + mapping_offset = match.start() + return chain_sequence, chain_id, mapping_offset + + # No hits, raise an error. + raise SequenceNotInTemplateError( + 'Could not find the template sequence in %s_%s. Template sequence: %s, ' + 'chain_to_seqres: %s' % (pdb_id, template_chain_id, template_sequence, + mmcif_object.chain_to_seqres)) + + +def _realign_pdb_template_to_query( + old_template_sequence: str, + template_chain_id: str, + mmcif_object: mmcif.MmcifObject, + old_mapping: Mapping[int, int], + kalign_binary_path: str, +) -> Tuple[str, Mapping[int, int]]: + """Aligns template from the mmcif_object to the query. + + In case PDB70 contains a different version of the template sequence, we need + to perform a realignment to the actual sequence that is in the mmCIF file. + This method performs such realignment, but returns the new sequence and + mapping only if the sequence in the mmCIF file is 90% identical to the old + sequence. + + Note that the old_template_sequence comes from the hit, and contains only that + part of the chain that matches with the query while the new_template_sequence + is the full chain. + + Args: + old_template_sequence: The template sequence that was returned by the PDB + template search (typically done using HHSearch). + template_chain_id: The template chain id was returned by the PDB template + search (typically done using HHSearch). This is used to find the right + chain in the mmcif_object chain_to_seqres mapping. + mmcif_object: A mmcif_object which holds the actual template data. + old_mapping: A mapping from the query sequence to the template sequence. + This mapping will be used to compute the new mapping from the query + sequence to the actual mmcif_object template sequence by aligning the + old_template_sequence and the actual template sequence. + kalign_binary_path: The path to a kalign executable. + + Returns: + A tuple (new_template_sequence, new_query_to_template_mapping) where: + * new_template_sequence is the actual template sequence that was found in + the mmcif_object. + * new_query_to_template_mapping is the new mapping from the query to the + actual template found in the mmcif_object. + + Raises: + QueryToTemplateAlignError: + * If there was an error thrown by the alignment tool. + * Or if the actual template sequence differs by more than 10% from the + old_template_sequence. + """ + aligner = kalign.Kalign(binary_path=kalign_binary_path) + new_template_sequence = mmcif_object.chain_to_seqres.get( + template_chain_id, '') + + # Sometimes the template chain id is unknown. But if there is only a single + # sequence within the mmcif_object, it is safe to assume it is that one. + if not new_template_sequence: + if len(mmcif_object.chain_to_seqres) == 1: + logging.info( + 'Could not find %s in %s, but there is only 1 sequence, so ' + 'using that one.', + template_chain_id, + mmcif_object.file_id, + ) + new_template_sequence = list( + mmcif_object.chain_to_seqres.values())[0] + else: + raise QueryToTemplateAlignError( + f'Could not find chain {template_chain_id} in {mmcif_object.file_id}. ' + 'If there are no mmCIF parsing errors, it is possible it was not a ' + 'protein chain.') + + try: + parsed_a3m = parsers.parse_a3m( + aligner.align([old_template_sequence, new_template_sequence])) + old_aligned_template, new_aligned_template = parsed_a3m.sequences + except Exception as e: + raise QueryToTemplateAlignError( + 'Could not align old template %s to template %s (%s_%s). Error: %s' + % ( + old_template_sequence, + new_template_sequence, + mmcif_object.file_id, + template_chain_id, + str(e), + )) + + logging.info( + 'Old aligned template: %s\nNew aligned template: %s', + old_aligned_template, + new_aligned_template, + ) + + old_to_new_template_mapping = {} + old_template_index = -1 + new_template_index = -1 + num_same = 0 + for old_template_aa, new_template_aa in zip(old_aligned_template, + new_aligned_template): + if old_template_aa != '-': + old_template_index += 1 + if new_template_aa != '-': + new_template_index += 1 + if old_template_aa != '-' and new_template_aa != '-': + old_to_new_template_mapping[ + old_template_index] = new_template_index + if old_template_aa == new_template_aa: + num_same += 1 + + # Require at least 90 % sequence identity wrt to the shorter of the sequences. + if (float(num_same) + / min(len(old_template_sequence), len(new_template_sequence)) + < # noqa W504 + 0.9): + raise QueryToTemplateAlignError( + 'Insufficient similarity of the sequence in the database: %s to the ' + 'actual sequence in the mmCIF file %s_%s: %s. We require at least ' + '90 %% similarity wrt to the shorter of the sequences. This is not a ' + 'problem unless you think this is a template that should be included.' + % ( + old_template_sequence, + mmcif_object.file_id, + template_chain_id, + new_template_sequence, + )) + + new_query_to_template_mapping = {} + for query_index, old_template_index in old_mapping.items(): + new_query_to_template_mapping[ + query_index] = old_to_new_template_mapping.get( + old_template_index, -1) + + new_template_sequence = new_template_sequence.replace('-', '') + + return new_template_sequence, new_query_to_template_mapping + + +def _check_residue_distances(all_positions: np.ndarray, + all_positions_mask: np.ndarray, + max_ca_ca_distance: float): + """Checks if the distance between unmasked neighbor residues is ok.""" + ca_position = residue_constants.atom_order['CA'] + prev_is_unmasked = False + prev_calpha = None + for i, (coords, mask) in enumerate(zip(all_positions, all_positions_mask)): + this_is_unmasked = bool(mask[ca_position]) + if this_is_unmasked: + this_calpha = coords[ca_position] + if prev_is_unmasked: + distance = np.linalg.norm(this_calpha - prev_calpha) + if distance > max_ca_ca_distance: + raise CaDistanceError( + 'The distance between residues %d and %d is %f > limit %f.' + % (i, i + 1, distance, max_ca_ca_distance)) + prev_calpha = this_calpha + prev_is_unmasked = this_is_unmasked + + +def _get_atom_positions( + mmcif_object: mmcif.MmcifObject, auth_chain_id: str, + max_ca_ca_distance: float) -> Tuple[np.ndarray, np.ndarray]: + """Gets atom positions and mask from a list of Biopython Residues.""" + num_res = len(mmcif_object.chain_to_seqres[auth_chain_id]) + + relevant_chains = [ + c for c in mmcif_object.structure.get_chains() if c.id == auth_chain_id + ] + if len(relevant_chains) != 1: + raise MultipleChainsError( + f'Expected exactly one chain in structure with id {auth_chain_id}.' + ) + chain = relevant_chains[0] + + all_positions = np.zeros([num_res, residue_constants.atom_type_num, 3]) + all_positions_mask = np.zeros([num_res, residue_constants.atom_type_num], + dtype=np.int64) + for res_index in range(num_res): + pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32) + mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32) + res_at_position = mmcif_object.seqres_to_structure[auth_chain_id][ + res_index] + if not res_at_position.is_missing: + res = chain[( + res_at_position.hetflag, + res_at_position.position.residue_number, + res_at_position.position.insertion_code, + )] + for atom in res.get_atoms(): + atom_name = atom.get_name() + x, y, z = atom.get_coord() + if atom_name in residue_constants.atom_order.keys(): + pos[residue_constants.atom_order[atom_name]] = [x, y, z] + mask[residue_constants.atom_order[atom_name]] = 1.0 + elif atom_name.upper() == 'SE' and res.get_resname() == 'MSE': + # Put the coordinates of the selenium atom in the sulphur column. + pos[residue_constants.atom_order['SD']] = [x, y, z] + mask[residue_constants.atom_order['SD']] = 1.0 + + # Fix naming errors in arginine residues where NH2 is incorrectly + # assigned to be closer to CD than NH1. + cd = residue_constants.atom_order['CD'] + nh1 = residue_constants.atom_order['NH1'] + nh2 = residue_constants.atom_order['NH2'] + if (res.get_resname() == 'ARG' + and all(mask[atom_index] for atom_index in (cd, nh1, nh2)) + and (np.linalg.norm(pos[nh1] - pos[cd]) > # noqa W504 + np.linalg.norm(pos[nh2] - pos[cd]))): + pos[nh1], pos[nh2] = pos[nh2].copy(), pos[nh1].copy() + mask[nh1], mask[nh2] = mask[nh2].copy(), mask[nh1].copy() + + all_positions[res_index] = pos + all_positions_mask[res_index] = mask + _check_residue_distances(all_positions, all_positions_mask, + max_ca_ca_distance) + return all_positions, all_positions_mask + + +def _extract_template_features( + mmcif_object: mmcif.MmcifObject, + pdb_id: str, + mapping: Mapping[int, int], + template_sequence: str, + query_sequence: str, + template_chain_id: str, + kalign_binary_path: str, +) -> Tuple[Dict[str, Any], Optional[str]]: + """Parses atom positions in the target structure and aligns with the query. + + Atoms for each residue in the template structure are indexed to coincide + with their corresponding residue in the query sequence, according to the + alignment mapping provided. + + Args: + mmcif_object: mmcif_parsing.MmcifObject representing the template. + pdb_id: PDB code for the template. + mapping: Dictionary mapping indices in the query sequence to indices in + the template sequence. + template_sequence: String describing the amino acid sequence for the + template protein. + query_sequence: String describing the amino acid sequence for the query + protein. + template_chain_id: String ID describing which chain in the structure proto + should be used. + kalign_binary_path: The path to a kalign executable used for template + realignment. + + Returns: + A tuple with: + * A dictionary containing the extra features derived from the template + protein structure. + * A warning message if the hit was realigned to the actual mmCIF sequence. + Otherwise None. + + Raises: + NoChainsError: If the mmcif object doesn't contain any chains. + SequenceNotInTemplateError: If the given chain id / sequence can't + be found in the mmcif object. + QueryToTemplateAlignError: If the actual template in the mmCIF file + can't be aligned to the query. + NoAtomDataInTemplateError: If the mmcif object doesn't contain + atom positions. + TemplateAtomMaskAllZerosError: If the mmcif object doesn't have any + unmasked residues. + """ + if mmcif_object is None or not mmcif_object.chain_to_seqres: + raise NoChainsError('No chains in PDB: %s_%s' % + (pdb_id, template_chain_id)) + + warning = None + try: + seqres, chain_id, mapping_offset = _find_template_in_pdb( + template_chain_id=template_chain_id, + template_sequence=template_sequence, + mmcif_object=mmcif_object, + ) + except SequenceNotInTemplateError: + # If PDB70 contains a different version of the template, we use the sequence + # from the mmcif_object. + chain_id = template_chain_id + warning = ( + f'The exact sequence {template_sequence} was not found in ' + f'{pdb_id}_{chain_id}. Realigning the template to the actual sequence.' + ) + logging.warning(warning) + # This throws an exception if it fails to realign the hit. + seqres, mapping = _realign_pdb_template_to_query( + old_template_sequence=template_sequence, + template_chain_id=template_chain_id, + mmcif_object=mmcif_object, + old_mapping=mapping, + kalign_binary_path=kalign_binary_path, + ) + logging.info( + 'Sequence in %s_%s: %s successfully realigned to %s', + pdb_id, + chain_id, + template_sequence, + seqres, + ) + # The template sequence changed. + template_sequence = seqres + # No mapping offset, the query is aligned to the actual sequence. + mapping_offset = 0 + + try: + # Essentially set to infinity - we don't want to reject templates unless + # they're really really bad. + all_atom_positions, all_atom_mask = _get_atom_positions( + mmcif_object, chain_id, max_ca_ca_distance=150.0) + except (CaDistanceError, KeyError) as ex: + raise NoAtomDataInTemplateError('Could not get atom data (%s_%s): %s' % + (pdb_id, chain_id, str(ex))) from ex + + all_atom_positions = np.split(all_atom_positions, + all_atom_positions.shape[0]) + all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0]) + + output_templates_sequence = [] + templates_all_atom_positions = [] + templates_all_atom_masks = [] + + for _ in query_sequence: + # Residues in the query_sequence that are not in the template_sequence: + templates_all_atom_positions.append( + np.zeros((residue_constants.atom_type_num, 3))) + templates_all_atom_masks.append( + np.zeros(residue_constants.atom_type_num)) + output_templates_sequence.append('-') + + for k, v in mapping.items(): + template_index = v + mapping_offset + templates_all_atom_positions[k] = all_atom_positions[template_index][0] + templates_all_atom_masks[k] = all_atom_masks[template_index][0] + output_templates_sequence[k] = template_sequence[v] + + # Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O). + if np.sum(templates_all_atom_masks) < 5: + raise TemplateAtomMaskAllZerosError( + 'Template all atom mask was all zeros: %s_%s. Residue range: %d-%d' + % ( + pdb_id, + chain_id, + min(mapping.values()) + mapping_offset, + max(mapping.values()) + mapping_offset, + )) + + output_templates_sequence = ''.join(output_templates_sequence) + + templates_aatype = residue_constants.sequence_to_onehot( + output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID) + + return ( + { + 'template_all_atom_positions': + np.array(templates_all_atom_positions), + 'template_all_atom_mask': np.array(templates_all_atom_masks), + 'template_sequence': output_templates_sequence.encode(), + 'template_aatype': np.array(templates_aatype), + 'template_domain_names': f'{pdb_id.lower()}_{chain_id}'.encode(), + }, + warning, + ) + + +def _build_query_to_hit_index_mapping( + hit_query_sequence: str, + hit_sequence: str, + indices_hit: Sequence[int], + indices_query: Sequence[int], + original_query_sequence: str, +) -> Mapping[int, int]: + """Gets mapping from indices in original query sequence to indices in the hit. + + hit_query_sequence and hit_sequence are two aligned sequences containing gap + characters. hit_query_sequence contains only the part of the original query + sequence that matched the hit. When interpreting the indices from the .hhr, we + need to correct for this to recover a mapping from original query sequence to + the hit sequence. + + Args: + hit_query_sequence: The portion of the query sequence that is in the .hhr + hit + hit_sequence: The portion of the hit sequence that is in the .hhr + indices_hit: The indices for each aminoacid relative to the hit sequence + indices_query: The indices for each aminoacid relative to the original query + sequence + original_query_sequence: String describing the original query sequence. + + Returns: + Dictionary with indices in the original query sequence as keys and indices + in the hit sequence as values. + """ + # If the hit is empty (no aligned residues), return empty mapping + if not hit_query_sequence: + return {} + + # Remove gaps and find the offset of hit.query relative to original query. + hhsearch_query_sequence = hit_query_sequence.replace('-', '') + hit_sequence = hit_sequence.replace('-', '') + hhsearch_query_offset = original_query_sequence.find( + hhsearch_query_sequence) + + # Index of -1 used for gap characters. Subtract the min index ignoring gaps. + min_idx = min(x for x in indices_hit if x > -1) + fixed_indices_hit = [x - min_idx if x > -1 else -1 for x in indices_hit] + + min_idx = min(x for x in indices_query if x > -1) + fixed_indices_query = [ + x - min_idx if x > -1 else -1 for x in indices_query + ] + + # Zip the corrected indices, ignore case where both seqs have gap characters. + mapping = {} + for q_i, q_t in zip(fixed_indices_query, fixed_indices_hit): + if q_t != -1 and q_i != -1: + if q_t >= len(hit_sequence) or q_i + hhsearch_query_offset >= len( + original_query_sequence): + continue + mapping[q_i + hhsearch_query_offset] = q_t + + return mapping + + +@dataclasses.dataclass(frozen=True) +class SingleHitResult: + features: Optional[Mapping[str, Any]] + error: Optional[str] + warning: Optional[str] + + +@functools.lru_cache(16, typed=False) +def _read_file(path): + with open(path, 'r') as f: + file_data = f.read() + return file_data + + +def _process_single_hit( + query_sequence: str, + hit: parsers.TemplateHit, + mmcif_dir: str, + max_template_date: datetime.datetime, + release_dates: Mapping[str, datetime.datetime], + obsolete_pdbs: Mapping[str, Optional[str]], + kalign_binary_path: str, + strict_error_check: bool = False, +) -> SingleHitResult: + """Tries to extract template features from a single HHSearch hit.""" + # Fail hard if we can't get the PDB ID and chain name from the hit. + hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit) + + # This hit has been removed (obsoleted) from PDB, skip it. + if hit_pdb_code in obsolete_pdbs and obsolete_pdbs[hit_pdb_code] is None: + return SingleHitResult( + features=None, + error=None, + warning=f'Hit {hit_pdb_code} is obsolete.') + + if hit_pdb_code not in release_dates: + if hit_pdb_code in obsolete_pdbs: + hit_pdb_code = obsolete_pdbs[hit_pdb_code] + + # Pass hit_pdb_code since it might have changed due to the pdb being obsolete. + try: + _assess_hhsearch_hit( + hit=hit, + hit_pdb_code=hit_pdb_code, + query_sequence=query_sequence, + release_dates=release_dates, + release_date_cutoff=max_template_date, + ) + except PrefilterError as e: + msg = f'hit {hit_pdb_code}_{hit_chain_id} did not pass prefilter: {str(e)}' + logging.info(msg) + if strict_error_check and isinstance(e, (DateError, DuplicateError)): + # In strict mode we treat some prefilter cases as errors. + return SingleHitResult(features=None, error=msg, warning=None) + + return SingleHitResult(features=None, error=None, warning=None) + + mapping = _build_query_to_hit_index_mapping(hit.query, hit.hit_sequence, + hit.indices_hit, + hit.indices_query, + query_sequence) + + # The mapping is from the query to the actual hit sequence, so we need to + # remove gaps (which regardless have a missing confidence score). + template_sequence = hit.hit_sequence.replace('-', '') + + cif_path = os.path.join(mmcif_dir, hit_pdb_code + '.cif') + logging.debug( + 'Reading PDB entry from %s. Query: %s, template: %s', + cif_path, + query_sequence, + template_sequence, + ) + # Fail if we can't find the mmCIF file. + cif_string = _read_file(cif_path) + + parsing_result = mmcif.parse(file_id=hit_pdb_code, mmcif_string=cif_string) + + if parsing_result.mmcif_object is not None: + hit_release_date = datetime.datetime.strptime( + parsing_result.mmcif_object.header['release_date'], '%Y-%m-%d') + if hit_release_date > max_template_date: + error = 'Template %s date (%s) > max template date (%s).' % ( + hit_pdb_code, + hit_release_date, + max_template_date, + ) + if strict_error_check: + return SingleHitResult( + features=None, error=error, warning=None) + else: + logging.debug(error) + return SingleHitResult(features=None, error=None, warning=None) + + try: + features, realign_warning = _extract_template_features( + mmcif_object=parsing_result.mmcif_object, + pdb_id=hit_pdb_code, + mapping=mapping, + template_sequence=template_sequence, + query_sequence=query_sequence, + template_chain_id=hit_chain_id, + kalign_binary_path=kalign_binary_path, + ) + if hit.sum_probs is None: + features['template_sum_probs'] = [0] + else: + features['template_sum_probs'] = [hit.sum_probs] + + # It is possible there were some errors when parsing the other chains in the + # mmCIF file, but the template features for the chain we want were still + # computed. In such case the mmCIF parsing errors are not relevant. + return SingleHitResult( + features=features, error=None, warning=realign_warning) + except ( + NoChainsError, + NoAtomDataInTemplateError, + TemplateAtomMaskAllZerosError, + ) as e: + # These 3 errors indicate missing mmCIF experimental data rather than a + # problem with the template search, so turn them into warnings. + warning = ( + '%s_%s (sum_probs: %s, rank: %s): feature extracting errors: ' + '%s, mmCIF parsing errors: %s' % ( + hit_pdb_code, + hit_chain_id, + hit.sum_probs, + hit.index, + str(e), + parsing_result.errors, + )) + if strict_error_check: + return SingleHitResult(features=None, error=warning, warning=None) + else: + return SingleHitResult(features=None, error=None, warning=warning) + except Error as e: + error = ( + '%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: ' + '%s, mmCIF parsing errors: %s' % ( + hit_pdb_code, + hit_chain_id, + hit.sum_probs, + hit.index, + str(e), + parsing_result.errors, + )) + return SingleHitResult(features=None, error=error, warning=None) + + +@dataclasses.dataclass(frozen=True) +class TemplateSearchResult: + features: Mapping[str, Any] + errors: Sequence[str] + warnings: Sequence[str] + + +class TemplateHitFeaturizer(abc.ABC): + """An abstract base class for turning template hits to template features.""" + + def __init__( + self, + mmcif_dir: str, + max_template_date: str, + max_hits: int, + kalign_binary_path: str, + release_dates_path: Optional[str], + obsolete_pdbs_path: Optional[str], + strict_error_check: bool = False, + ): + """Initializes the Template Search. + + Args: + mmcif_dir: Path to a directory with mmCIF structures. Once a template ID + is found by HHSearch, this directory is used to retrieve the template + data. + max_template_date: The maximum date permitted for template structures. No + template with date higher than this date will be returned. In ISO8601 + date format, YYYY-MM-DD. + max_hits: The maximum number of templates that will be returned. + kalign_binary_path: The path to a kalign executable used for template + realignment. + release_dates_path: An optional path to a file with a mapping from PDB IDs + to their release dates. Thanks to this we don't have to redundantly + parse mmCIF files to get that information. + obsolete_pdbs_path: An optional path to a file containing a mapping from + obsolete PDB IDs to the PDB IDs of their replacements. + strict_error_check: If True, then the following will be treated as errors: + * If any template date is after the max_template_date. + * If any template has identical PDB ID to the query. + * If any template is a duplicate of the query. + * Any feature computation errors. + """ + self._mmcif_dir = mmcif_dir + if not glob.glob(os.path.join(self._mmcif_dir, '*.cif')): + logging.error('Could not find CIFs in %s', self._mmcif_dir) + raise ValueError(f'Could not find CIFs in {self._mmcif_dir}') + + try: + self._max_template_date = datetime.datetime.strptime( + max_template_date, '%Y-%m-%d') + except ValueError: + raise ValueError( + 'max_template_date must be set and have format YYYY-MM-DD.') + self._max_hits = max_hits + self._kalign_binary_path = kalign_binary_path + self._strict_error_check = strict_error_check + + if release_dates_path: + logging.info('Using precomputed release dates %s.', + release_dates_path) + self._release_dates = _parse_release_dates(release_dates_path) + else: + self._release_dates = {} + + if obsolete_pdbs_path: + logging.info('Using precomputed obsolete pdbs %s.', + obsolete_pdbs_path) + self._obsolete_pdbs = _parse_obsolete(obsolete_pdbs_path) + else: + self._obsolete_pdbs = {} + + @abc.abstractmethod + def get_templates( + self, query_sequence: str, + hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult: + """Computes the templates for given query sequence.""" + + +class HhsearchHitFeaturizer(TemplateHitFeaturizer): + """A class for turning a3m hits from hhsearch to template features.""" + + def get_templates( + self, query_sequence: str, + hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult: + """Computes the templates for given query sequence (more details above).""" + logging.info('Searching for template for: %s', query_sequence) + + template_features = {} + for template_feature_name in TEMPLATE_FEATURES: + template_features[template_feature_name] = [] + + num_hits = 0 + errors = [] + warnings = [] + + for hit in sorted(hits, key=lambda x: x.sum_probs, reverse=True): + # We got all the templates we wanted, stop processing hits. + if num_hits >= self._max_hits: + break + + result = _process_single_hit( + query_sequence=query_sequence, + hit=hit, + mmcif_dir=self._mmcif_dir, + max_template_date=self._max_template_date, + release_dates=self._release_dates, + obsolete_pdbs=self._obsolete_pdbs, + strict_error_check=self._strict_error_check, + kalign_binary_path=self._kalign_binary_path, + ) + + if result.error: + errors.append(result.error) + + # There could be an error even if there are some results, e.g. thrown by + # other unparsable chains in the same mmCIF file. + if result.warning: + warnings.append(result.warning) + + if result.features is None: + logging.info( + 'Skipped invalid hit %s, error: %s, warning: %s', + hit.name, + result.error, + result.warning, + ) + else: + # Increment the hit counter, since we got features out of this hit. + num_hits += 1 + for k in template_features: + template_features[k].append(result.features[k]) + + for name in template_features: + if num_hits > 0: + template_features[name] = np.stack( + template_features[name], + axis=0).astype(TEMPLATE_FEATURES[name]) + else: + # Make sure the feature has correct dtype even if empty. + template_features[name] = np.array( + [], dtype=TEMPLATE_FEATURES[name]) + + return TemplateSearchResult( + features=template_features, errors=errors, warnings=warnings) + + +class HmmsearchHitFeaturizer(TemplateHitFeaturizer): + """A class for turning a3m hits from hmmsearch to template features.""" + + def get_templates( + self, query_sequence: str, + hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult: + """Computes the templates for given query sequence (more details above).""" + logging.info('Searching for template for: %s', query_sequence) + + template_features = {} + for template_feature_name in TEMPLATE_FEATURES: + template_features[template_feature_name] = [] + + already_seen = set() + errors = [] + warnings = [] + + if not hits or hits[0].sum_probs is None: + sorted_hits = hits + else: + sorted_hits = sorted(hits, key=lambda x: x.sum_probs, reverse=True) + + for hit in sorted_hits: + # We got all the templates we wanted, stop processing hits. + if len(already_seen) >= self._max_hits: + break + + result = _process_single_hit( + query_sequence=query_sequence, + hit=hit, + mmcif_dir=self._mmcif_dir, + max_template_date=self._max_template_date, + release_dates=self._release_dates, + obsolete_pdbs=self._obsolete_pdbs, + strict_error_check=self._strict_error_check, + kalign_binary_path=self._kalign_binary_path, + ) + + if result.error: + errors.append(result.error) + + # There could be an error even if there are some results, e.g. thrown by + # other unparsable chains in the same mmCIF file. + if result.warning: + warnings.append(result.warning) + + if result.features is None: + logging.debug( + 'Skipped invalid hit %s, error: %s, warning: %s', + hit.name, + result.error, + result.warning, + ) + else: + already_seen_key = result.features['template_sequence'] + if already_seen_key in already_seen: + continue + # Increment the hit counter, since we got features out of this hit. + already_seen.add(already_seen_key) + for k in template_features: + template_features[k].append(result.features[k]) + + if already_seen: + for name in template_features: + template_features[name] = np.stack( + template_features[name], + axis=0).astype(TEMPLATE_FEATURES[name]) + else: + num_res = len(query_sequence) + # Construct a default template with all zeros. + template_features = { + 'template_aatype': + np.zeros( + (1, num_res, len( + residue_constants.restypes_with_x_and_gap)), + np.float32, + ), + 'template_all_atom_mask': + np.zeros((1, num_res, residue_constants.atom_type_num), + np.float32), + 'template_all_atom_positions': + np.zeros((1, num_res, residue_constants.atom_type_num, 3), + np.float32), + 'template_domain_names': + np.array([''.encode()], dtype=np.object), + 'template_sequence': + np.array([''.encode()], dtype=np.object), + 'template_sum_probs': + np.array([0], dtype=np.float32), + } + return TemplateSearchResult( + features=template_features, errors=errors, warnings=warnings) diff --git a/modelscope/models/science/unifold/msa/tools/__init__.py b/modelscope/models/science/unifold/msa/tools/__init__.py new file mode 100644 index 00000000..903d0979 --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Python wrappers for third party tools.""" diff --git a/modelscope/models/science/unifold/msa/tools/hhblits.py b/modelscope/models/science/unifold/msa/tools/hhblits.py new file mode 100644 index 00000000..ee442e39 --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/hhblits.py @@ -0,0 +1,170 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Library to run HHblits from Python.""" + +import glob +import os +import subprocess +from typing import Any, List, Mapping, Optional, Sequence + +from absl import logging + +from . import utils + +_HHBLITS_DEFAULT_P = 20 +_HHBLITS_DEFAULT_Z = 500 + + +class HHBlits: + """Python wrapper of the HHblits binary.""" + + def __init__( + self, + *, + binary_path: str, + databases: Sequence[str], + n_cpu: int = 4, + n_iter: int = 3, + e_value: float = 0.001, + maxseq: int = 1_000_000, + realign_max: int = 100_000, + maxfilt: int = 100_000, + min_prefilter_hits: int = 1000, + all_seqs: bool = False, + alt: Optional[int] = None, + p: int = _HHBLITS_DEFAULT_P, + z: int = _HHBLITS_DEFAULT_Z, + ): + """Initializes the Python HHblits wrapper. + + Args: + binary_path: The path to the HHblits executable. + databases: A sequence of HHblits database paths. This should be the + common prefix for the database files (i.e. up to but not including + _hhm.ffindex etc.) + n_cpu: The number of CPUs to give HHblits. + n_iter: The number of HHblits iterations. + e_value: The E-value, see HHblits docs for more details. + maxseq: The maximum number of rows in an input alignment. Note that this + parameter is only supported in HHBlits version 3.1 and higher. + realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500. + maxfilt: Max number of hits allowed to pass the 2nd prefilter. + HHblits default: 20000. + min_prefilter_hits: Min number of hits to pass prefilter. + HHblits default: 100. + all_seqs: Return all sequences in the MSA / Do not filter the result MSA. + HHblits default: False. + alt: Show up to this many alternative alignments. + p: Minimum Prob for a hit to be included in the output hhr file. + HHblits default: 20. + z: Hard cap on number of hits reported in the hhr file. + HHblits default: 500. NB: The relevant HHblits flag is -Z not -z. + + Raises: + RuntimeError: If HHblits binary not found within the path. + """ + self.binary_path = binary_path + self.databases = databases + + for database_path in self.databases: + if not glob.glob(database_path + '_*'): + logging.error('Could not find HHBlits database %s', + database_path) + raise ValueError( + f'Could not find HHBlits database {database_path}') + + self.n_cpu = n_cpu + self.n_iter = n_iter + self.e_value = e_value + self.maxseq = maxseq + self.realign_max = realign_max + self.maxfilt = maxfilt + self.min_prefilter_hits = min_prefilter_hits + self.all_seqs = all_seqs + self.alt = alt + self.p = p + self.z = z + + def query(self, input_fasta_path: str) -> List[Mapping[str, Any]]: + """Queries the database using HHblits.""" + with utils.tmpdir_manager() as query_tmp_dir: + a3m_path = os.path.join(query_tmp_dir, 'output.a3m') + + db_cmd = [] + for db_path in self.databases: + db_cmd.append('-d') + db_cmd.append(db_path) + cmd = [ + self.binary_path, + '-i', + input_fasta_path, + '-cpu', + str(self.n_cpu), + '-oa3m', + a3m_path, + '-o', + '/dev/null', + '-n', + str(self.n_iter), + '-e', + str(self.e_value), + '-maxseq', + str(self.maxseq), + '-realign_max', + str(self.realign_max), + '-maxfilt', + str(self.maxfilt), + '-min_prefilter_hits', + str(self.min_prefilter_hits), + ] + if self.all_seqs: + cmd += ['-all'] + if self.alt: + cmd += ['-alt', str(self.alt)] + if self.p != _HHBLITS_DEFAULT_P: + cmd += ['-p', str(self.p)] + if self.z != _HHBLITS_DEFAULT_Z: + cmd += ['-Z', str(self.z)] + cmd += db_cmd + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + with utils.timing('HHblits query'): + stdout, stderr = process.communicate() + retcode = process.wait() + + if retcode: + # Logs have a 15k character limit, so log HHblits error line by line. + logging.error('HHblits failed. HHblits stderr begin:') + for error_line in stderr.decode('utf-8').splitlines(): + if error_line.strip(): + logging.error(error_line.strip()) + logging.error('HHblits stderr end') + raise RuntimeError( + 'HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % + (stdout.decode('utf-8'), stderr[:500_000].decode('utf-8'))) + + with open(a3m_path) as f: + a3m = f.read() + + raw_output = dict( + a3m=a3m, + output=stdout, + stderr=stderr, + n_iter=self.n_iter, + e_value=self.e_value, + ) + return [raw_output] diff --git a/modelscope/models/science/unifold/msa/tools/hhsearch.py b/modelscope/models/science/unifold/msa/tools/hhsearch.py new file mode 100644 index 00000000..ac7f3b55 --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/hhsearch.py @@ -0,0 +1,111 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Library to run HHsearch from Python.""" + +import glob +import os +import subprocess +from typing import Sequence + +from absl import logging + +from modelscope.models.science.unifold.msa import parsers +from . import utils + + +class HHSearch: + """Python wrapper of the HHsearch binary.""" + + def __init__(self, + *, + binary_path: str, + databases: Sequence[str], + maxseq: int = 1_000_000): + """Initializes the Python HHsearch wrapper. + + Args: + binary_path: The path to the HHsearch executable. + databases: A sequence of HHsearch database paths. This should be the + common prefix for the database files (i.e. up to but not including + _hhm.ffindex etc.) + maxseq: The maximum number of rows in an input alignment. Note that this + parameter is only supported in HHBlits version 3.1 and higher. + + Raises: + RuntimeError: If HHsearch binary not found within the path. + """ + self.binary_path = binary_path + self.databases = databases + self.maxseq = maxseq + + for database_path in self.databases: + if not glob.glob(database_path + '_*'): + logging.error('Could not find HHsearch database %s', + database_path) + raise ValueError( + f'Could not find HHsearch database {database_path}') + + @property + def output_format(self) -> str: + return 'hhr' + + @property + def input_format(self) -> str: + return 'a3m' + + def query(self, a3m: str) -> str: + """Queries the database using HHsearch using a given a3m.""" + with utils.tmpdir_manager() as query_tmp_dir: + input_path = os.path.join(query_tmp_dir, 'query.a3m') + hhr_path = os.path.join(query_tmp_dir, 'output.hhr') + with open(input_path, 'w') as f: + f.write(a3m) + + db_cmd = [] + for db_path in self.databases: + db_cmd.append('-d') + db_cmd.append(db_path) + cmd = [ + self.binary_path, + '-i', + input_path, + '-o', + hhr_path, + '-maxseq', + str(self.maxseq), + ] + db_cmd + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + with utils.timing('HHsearch query'): + stdout, stderr = process.communicate() + retcode = process.wait() + + if retcode: + # Stderr is truncated to prevent proto size errors in Beam. + raise RuntimeError( + 'HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % + (stdout.decode('utf-8'), stderr[:100_000].decode('utf-8'))) + + with open(hhr_path) as f: + hhr = f.read() + return hhr + + def get_template_hits( + self, output_string: str, + input_sequence: str) -> Sequence[parsers.TemplateHit]: + """Gets parsed template hits from the raw string output by the tool.""" + del input_sequence # Used by hmmseach but not needed for hhsearch. + return parsers.parse_hhr(output_string) diff --git a/modelscope/models/science/unifold/msa/tools/hmmbuild.py b/modelscope/models/science/unifold/msa/tools/hmmbuild.py new file mode 100644 index 00000000..84f205d6 --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/hmmbuild.py @@ -0,0 +1,143 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A Python wrapper for hmmbuild - construct HMM profiles from MSA.""" + +import os +import re +import subprocess + +from absl import logging + +from . import utils + + +class Hmmbuild(object): + """Python wrapper of the hmmbuild binary.""" + + def __init__(self, *, binary_path: str, singlemx: bool = False): + """Initializes the Python hmmbuild wrapper. + + Args: + binary_path: The path to the hmmbuild executable. + singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to + just use a common substitution score matrix. + + Raises: + RuntimeError: If hmmbuild binary not found within the path. + """ + self.binary_path = binary_path + self.singlemx = singlemx + + def build_profile_from_sto(self, + sto: str, + model_construction='fast') -> str: + """Builds a HHM for the aligned sequences given as an A3M string. + + Args: + sto: A string with the aligned sequences in the Stockholm format. + model_construction: Whether to use reference annotation in the msa to + determine consensus columns ('hand') or default ('fast'). + + Returns: + A string with the profile in the HMM format. + + Raises: + RuntimeError: If hmmbuild fails. + """ + return self._build_profile(sto, model_construction=model_construction) + + def build_profile_from_a3m(self, a3m: str) -> str: + """Builds a HHM for the aligned sequences given as an A3M string. + + Args: + a3m: A string with the aligned sequences in the A3M format. + + Returns: + A string with the profile in the HMM format. + + Raises: + RuntimeError: If hmmbuild fails. + """ + lines = [] + for line in a3m.splitlines(): + if not line.startswith('>'): + line = re.sub('[a-z]+', '', line) # Remove inserted residues. + lines.append(line + '\n') + msa = ''.join(lines) + return self._build_profile(msa, model_construction='fast') + + def _build_profile(self, + msa: str, + model_construction: str = 'fast') -> str: + """Builds a HMM for the aligned sequences given as an MSA string. + + Args: + msa: A string with the aligned sequences, in A3M or STO format. + model_construction: Whether to use reference annotation in the msa to + determine consensus columns ('hand') or default ('fast'). + + Returns: + A string with the profile in the HMM format. + + Raises: + RuntimeError: If hmmbuild fails. + ValueError: If unspecified arguments are provided. + """ + if model_construction not in {'hand', 'fast'}: + raise ValueError( + f'Invalid model_construction {model_construction} - only' + 'hand and fast supported.') + + with utils.tmpdir_manager() as query_tmp_dir: + input_query = os.path.join(query_tmp_dir, 'query.msa') + output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm') + + with open(input_query, 'w') as f: + f.write(msa) + + cmd = [self.binary_path] + # If adding flags, we have to do so before the output and input: + + if model_construction == 'hand': + cmd.append(f'--{model_construction}') + if self.singlemx: + cmd.append('--singlemx') + cmd.extend([ + '--amino', + output_hmm_path, + input_query, + ]) + + logging.info('Launching subprocess %s', cmd) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + with utils.timing('hmmbuild query'): + stdout, stderr = process.communicate() + retcode = process.wait() + logging.info( + 'hmmbuild stdout:\n%s\n\nstderr:\n%s\n', + stdout.decode('utf-8'), + stderr.decode('utf-8'), + ) + + if retcode: + raise RuntimeError( + 'hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n' % + (stdout.decode('utf-8'), stderr.decode('utf-8'))) + + with open(output_hmm_path, encoding='utf-8') as f: + hmm = f.read() + + return hmm diff --git a/modelscope/models/science/unifold/msa/tools/hmmsearch.py b/modelscope/models/science/unifold/msa/tools/hmmsearch.py new file mode 100644 index 00000000..445970ca --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/hmmsearch.py @@ -0,0 +1,146 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A Python wrapper for hmmsearch - search profile against a sequence db.""" + +import os +import subprocess +from typing import Optional, Sequence + +from absl import logging + +from modelscope.models.science.unifold.msa import parsers +from . import hmmbuild, utils + + +class Hmmsearch(object): + """Python wrapper of the hmmsearch binary.""" + + def __init__( + self, + *, + binary_path: str, + hmmbuild_binary_path: str, + database_path: str, + flags: Optional[Sequence[str]] = None, + ): + """Initializes the Python hmmsearch wrapper. + + Args: + binary_path: The path to the hmmsearch executable. + hmmbuild_binary_path: The path to the hmmbuild executable. Used to build + an hmm from an input a3m. + database_path: The path to the hmmsearch database (FASTA format). + flags: List of flags to be used by hmmsearch. + + Raises: + RuntimeError: If hmmsearch binary not found within the path. + """ + self.binary_path = binary_path + self.hmmbuild_runner = hmmbuild.Hmmbuild( + binary_path=hmmbuild_binary_path) + self.database_path = database_path + if flags is None: + # Default hmmsearch run settings. + flags = [ + '--F1', + '0.1', + '--F2', + '0.1', + '--F3', + '0.1', + '--incE', + '100', + '-E', + '100', + '--domE', + '100', + '--incdomE', + '100', + ] + self.flags = flags + + if not os.path.exists(self.database_path): + logging.error('Could not find hmmsearch database %s', + database_path) + raise ValueError( + f'Could not find hmmsearch database {database_path}') + + @property + def output_format(self) -> str: + return 'sto' + + @property + def input_format(self) -> str: + return 'sto' + + def query(self, msa_sto: str) -> str: + """Queries the database using hmmsearch using a given stockholm msa.""" + hmm = self.hmmbuild_runner.build_profile_from_sto( + msa_sto, model_construction='hand') + return self.query_with_hmm(hmm) + + def query_with_hmm(self, hmm: str) -> str: + """Queries the database using hmmsearch using a given hmm.""" + with utils.tmpdir_manager() as query_tmp_dir: + hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm') + out_path = os.path.join(query_tmp_dir, 'output.sto') + with open(hmm_input_path, 'w') as f: + f.write(hmm) + + cmd = [ + self.binary_path, + '--noali', # Don't include the alignment in stdout. + '--cpu', + '8', + ] + # If adding flags, we have to do so before the output and input: + if self.flags: + cmd.extend(self.flags) + cmd.extend([ + '-A', + out_path, + hmm_input_path, + self.database_path, + ]) + + logging.info('Launching sub-process %s', cmd) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + with utils.timing( + f'hmmsearch ({os.path.basename(self.database_path)}) query' + ): + stdout, stderr = process.communicate() + retcode = process.wait() + + if retcode: + raise RuntimeError( + 'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % + (stdout.decode('utf-8'), stderr.decode('utf-8'))) + + with open(out_path) as f: + out_msa = f.read() + + return out_msa + + def get_template_hits( + self, output_string: str, + input_sequence: str) -> Sequence[parsers.TemplateHit]: + """Gets parsed template hits from the raw string output by the tool.""" + a3m_string = parsers.convert_stockholm_to_a3m( + output_string, remove_first_row_gaps=False) + template_hits = parsers.parse_hmmsearch_a3m( + query_sequence=input_sequence, + a3m_string=a3m_string, + skip_first=False) + return template_hits diff --git a/modelscope/models/science/unifold/msa/tools/jackhmmer.py b/modelscope/models/science/unifold/msa/tools/jackhmmer.py new file mode 100644 index 00000000..3e29eec9 --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/jackhmmer.py @@ -0,0 +1,224 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Library to run Jackhmmer from Python.""" + +import glob +import os +import subprocess +from concurrent import futures +from typing import Any, Callable, Mapping, Optional, Sequence +from urllib import request + +from absl import logging + +from . import utils + + +class Jackhmmer: + """Python wrapper of the Jackhmmer binary.""" + + def __init__( + self, + *, + binary_path: str, + database_path: str, + n_cpu: int = 8, + n_iter: int = 1, + e_value: float = 0.0001, + z_value: Optional[int] = None, + get_tblout: bool = False, + filter_f1: float = 0.0005, + filter_f2: float = 0.00005, + filter_f3: float = 0.0000005, + incdom_e: Optional[float] = None, + dom_e: Optional[float] = None, + num_streamed_chunks: Optional[int] = None, + streaming_callback: Optional[Callable[[int], None]] = None, + ): + """Initializes the Python Jackhmmer wrapper. + + Args: + binary_path: The path to the jackhmmer executable. + database_path: The path to the jackhmmer database (FASTA format). + n_cpu: The number of CPUs to give Jackhmmer. + n_iter: The number of Jackhmmer iterations. + e_value: The E-value, see Jackhmmer docs for more details. + z_value: The Z-value, see Jackhmmer docs for more details. + get_tblout: Whether to save tblout string. + filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off. + filter_f2: Viterbi pre-filter, set to >1.0 to turn off. + filter_f3: Forward pre-filter, set to >1.0 to turn off. + incdom_e: Domain e-value criteria for inclusion of domains in MSA/next + round. + dom_e: Domain e-value criteria for inclusion in tblout. + num_streamed_chunks: Number of database chunks to stream over. + streaming_callback: Callback function run after each chunk iteration with + the iteration number as argument. + """ + self.binary_path = binary_path + self.database_path = database_path + self.num_streamed_chunks = num_streamed_chunks + + if not os.path.exists( + self.database_path) and num_streamed_chunks is None: + logging.error('Could not find Jackhmmer database %s', + database_path) + raise ValueError( + f'Could not find Jackhmmer database {database_path}') + + self.n_cpu = n_cpu + self.n_iter = n_iter + self.e_value = e_value + self.z_value = z_value + self.filter_f1 = filter_f1 + self.filter_f2 = filter_f2 + self.filter_f3 = filter_f3 + self.incdom_e = incdom_e + self.dom_e = dom_e + self.get_tblout = get_tblout + self.streaming_callback = streaming_callback + + def _query_chunk(self, input_fasta_path: str, + database_path: str) -> Mapping[str, Any]: + """Queries the database chunk using Jackhmmer.""" + with utils.tmpdir_manager() as query_tmp_dir: + sto_path = os.path.join(query_tmp_dir, 'output.sto') + + # The F1/F2/F3 are the expected proportion to pass each of the filtering + # stages (which get progressively more expensive), reducing these + # speeds up the pipeline at the expensive of sensitivity. They are + # currently set very low to make querying Mgnify run in a reasonable + # amount of time. + cmd_flags = [ + # Don't pollute stdout with Jackhmmer output. + '-o', + '/dev/null', + '-A', + sto_path, + '--noali', + '--F1', + str(self.filter_f1), + '--F2', + str(self.filter_f2), + '--F3', + str(self.filter_f3), + '--incE', + str(self.e_value), + # Report only sequences with E-values <= x in per-sequence output. + '-E', + str(self.e_value), + '--cpu', + str(self.n_cpu), + '-N', + str(self.n_iter), + ] + if self.get_tblout: + tblout_path = os.path.join(query_tmp_dir, 'tblout.txt') + cmd_flags.extend(['--tblout', tblout_path]) + + if self.z_value: + cmd_flags.extend(['-Z', str(self.z_value)]) + + if self.dom_e is not None: + cmd_flags.extend(['--domE', str(self.dom_e)]) + + if self.incdom_e is not None: + cmd_flags.extend(['--incdomE', str(self.incdom_e)]) + + cmd = [self.binary_path + ] + cmd_flags + [input_fasta_path, database_path] + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + with utils.timing( + f'Jackhmmer ({os.path.basename(database_path)}) query'): + _, stderr = process.communicate() + retcode = process.wait() + + if retcode: + raise RuntimeError('Jackhmmer failed\nstderr:\n%s\n' + % stderr.decode('utf-8')) + + # Get e-values for each target name + tbl = '' + if self.get_tblout: + with open(tblout_path) as f: + tbl = f.read() + + with open(sto_path) as f: + sto = f.read() + + raw_output = dict( + sto=sto, + tbl=tbl, + stderr=stderr, + n_iter=self.n_iter, + e_value=self.e_value) + + return raw_output + + def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]: + """Queries the database using Jackhmmer.""" + if self.num_streamed_chunks is None: + return [self._query_chunk(input_fasta_path, self.database_path)] + + db_basename = os.path.basename(self.database_path) + + def db_remote_chunk(db_idx): + return f'{self.database_path}.{db_idx}' + + def db_local_chunk(db_idx): + return f'/tmp/ramdisk/{db_basename}.{db_idx}' + + # db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}' + # db_local_chunk = lambda db_idx: f'/tmp/ramdisk/{db_basename}.{db_idx}' + + # Remove existing files to prevent OOM + for f in glob.glob(db_local_chunk('[0-9]*')): + try: + os.remove(f) + except OSError: + print(f'OSError while deleting {f}') + + # Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk + with futures.ThreadPoolExecutor(max_workers=2) as executor: + chunked_output = [] + for i in range(1, self.num_streamed_chunks + 1): + # Copy the chunk locally + if i == 1: + future = executor.submit(request.urlretrieve, + db_remote_chunk(i), + db_local_chunk(i)) + if i < self.num_streamed_chunks: + next_future = executor.submit( + request.urlretrieve, + db_remote_chunk(i + 1), + db_local_chunk(i + 1), + ) + + # Run Jackhmmer with the chunk + future.result() + chunked_output.append( + self._query_chunk(input_fasta_path, db_local_chunk(i))) + + # Remove the local copy of the chunk + os.remove(db_local_chunk(i)) + # Do not set next_future for the last chunk so that this works even for + # databases with only 1 chunk. + if i < self.num_streamed_chunks: + future = next_future + if self.streaming_callback: + self.streaming_callback(i) + return chunked_output diff --git a/modelscope/models/science/unifold/msa/tools/kalign.py b/modelscope/models/science/unifold/msa/tools/kalign.py new file mode 100644 index 00000000..1ea997fa --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/kalign.py @@ -0,0 +1,110 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A Python wrapper for Kalign.""" +import os +import subprocess +from typing import Sequence + +from absl import logging + +from . import utils + + +def _to_a3m(sequences: Sequence[str]) -> str: + """Converts sequences to an a3m file.""" + names = ['sequence %d' % i for i in range(1, len(sequences) + 1)] + a3m = [] + for sequence, name in zip(sequences, names): + a3m.append('>' + name + '\n') + a3m.append(sequence + '\n') + return ''.join(a3m) + + +class Kalign: + """Python wrapper of the Kalign binary.""" + + def __init__(self, *, binary_path: str): + """Initializes the Python Kalign wrapper. + + Args: + binary_path: The path to the Kalign binary. + + Raises: + RuntimeError: If Kalign binary not found within the path. + """ + self.binary_path = binary_path + + def align(self, sequences: Sequence[str]) -> str: + """Aligns the sequences and returns the alignment in A3M string. + + Args: + sequences: A list of query sequence strings. The sequences have to be at + least 6 residues long (Kalign requires this). Note that the order in + which you give the sequences might alter the output slightly as + different alignment tree might get constructed. + + Returns: + A string with the alignment in a3m format. + + Raises: + RuntimeError: If Kalign fails. + ValueError: If any of the sequences is less than 6 residues long. + """ + logging.info('Aligning %d sequences', len(sequences)) + + for s in sequences: + if len(s) < 6: + raise ValueError( + 'Kalign requires all sequences to be at least 6 ' + 'residues long. Got %s (%d residues).' % (s, len(s))) + + with utils.tmpdir_manager() as query_tmp_dir: + input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta') + output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m') + + with open(input_fasta_path, 'w') as f: + f.write(_to_a3m(sequences)) + + cmd = [ + self.binary_path, + '-i', + input_fasta_path, + '-o', + output_a3m_path, + '-format', + 'fasta', + ] + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + with utils.timing('Kalign query'): + stdout, stderr = process.communicate() + retcode = process.wait() + logging.info( + 'Kalign stdout:\n%s\n\nstderr:\n%s\n', + stdout.decode('utf-8'), + stderr.decode('utf-8'), + ) + + if retcode: + raise RuntimeError( + 'Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n' % + (stdout.decode('utf-8'), stderr.decode('utf-8'))) + + with open(output_a3m_path) as f: + a3m = f.read() + + return a3m diff --git a/modelscope/models/science/unifold/msa/tools/utils.py b/modelscope/models/science/unifold/msa/tools/utils.py new file mode 100644 index 00000000..1c2af936 --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/utils.py @@ -0,0 +1,40 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common utilities for data pipeline tools.""" +import contextlib +import shutil +import tempfile +import time +from typing import Optional + +from absl import logging + + +@contextlib.contextmanager +def tmpdir_manager(base_dir: Optional[str] = None): + """Context manager that deletes a temporary directory on exit.""" + tmpdir = tempfile.mkdtemp(dir=base_dir) + try: + yield tmpdir + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +@contextlib.contextmanager +def timing(msg: str): + logging.info('Started %s', msg) + tic = time.time() + yield + toc = time.time() + logging.info('Finished %s in %.3f seconds', msg, toc - tic) diff --git a/modelscope/models/science/unifold/msa/utils.py b/modelscope/models/science/unifold/msa/utils.py new file mode 100644 index 00000000..50e380d4 --- /dev/null +++ b/modelscope/models/science/unifold/msa/utils.py @@ -0,0 +1,89 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import os +from typing import Mapping, Sequence + +import json +from absl import logging + +from modelscope.models.science.unifold.data import protein + + +def get_chain_id_map( + sequences: Sequence[str], + descriptions: Sequence[str], +): + """ + Makes a mapping from PDB-format chain ID to sequence and description, + and parses the order of multi-chains + """ + unique_seqs = [] + for seq in sequences: + if seq not in unique_seqs: + unique_seqs.append(seq) + + chain_id_map = { + chain_id: { + 'descriptions': [], + 'sequence': seq + } + for chain_id, seq in zip(protein.PDB_CHAIN_IDS, unique_seqs) + } + chain_order = [] + + for seq, des in zip(sequences, descriptions): + chain_id = protein.PDB_CHAIN_IDS[unique_seqs.index(seq)] + chain_id_map[chain_id]['descriptions'].append(des) + chain_order.append(chain_id) + + return chain_id_map, chain_order + + +def divide_multi_chains( + fasta_name: str, + output_dir_base: str, + sequences: Sequence[str], + descriptions: Sequence[str], +): + """ + Divides the multi-chains fasta into several single fasta files and + records multi-chains mapping information. + """ + if len(sequences) != len(descriptions): + raise ValueError('sequences and descriptions must have equal length. ' + f'Got {len(sequences)} != {len(descriptions)}.') + if len(sequences) > protein.PDB_MAX_CHAINS: + raise ValueError( + 'Cannot process more chains than the PDB format supports. ' + f'Got {len(sequences)} chains.') + + chain_id_map, chain_order = get_chain_id_map(sequences, descriptions) + + output_dir = os.path.join(output_dir_base, fasta_name) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + chain_id_map_path = os.path.join(output_dir, 'chain_id_map.json') + with open(chain_id_map_path, 'w') as f: + json.dump(chain_id_map, f, indent=4, sort_keys=True) + + chain_order_path = os.path.join(output_dir, 'chains.txt') + with open(chain_order_path, 'w') as f: + f.write(' '.join(chain_order)) + + logging.info('Mapping multi-chains fasta with chain order: %s', + ' '.join(chain_order)) + + temp_names = [] + temp_paths = [] + for chain_id in chain_id_map.keys(): + temp_name = fasta_name + '_{}'.format(chain_id) + temp_path = os.path.join(output_dir, temp_name + '.fasta') + des = 'chain_{}'.format(chain_id) + seq = chain_id_map[chain_id]['sequence'] + with open(temp_path, 'w') as f: + f.write('>' + des + '\n' + seq) + temp_names.append(temp_name) + temp_paths.append(temp_path) + return temp_names, temp_paths diff --git a/modelscope/msdatasets/cv/easycv_base.py b/modelscope/msdatasets/cv/easycv_base.py index a45827a3..7b6df6e0 100644 --- a/modelscope/msdatasets/cv/easycv_base.py +++ b/modelscope/msdatasets/cv/easycv_base.py @@ -26,11 +26,16 @@ class EasyCVBaseDataset(object): if self.split_config is not None: self._update_data_source(kwargs['data_source']) + def _update_data_root(self, input_dict, data_root): + for k, v in input_dict.items(): + if isinstance(v, str) and self.DATA_ROOT_PATTERN in v: + input_dict.update( + {k: v.replace(self.DATA_ROOT_PATTERN, data_root)}) + elif isinstance(v, dict): + self._update_data_root(v, data_root) + def _update_data_source(self, data_source): data_root = next(iter(self.split_config.values())) data_root = data_root.rstrip(osp.sep) - for k, v in data_source.items(): - if isinstance(v, str) and self.DATA_ROOT_PATTERN in v: - data_source.update( - {k: v.replace(self.DATA_ROOT_PATTERN, data_root)}) + self._update_data_root(data_source, data_root) diff --git a/modelscope/msdatasets/cv/hand_2d_keypoints/__init__.py b/modelscope/msdatasets/cv/hand_2d_keypoints/__init__.py new file mode 100644 index 00000000..5c1c72c1 --- /dev/null +++ b/modelscope/msdatasets/cv/hand_2d_keypoints/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .hand_2d_keypoints_dataset import Hand2DKeypointDataset + +else: + _import_structure = { + 'hand_2d_keypoints_dataset': ['Hand2DKeypointDataset'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/cv/hand_2d_keypoints/hand_2d_keypoints_dataset.py b/modelscope/msdatasets/cv/hand_2d_keypoints/hand_2d_keypoints_dataset.py new file mode 100644 index 00000000..89ee0bb8 --- /dev/null +++ b/modelscope/msdatasets/cv/hand_2d_keypoints/hand_2d_keypoints_dataset.py @@ -0,0 +1,38 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from easycv.datasets.pose import \ + HandCocoWholeBodyDataset as _HandCocoWholeBodyDataset + +from modelscope.metainfo import Datasets +from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.utils.constant import Tasks + + +@TASK_DATASETS.register_module( + group_key=Tasks.hand_2d_keypoints, + module_name=Datasets.HandCocoWholeBodyDataset) +class HandCocoWholeBodyDataset(EasyCVBaseDataset, _HandCocoWholeBodyDataset): + """EasyCV dataset for human hand 2d keypoints. + + Args: + split_config (dict): Dataset root path from MSDataset, e.g. + {"train":"local cache path"} or {"evaluation":"local cache path"}. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. Not support yet. + mode: Training or Evaluation. + """ + + def __init__(self, + split_config=None, + preprocessor=None, + mode=None, + *args, + **kwargs) -> None: + EasyCVBaseDataset.__init__( + self, + split_config=split_config, + preprocessor=preprocessor, + mode=mode, + args=args, + kwargs=kwargs) + _HandCocoWholeBodyDataset.__init__(self, *args, **kwargs) diff --git a/modelscope/msdatasets/cv/human_wholebody_keypoint/__init__.py b/modelscope/msdatasets/cv/human_wholebody_keypoint/__init__.py new file mode 100644 index 00000000..472ed2d8 --- /dev/null +++ b/modelscope/msdatasets/cv/human_wholebody_keypoint/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .human_wholebody_keypoint_dataset import WholeBodyCocoTopDownDataset + +else: + _import_structure = { + 'human_wholebody_keypoint_dataset': ['WholeBodyCocoTopDownDataset'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/cv/human_wholebody_keypoint/human_wholebody_keypoint_dataset.py b/modelscope/msdatasets/cv/human_wholebody_keypoint/human_wholebody_keypoint_dataset.py new file mode 100644 index 00000000..fc9469f2 --- /dev/null +++ b/modelscope/msdatasets/cv/human_wholebody_keypoint/human_wholebody_keypoint_dataset.py @@ -0,0 +1,39 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from easycv.datasets.pose import \ + WholeBodyCocoTopDownDataset as _WholeBodyCocoTopDownDataset + +from modelscope.metainfo import Datasets +from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.utils.constant import Tasks + + +@TASK_DATASETS.register_module( + group_key=Tasks.human_wholebody_keypoint, + module_name=Datasets.HumanWholeBodyKeypointDataset) +class WholeBodyCocoTopDownDataset(EasyCVBaseDataset, + _WholeBodyCocoTopDownDataset): + """EasyCV dataset for human whole body 2d keypoints. + + Args: + split_config (dict): Dataset root path from MSDataset, e.g. + {"train":"local cache path"} or {"evaluation":"local cache path"}. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. Not support yet. + mode: Training or Evaluation. + """ + + def __init__(self, + split_config=None, + preprocessor=None, + mode=None, + *args, + **kwargs) -> None: + EasyCVBaseDataset.__init__( + self, + split_config=split_config, + preprocessor=preprocessor, + mode=mode, + args=args, + kwargs=kwargs) + _WholeBodyCocoTopDownDataset.__init__(self, *args, **kwargs) diff --git a/modelscope/msdatasets/image_denoise_data/data_utils.py b/modelscope/msdatasets/image_denoise_data/data_utils.py deleted file mode 100644 index dd735830..00000000 --- a/modelscope/msdatasets/image_denoise_data/data_utils.py +++ /dev/null @@ -1,152 +0,0 @@ -# ------------------------------------------------------------------------ -# Modified from BasicSR (https://github.com/xinntao/BasicSR) -# Copyright 2018-2020 BasicSR Authors -# ------------------------------------------------------------------------ -import os -from os import path as osp - -import cv2 -import numpy as np -import torch - -from .transforms import mod_crop - - -def img2tensor(imgs, bgr2rgb=True, float32=True): - """Numpy array to tensor. - Args: - imgs (list[ndarray] | ndarray): Input images. - bgr2rgb (bool): Whether to change bgr to rgb. - float32 (bool): Whether to change to float32. - Returns: - list[tensor] | tensor: Tensor images. If returned results only have - one element, just return tensor. - """ - - def _totensor(img, bgr2rgb, float32): - if img.shape[2] == 3 and bgr2rgb: - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img = torch.from_numpy(img.transpose(2, 0, 1)) - if float32: - img = img.float() - return img - - if isinstance(imgs, list): - return [_totensor(img, bgr2rgb, float32) for img in imgs] - else: - return _totensor(imgs, bgr2rgb, float32) - - -def scandir(dir_path, keyword=None, recursive=False, full_path=False): - """Scan a directory to find the interested files. - Args: - dir_path (str): Path of the directory. - keyword (str | tuple(str), optional): File keyword that we are - interested in. Default: None. - recursive (bool, optional): If set to True, recursively scan the - directory. Default: False. - full_path (bool, optional): If set to True, include the dir_path. - Default: False. - Returns: - A generator for all the interested files with relative pathes. - """ - - if (keyword is not None) and not isinstance(keyword, (str, tuple)): - raise TypeError('"suffix" must be a string or tuple of strings') - - root = dir_path - - def _scandir(dir_path, keyword, recursive): - for entry in os.scandir(dir_path): - if not entry.name.startswith('.') and entry.is_file(): - if full_path: - return_path = entry.path - else: - return_path = osp.relpath(entry.path, root) - - if keyword is None: - yield return_path - elif keyword in return_path: - yield return_path - else: - if recursive: - yield from _scandir( - entry.path, keyword=keyword, recursive=recursive) - else: - continue - - return _scandir(dir_path, keyword=keyword, recursive=recursive) - - -def padding(img_lq, img_gt, gt_size): - h, w, _ = img_lq.shape - - h_pad = max(0, gt_size - h) - w_pad = max(0, gt_size - w) - - if h_pad == 0 and w_pad == 0: - return img_lq, img_gt - - img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) - img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) - return img_lq, img_gt - - -def read_img_seq(path, require_mod_crop=False, scale=1): - """Read a sequence of images from a given folder path. - Args: - path (list[str] | str): List of image paths or image folder path. - require_mod_crop (bool): Require mod crop for each image. - Default: False. - scale (int): Scale factor for mod_crop. Default: 1. - Returns: - Tensor: size (t, c, h, w), RGB, [0, 1]. - """ - if isinstance(path, list): - img_paths = path - else: - img_paths = sorted(list(scandir(path, full_path=True))) - imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths] - if require_mod_crop: - imgs = [mod_crop(img, scale) for img in imgs] - imgs = img2tensor(imgs, bgr2rgb=True, float32=True) - imgs = torch.stack(imgs, dim=0) - return imgs - - -def paired_paths_from_folder(folders, keys, filename_tmpl): - """Generate paired paths from folders. - Args: - folders (list[str]): A list of folder path. The order of list should - be [input_folder, gt_folder]. - keys (list[str]): A list of keys identifying folders. The order should - be in consistent with folders, e.g., ['lq', 'gt']. - filename_tmpl (str): Template for each filename. Note that the - template excludes the file extension. Usually the filename_tmpl is - for files in the input folder. - Returns: - list[str]: Returned path list. - """ - assert len(folders) == 2, ( - 'The len of folders should be 2 with [input_folder, gt_folder]. ' - f'But got {len(folders)}') - assert len(keys) == 2, ( - 'The len of keys should be 2 with [input_key, gt_key]. ' - f'But got {len(keys)}') - input_folder, gt_folder = folders - input_key, gt_key = keys - - input_paths = list(scandir(input_folder, keyword='NOISY', recursive=True)) - gt_paths = list(scandir(gt_folder, keyword='GT', recursive=True)) - assert len(input_paths) == len(gt_paths), ( - f'{input_key} and {gt_key} datasets have different number of images: ' - f'{len(input_paths)}, {len(gt_paths)}.') - paths = [] - for idx in range(len(gt_paths)): - gt_path = os.path.join(gt_folder, gt_paths[idx]) - input_path = os.path.join(input_folder, gt_path.replace('GT', 'NOISY')) - - paths.append( - dict([(f'{input_key}_path', input_path), - (f'{gt_key}_path', gt_path)])) - return paths diff --git a/modelscope/msdatasets/image_denoise_data/image_denoise_dataset.py b/modelscope/msdatasets/image_denoise_data/image_denoise_dataset.py deleted file mode 100644 index 96b777e6..00000000 --- a/modelscope/msdatasets/image_denoise_data/image_denoise_dataset.py +++ /dev/null @@ -1,78 +0,0 @@ -import os -from typing import Callable, List, Optional, Tuple, Union - -import cv2 -import numpy as np -from torch.utils import data - -from .data_utils import img2tensor, padding, paired_paths_from_folder -from .transforms import augment, paired_random_crop - - -def default_loader(path): - return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0 - - -class PairedImageDataset(data.Dataset): - """Paired image dataset for image restoration. - """ - - def __init__(self, opt, root, is_train): - super(PairedImageDataset, self).__init__() - self.opt = opt - self.is_train = is_train - self.gt_folder, self.lq_folder = os.path.join( - root, opt.dataroot_gt), os.path.join(root, opt.dataroot_lq) - - if opt.filename_tmpl is not None: - self.filename_tmpl = opt.filename_tmpl - else: - self.filename_tmpl = '{}' - self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], - ['lq', 'gt'], self.filename_tmpl) - - def __getitem__(self, index): - scale = self.opt.scale - - # Load gt and lq images. Dimension order: HWC; channel order: BGR; - # image range: [0, 1], float32. - gt_path = self.paths[index]['gt_path'] - img_gt = default_loader(gt_path) - lq_path = self.paths[index]['lq_path'] - img_lq = default_loader(lq_path) - - # augmentation for training - # if self.is_train: - gt_size = self.opt.gt_size - # padding - img_gt, img_lq = padding(img_gt, img_lq, gt_size) - - # random crop - img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale) - - # flip, rotation - img_gt, img_lq = augment([img_gt, img_lq], self.opt.use_flip, - self.opt.use_rot) - - # BGR to RGB, HWC to CHW, numpy to tensor - img_gt, img_lq = img2tensor([img_gt, img_lq], - bgr2rgb=True, - float32=True) - - return { - 'input': img_lq, - 'target': img_gt, - 'input_path': lq_path, - 'target_path': gt_path - } - - def __len__(self): - return len(self.paths) - - def to_torch_dataset( - self, - columns: Union[str, List[str]] = None, - preprocessors: Union[Callable, List[Callable]] = None, - **format_kwargs, - ): - return self diff --git a/modelscope/msdatasets/ms_dataset.py b/modelscope/msdatasets/ms_dataset.py index 361b8ae0..ad900bab 100644 --- a/modelscope/msdatasets/ms_dataset.py +++ b/modelscope/msdatasets/ms_dataset.py @@ -1,6 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import math import os from typing import (Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Union) @@ -17,19 +16,18 @@ from datasets.utils.file_utils import (is_relative_path, relative_to_absolute_path) from modelscope.hub.repository import DatasetRepository +from modelscope.msdatasets.task_datasets.builder import build_task_dataset +from modelscope.msdatasets.utils.dataset_builder import ExternalDataset +from modelscope.msdatasets.utils.dataset_utils import ( + get_dataset_files, get_target_dataset_structure, load_dataset_builder) +from modelscope.msdatasets.utils.download_utils import DatasetDownloadManager +from modelscope.msdatasets.utils.upload_utils import DatasetUploadManager from modelscope.utils.config import ConfigDict from modelscope.utils.config_ds import MS_DATASETS_CACHE from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE, DEFAULT_DATASET_REVISION, DatasetFormations, DownloadMode, Hubs) from modelscope.utils.logger import get_logger -from .task_datasets.builder import build_task_dataset -from .utils.dataset_builder import ExternalDataset -from .utils.dataset_utils import (get_dataset_files, - get_target_dataset_structure, - load_dataset_builder) -from .utils.download_utils import DatasetDownloadManager -from .utils.upload_utils import DatasetUploadManager logger = get_logger() @@ -234,7 +232,6 @@ class MsDataset: # dataset organized to be compatible with hf format if dataset_formation == DatasetFormations.hf_compatible: dataset_name = dataset_scripts['.py'][0] - download_dataset = dataset_name else: raise FileNotFoundError( f"Couldn't find a dataset script at {relative_to_absolute_path(dataset_name)} " @@ -270,7 +267,8 @@ class MsDataset: raise TypeError('path must be a str or a list, but got' f' {type(dataset_name)}') - if download_dataset: + is_ci_test = os.getenv('CI_TEST') == 'True' + if download_dataset and not is_ci_test: try: api.on_dataset_download( dataset_name=download_dataset, namespace=namespace) @@ -570,15 +568,26 @@ class MsDataset: local_file_path: str, dataset_name: str, namespace: Optional[str] = DEFAULT_DATASET_NAMESPACE, - version: Optional[str] = DEFAULT_DATASET_REVISION) -> None: - """Upload dataset file to the ModelScope Hub. Please login to the ModelScope Hub first. + version: Optional[str] = DEFAULT_DATASET_REVISION, + num_processes: Optional[int] = None, + chunksize: Optional[int] = 1, + filter_hidden_files: Optional[bool] = True) -> None: + """Upload dataset file or directory to the ModelScope Hub. Please login to the ModelScope Hub first. Args: - object_name (str): The object name on ModelScope, in the form of your-dataset-name.zip - local_file_path (str): Local file to upload + object_name (str): The object name on ModelScope, in the form of your-dataset-name.zip or your-dataset-name + local_file_path (str): Local file or directory to upload dataset_name (str): Name of the dataset namespace(str, optional): Namespace of the dataset version: Optional[str]: Version of the dataset + num_processes: Optional[int]: The number of processes used for multi-process uploading. + This is only applicable when local_file_path is a directory, and we are uploading mutliple-files + insided the directory. When None provided, the number returned by os.cpu_count() is used as default. + chunksize: Optional[int]: The chunksize of objects to upload. + For very long iterables using a large value for chunksize can make the job complete much faster than + using the default value of 1. Available if local_file_path is a directory. + filter_hidden_files: Optional[bool]: Whether to filter hidden files. + Available if local_file_path is a directory. Returns: None @@ -586,7 +595,20 @@ class MsDataset: """ _upload_manager = DatasetUploadManager( dataset_name=dataset_name, namespace=namespace, version=version) - _upload_manager.upload(object_name, local_file_path) + + if os.path.isfile(local_file_path): + _upload_manager.upload( + object_name=object_name, local_file_path=local_file_path) + elif os.path.isdir(local_file_path): + _upload_manager.upload_dir( + object_dir_name=object_name, + local_dir_path=local_file_path, + num_processes=num_processes, + chunksize=chunksize, + filter_hidden_files=filter_hidden_files) + else: + raise ValueError( + f'{local_file_path} is not a valid file path or directory') @staticmethod def clone_meta(dataset_work_dir: str, @@ -653,4 +675,8 @@ class MsDataset: revision=revision, auth_token=auth_token, git_path=git_path) - _repo.push(commit_message=commit_message, branch=revision, force=force) + _repo.push( + commit_message=commit_message, + local_branch=revision, + remote_branch=revision, + force=force) diff --git a/modelscope/msdatasets/task_datasets/__init__.py b/modelscope/msdatasets/task_datasets/__init__.py index e2bf5bc1..92764155 100644 --- a/modelscope/msdatasets/task_datasets/__init__.py +++ b/modelscope/msdatasets/task_datasets/__init__.py @@ -11,19 +11,24 @@ if TYPE_CHECKING: from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset from .movie_scene_segmentation import MovieSceneSegmentationDataset from .video_summarization_dataset import VideoSummarizationDataset - from .passage_ranking_dataset import PassageRankingDataset + from .image_inpainting import ImageInpaintingDataset + from .text_ranking_dataset import TextRankingDataset else: _import_structure = { 'base': ['TaskDataset'], 'builder': ['TASK_DATASETS', 'build_task_dataset'], 'torch_base_dataset': ['TorchTaskDataset'], - 'passage_ranking_dataset': ['PassageRankingDataset'], + 'text_ranking_dataset': ['TextRankingDataset'], 'veco_dataset': ['VecoDataset'], 'image_instance_segmentation_coco_dataset': ['ImageInstanceSegmentationCocoDataset'], 'video_summarization_dataset': ['VideoSummarizationDataset'], 'movie_scene_segmentation': ['MovieSceneSegmentationDataset'], + 'image_inpainting': ['ImageInpaintingDataset'], + 'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'], + 'image_portrait_enhancement_dataset': + ['ImagePortraitEnhancementDataset'], } import sys diff --git a/modelscope/msdatasets/task_datasets/audio/__init__.py b/modelscope/msdatasets/task_datasets/audio/__init__.py new file mode 100644 index 00000000..c62a8d9c --- /dev/null +++ b/modelscope/msdatasets/task_datasets/audio/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .kws_farfield_dataset import KWSDataset, KWSDataLoader + +else: + _import_structure = { + 'kws_farfield_dataset': ['KWSDataset', 'KWSDataLoader'], + } + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/task_datasets/audio/kws_farfield_dataset.py b/modelscope/msdatasets/task_datasets/audio/kws_farfield_dataset.py new file mode 100644 index 00000000..8c518ec9 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/audio/kws_farfield_dataset.py @@ -0,0 +1,280 @@ +""" +Used to prepare simulated data. +""" +import math +import os.path +import queue +import threading +import time + +import numpy as np +import torch + +from modelscope.utils.logger import get_logger + +logger = get_logger() + +BLOCK_DEC = 2 +BLOCK_CAT = 3 +FBANK_SIZE = 40 +LABEL_SIZE = 1 +LABEL_GAIN = 100.0 + + +class KWSDataset: + """ + dataset for keyword spotting and vad + conf_basetrain: basetrain configure file path + conf_finetune: finetune configure file path, null allowed + numworkers: no. of workers + basetrainratio: basetrain workers ratio + numclasses: no. of nn output classes, 2 classes to generate vad label + blockdec: block decimation + blockcat: block concatenation + """ + + def __init__(self, + conf_basetrain, + conf_finetune, + numworkers, + basetrainratio, + numclasses, + blockdec=BLOCK_CAT, + blockcat=BLOCK_CAT): + super().__init__() + self.numclasses = numclasses + self.blockdec = blockdec + self.blockcat = blockcat + self.sims_base = [] + self.sims_senior = [] + self.setup_sims(conf_basetrain, conf_finetune, numworkers, + basetrainratio) + + def release(self): + for sim in self.sims_base: + del sim + for sim in self.sims_senior: + del sim + del self.base_conf + del self.senior_conf + logger.info('KWSDataset: Released.') + + def setup_sims(self, conf_basetrain, conf_finetune, numworkers, + basetrainratio): + if not os.path.exists(conf_basetrain): + raise ValueError(f'{conf_basetrain} does not exist!') + if not os.path.exists(conf_finetune): + raise ValueError(f'{conf_finetune} does not exist!') + import py_sound_connect + logger.info('KWSDataset init SoundConnect...') + num_base = math.ceil(numworkers * basetrainratio) + num_senior = numworkers - num_base + # hold by fields to avoid python releasing conf object + self.base_conf = py_sound_connect.ConfigFile(conf_basetrain) + self.senior_conf = py_sound_connect.ConfigFile(conf_finetune) + for i in range(num_base): + fs = py_sound_connect.FeatSimuKWS(self.base_conf.params) + self.sims_base.append(fs) + for i in range(num_senior): + self.sims_senior.append( + py_sound_connect.FeatSimuKWS(self.senior_conf.params)) + logger.info('KWSDataset init SoundConnect finished.') + + def getBatch(self, id): + """ + Generate a data batch + + Args: + id: worker id + + Return: time x channel x feature, label + """ + fs = self.get_sim(id) + fs.processBatch() + # get multi-channel feature vector size + featsize = fs.featSize() + # get label vector size + labelsize = fs.labelSize() + # get minibatch size (time dimension) + # batchsize = fs.featBatchSize() + # no. of fe output channels + numchs = featsize // FBANK_SIZE + # get raw data + fs_feat = fs.feat() + data = np.frombuffer(fs_feat, dtype='float32') + data = data.reshape((-1, featsize + labelsize)) + + # convert float label to int + label = data[:, FBANK_SIZE * numchs:] + + if self.numclasses == 2: + # generate vad label + label[label > 0.0] = 1.0 + else: + # generate kws label + label = np.round(label * LABEL_GAIN) + label[label > self.numclasses - 1] = 0.0 + + # decimated size + size1 = int(np.ceil( + label.shape[0] / self.blockdec)) - self.blockcat + 1 + + # label decimation + label1 = np.zeros((size1, LABEL_SIZE), dtype='float32') + for tau in range(size1): + label1[tau, :] = label[(tau + self.blockcat // 2) + * self.blockdec, :] + + # feature decimation and concatenation + # time x channel x feature + featall = np.zeros((size1, numchs, FBANK_SIZE * self.blockcat), + dtype='float32') + for n in range(numchs): + feat = data[:, FBANK_SIZE * n:FBANK_SIZE * (n + 1)] + + for tau in range(size1): + for i in range(self.blockcat): + featall[tau, n, FBANK_SIZE * i:FBANK_SIZE * (i + 1)] = \ + feat[(tau + i) * self.blockdec, :] + + return torch.from_numpy(featall), torch.from_numpy(label1).long() + + def get_sim(self, id): + num_base = len(self.sims_base) + if id < num_base: + fs = self.sims_base[id] + else: + fs = self.sims_senior[id - num_base] + return fs + + +class Worker(threading.Thread): + """ + id: worker id + dataset: the dataset + pool: queue as the global data buffer + """ + + def __init__(self, id, dataset, pool): + threading.Thread.__init__(self) + + self.id = id + self.dataset = dataset + self.pool = pool + self.isrun = True + self.nn = 0 + + def run(self): + while self.isrun: + self.nn += 1 + logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:1') + # get simulated minibatch + if self.isrun: + data = self.dataset.getBatch(self.id) + logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:2') + + # put data into buffer + if self.isrun: + self.pool.put(data) + logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:3') + + logger.info('KWSDataLoader: Worker {:02d} stopped.'.format(self.id)) + + def stopWorker(self): + """ + stop the worker thread + """ + self.isrun = False + + +class KWSDataLoader: + """ + dataset: the dataset reference + batchsize: data batch size + numworkers: no. of workers + prefetch: prefetch factor + """ + + def __init__(self, dataset, batchsize, numworkers, prefetch=2): + self.dataset = dataset + self.batchsize = batchsize + self.datamap = {} + self.isrun = True + + # data queue + self.pool = queue.Queue(batchsize * prefetch) + + # initialize workers + self.workerlist = [] + for id in range(numworkers): + w = Worker(id, dataset, self.pool) + self.workerlist.append(w) + + def __iter__(self): + return self + + def __next__(self): + while self.isrun: + # get data from common data pool + data = self.pool.get() + self.pool.task_done() + + # group minibatches with the same shape + key = str(data[0].shape) + + batchl = self.datamap.get(key) + if batchl is None: + batchl = [] + self.datamap.update({key: batchl}) + + batchl.append(data) + + # a full data batch collected + if len(batchl) >= self.batchsize: + featbatch = [] + labelbatch = [] + + for feat, label in batchl: + featbatch.append(feat) + labelbatch.append(label) + + batchl.clear() + + feattensor = torch.stack(featbatch, dim=0) + labeltensor = torch.stack(labelbatch, dim=0) + + if feattensor.shape[-2] == 1: + logger.debug('KWSDataLoader: Basetrain batch.') + else: + logger.debug('KWSDataLoader: Finetune batch.') + + return feattensor, labeltensor + + return None, None + + def start(self): + """ + start multi-thread data loader + """ + for w in self.workerlist: + w.start() + + def stop(self): + """ + stop data loader + """ + logger.info('KWSDataLoader: Stopping...') + self.isrun = False + + for w in self.workerlist: + w.stopWorker() + + while not self.pool.empty(): + self.pool.get(block=True, timeout=0.001) + + # wait workers terminated + for w in self.workerlist: + while not self.pool.empty(): + self.pool.get(block=True, timeout=0.001) + w.join() + logger.info('KWSDataLoader: All worker stopped.') diff --git a/modelscope/msdatasets/task_datasets/image_inpainting/__init__.py b/modelscope/msdatasets/task_datasets/image_inpainting/__init__.py new file mode 100644 index 00000000..732a1bd7 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/image_inpainting/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .image_inpainting_dataset import ImageInpaintingDataset diff --git a/modelscope/msdatasets/task_datasets/image_inpainting/aug.py b/modelscope/msdatasets/task_datasets/image_inpainting/aug.py new file mode 100644 index 00000000..445bb9b4 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/image_inpainting/aug.py @@ -0,0 +1,100 @@ +""" +The implementation is borrowed from LaMa, +publicly available at https://github.com/saic-mdal/lama +""" +import imgaug.augmenters as iaa +from albumentations import DualIAATransform, to_tuple + + +class IAAAffine2(DualIAATransform): + """Place a regular grid of points on the input and randomly move the neighbourhood of these point around + via affine transformations. + + Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} + + Args: + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask + """ + + def __init__( + self, + scale=(0.7, 1.3), + translate_percent=None, + translate_px=None, + rotate=0.0, + shear=(-0.1, 0.1), + order=1, + cval=0, + mode='reflect', + always_apply=False, + p=0.5, + ): + super(IAAAffine2, self).__init__(always_apply, p) + self.scale = dict(x=scale, y=scale) + self.translate_percent = to_tuple(translate_percent, 0) + self.translate_px = to_tuple(translate_px, 0) + self.rotate = to_tuple(rotate) + self.shear = dict(x=shear, y=shear) + self.order = order + self.cval = cval + self.mode = mode + + @property + def processor(self): + return iaa.Affine( + self.scale, + self.translate_percent, + self.translate_px, + self.rotate, + self.shear, + self.order, + self.cval, + self.mode, + ) + + def get_transform_init_args_names(self): + return ('scale', 'translate_percent', 'translate_px', 'rotate', + 'shear', 'order', 'cval', 'mode') + + +class IAAPerspective2(DualIAATransform): + """Perform a random four point perspective transform of the input. + + Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} + + Args: + scale ((float, float): standard deviation of the normal distributions. These are used to sample + the random distances of the subimage's corners from the full image's corners. Default: (0.05, 0.1). + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask + """ + + def __init__(self, + scale=(0.05, 0.1), + keep_size=True, + always_apply=False, + p=0.5, + order=1, + cval=0, + mode='replicate'): + super(IAAPerspective2, self).__init__(always_apply, p) + self.scale = to_tuple(scale, 1.0) + self.keep_size = keep_size + self.cval = cval + self.mode = mode + + @property + def processor(self): + return iaa.PerspectiveTransform( + self.scale, + keep_size=self.keep_size, + mode=self.mode, + cval=self.cval) + + def get_transform_init_args_names(self): + return ('scale', 'keep_size') diff --git a/modelscope/msdatasets/task_datasets/image_inpainting/image_inpainting_dataset.py b/modelscope/msdatasets/task_datasets/image_inpainting/image_inpainting_dataset.py new file mode 100644 index 00000000..057b8f88 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/image_inpainting/image_inpainting_dataset.py @@ -0,0 +1,337 @@ +""" +Part of the implementation is borrowed and modified from LaMa, +publicly available at https://github.com/saic-mdal/lama +""" +import glob +import os +import os.path as osp +from enum import Enum + +import albumentations as A +import cv2 +import json +import numpy as np +import torch + +from modelscope.metainfo import Models +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.msdatasets.task_datasets.torch_base_dataset import \ + TorchTaskDataset +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from .aug import IAAAffine2, IAAPerspective2 + +LOGGER = get_logger() + + +class LinearRamp: + + def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0): + self.start_value = start_value + self.end_value = end_value + self.start_iter = start_iter + self.end_iter = end_iter + + def __call__(self, i): + if i < self.start_iter: + return self.start_value + if i >= self.end_iter: + return self.end_value + part = (i - self.start_iter) / (self.end_iter - self.start_iter) + return self.start_value * (1 - part) + self.end_value * part + + +class DrawMethod(Enum): + LINE = 'line' + CIRCLE = 'circle' + SQUARE = 'square' + + +def make_random_superres_mask(shape, + min_step=2, + max_step=4, + min_width=1, + max_width=3): + height, width = shape + mask = np.zeros((height, width), np.float32) + step_x = np.random.randint(min_step, max_step + 1) + width_x = np.random.randint(min_width, min(step_x, max_width + 1)) + offset_x = np.random.randint(0, step_x) + + step_y = np.random.randint(min_step, max_step + 1) + width_y = np.random.randint(min_width, min(step_y, max_width + 1)) + offset_y = np.random.randint(0, step_y) + + for dy in range(width_y): + mask[offset_y + dy::step_y] = 1 + for dx in range(width_x): + mask[:, offset_x + dx::step_x] = 1 + return mask[None, ...] + + +class RandomSuperresMaskGenerator: + + def __init__(self, **kwargs): + self.kwargs = kwargs + + def __call__(self, img, iter_i=None): + return make_random_superres_mask(img.shape[1:], **self.kwargs) + + +def make_random_rectangle_mask(shape, + margin=10, + bbox_min_size=30, + bbox_max_size=100, + min_times=0, + max_times=3): + height, width = shape + mask = np.zeros((height, width), np.float32) + bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2) + times = np.random.randint(min_times, max_times + 1) + for i in range(times): + box_width = np.random.randint(bbox_min_size, bbox_max_size) + box_height = np.random.randint(bbox_min_size, bbox_max_size) + start_x = np.random.randint(margin, width - margin - box_width + 1) + start_y = np.random.randint(margin, height - margin - box_height + 1) + mask[start_y:start_y + box_height, start_x:start_x + box_width] = 1 + return mask[None, ...] + + +class RandomRectangleMaskGenerator: + + def __init__(self, + margin=10, + bbox_min_size=30, + bbox_max_size=100, + min_times=0, + max_times=3, + ramp_kwargs=None): + self.margin = margin + self.bbox_min_size = bbox_min_size + self.bbox_max_size = bbox_max_size + self.min_times = min_times + self.max_times = max_times + self.ramp = LinearRamp( + **ramp_kwargs) if ramp_kwargs is not None else None + + def __call__(self, img, iter_i=None, raw_image=None): + coef = self.ramp(iter_i) if (self.ramp is not None) and ( + iter_i is not None) else 1 + cur_bbox_max_size = int(self.bbox_min_size + 1 + + (self.bbox_max_size - self.bbox_min_size) + * coef) + cur_max_times = int(self.min_times + + (self.max_times - self.min_times) * coef) + return make_random_rectangle_mask( + img.shape[1:], + margin=self.margin, + bbox_min_size=self.bbox_min_size, + bbox_max_size=cur_bbox_max_size, + min_times=self.min_times, + max_times=cur_max_times) + + +def make_random_irregular_mask(shape, + max_angle=4, + max_len=60, + max_width=20, + min_times=0, + max_times=10, + draw_method=DrawMethod.LINE): + draw_method = DrawMethod(draw_method) + + height, width = shape + mask = np.zeros((height, width), np.float32) + times = np.random.randint(min_times, max_times + 1) + for i in range(times): + start_x = np.random.randint(width) + start_y = np.random.randint(height) + for j in range(1 + np.random.randint(5)): + angle = 0.01 + np.random.randint(max_angle) + if i % 2 == 0: + angle = 2 * 3.1415926 - angle + length = 10 + np.random.randint(max_len) + brush_w = 5 + np.random.randint(max_width) + end_x = np.clip( + (start_x + length * np.sin(angle)).astype(np.int32), 0, width) + end_y = np.clip( + (start_y + length * np.cos(angle)).astype(np.int32), 0, height) + if draw_method == DrawMethod.LINE: + cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, + brush_w) + elif draw_method == DrawMethod.CIRCLE: + cv2.circle( + mask, (start_x, start_y), + radius=brush_w, + color=1., + thickness=-1) + elif draw_method == DrawMethod.SQUARE: + radius = brush_w // 2 + mask[start_y - radius:start_y + radius, + start_x - radius:start_x + radius] = 1 + start_x, start_y = end_x, end_y + return mask[None, ...] + + +class RandomIrregularMaskGenerator: + + def __init__(self, + max_angle=4, + max_len=60, + max_width=20, + min_times=0, + max_times=10, + ramp_kwargs=None, + draw_method=DrawMethod.LINE): + self.max_angle = max_angle + self.max_len = max_len + self.max_width = max_width + self.min_times = min_times + self.max_times = max_times + self.draw_method = draw_method + self.ramp = LinearRamp( + **ramp_kwargs) if ramp_kwargs is not None else None + + def __call__(self, img, iter_i=None, raw_image=None): + coef = self.ramp(iter_i) if (self.ramp is not None) and ( + iter_i is not None) else 1 + cur_max_len = int(max(1, self.max_len * coef)) + cur_max_width = int(max(1, self.max_width * coef)) + cur_max_times = int(self.min_times + 1 + + (self.max_times - self.min_times) * coef) + return make_random_irregular_mask( + img.shape[1:], + max_angle=self.max_angle, + max_len=cur_max_len, + max_width=cur_max_width, + min_times=self.min_times, + max_times=cur_max_times, + draw_method=self.draw_method) + + +class MixedMaskGenerator: + + def __init__(self, + irregular_proba=1 / 3, + irregular_kwargs=None, + box_proba=1 / 3, + box_kwargs=None, + segm_proba=1 / 3, + segm_kwargs=None, + squares_proba=0, + squares_kwargs=None, + superres_proba=0, + superres_kwargs=None, + outpainting_proba=0, + outpainting_kwargs=None, + invert_proba=0): + self.probas = [] + self.gens = [] + + if irregular_proba > 0: + self.probas.append(irregular_proba) + if irregular_kwargs is None: + irregular_kwargs = {} + else: + irregular_kwargs = dict(irregular_kwargs) + irregular_kwargs['draw_method'] = DrawMethod.LINE + self.gens.append(RandomIrregularMaskGenerator(**irregular_kwargs)) + + if box_proba > 0: + self.probas.append(box_proba) + if box_kwargs is None: + box_kwargs = {} + self.gens.append(RandomRectangleMaskGenerator(**box_kwargs)) + + if squares_proba > 0: + self.probas.append(squares_proba) + if squares_kwargs is None: + squares_kwargs = {} + else: + squares_kwargs = dict(squares_kwargs) + squares_kwargs['draw_method'] = DrawMethod.SQUARE + self.gens.append(RandomIrregularMaskGenerator(**squares_kwargs)) + + if superres_proba > 0: + self.probas.append(superres_proba) + if superres_kwargs is None: + superres_kwargs = {} + self.gens.append(RandomSuperresMaskGenerator(**superres_kwargs)) + + self.probas = np.array(self.probas, dtype='float32') + self.probas /= self.probas.sum() + self.invert_proba = invert_proba + + def __call__(self, img, iter_i=None, raw_image=None): + kind = np.random.choice(len(self.probas), p=self.probas) + gen = self.gens[kind] + result = gen(img, iter_i=iter_i, raw_image=raw_image) + if self.invert_proba > 0 and random.random() < self.invert_proba: + result = 1 - result + return result + + +def get_transforms(test_mode, out_size): + if not test_mode: + transform = A.Compose([ + IAAPerspective2(scale=(0.0, 0.06)), + IAAAffine2(scale=(0.7, 1.3), rotate=(-40, 40), shear=(-0.1, 0.1)), + A.PadIfNeeded(min_height=out_size, min_width=out_size), + A.OpticalDistortion(), + A.RandomCrop(height=out_size, width=out_size), + A.HorizontalFlip(), + A.CLAHE(), + A.RandomBrightnessContrast( + brightness_limit=0.2, contrast_limit=0.2), + A.HueSaturationValue( + hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), + A.ToFloat() + ]) + else: + transform = A.Compose([ + A.PadIfNeeded(min_height=out_size, min_width=out_size), + A.CenterCrop(height=out_size, width=out_size), + A.ToFloat() + ]) + return transform + + +@TASK_DATASETS.register_module( + Tasks.image_inpainting, module_name=Models.image_inpainting) +class ImageInpaintingDataset(TorchTaskDataset): + + def __init__(self, **kwargs): + split_config = kwargs['split_config'] + LOGGER.info(kwargs) + mode = kwargs.get('test_mode', False) + + self.data_root = next(iter(split_config.values())) + if not osp.exists(self.data_root): + self.data_root = osp.dirname(self.data_root) + assert osp.exists(self.data_root) + mask_gen_kwargs = kwargs.get('mask_gen_kwargs', {}) + out_size = kwargs.get('out_size', 256) + self.mask_generator = MixedMaskGenerator(**mask_gen_kwargs) + self.transform = get_transforms(mode, out_size) + self.in_files = sorted( + list( + glob.glob( + osp.join(self.data_root, '**', '*.jpg'), recursive=True)) + + list( + glob.glob( + osp.join(self.data_root, '**', '*.png'), recursive=True))) + self.iter_i = 0 + + def __len__(self): + return len(self.in_files) + + def __getitem__(self, index): + path = self.in_files[index] + img = cv2.imread(path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = self.transform(image=img)['image'] + img = np.transpose(img, (2, 0, 1)) + # TODO: maybe generate mask before augmentations? slower, but better for segmentation-based masks + mask = self.mask_generator(img, iter_i=self.iter_i) + self.iter_i += 1 + return dict(image=img, mask=mask) diff --git a/modelscope/msdatasets/task_datasets/image_portrait_enhancement/__init__.py b/modelscope/msdatasets/task_datasets/image_portrait_enhancement/__init__.py new file mode 100644 index 00000000..4df24fae --- /dev/null +++ b/modelscope/msdatasets/task_datasets/image_portrait_enhancement/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .image_portrait_enhancement_dataset import ImagePortraitEnhancementDataset + +else: + _import_structure = { + 'image_portrait_enhancement_dataset': + ['ImagePortraitEnhancementDataset'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/task_datasets/image_portrait_enhancement/data_utils.py b/modelscope/msdatasets/task_datasets/image_portrait_enhancement/data_utils.py new file mode 100644 index 00000000..1133d3c2 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/image_portrait_enhancement/data_utils.py @@ -0,0 +1,32 @@ +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ + +import cv2 +import torch + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) diff --git a/modelscope/msdatasets/task_datasets/image_portrait_enhancement/image_portrait_enhancement_dataset.py b/modelscope/msdatasets/task_datasets/image_portrait_enhancement/image_portrait_enhancement_dataset.py new file mode 100644 index 00000000..58d40778 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/image_portrait_enhancement/image_portrait_enhancement_dataset.py @@ -0,0 +1,51 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import cv2 +import numpy as np + +from modelscope.metainfo import Datasets, Models +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.msdatasets.task_datasets.torch_base_dataset import \ + TorchTaskDataset +from modelscope.utils.constant import Tasks +from .data_utils import img2tensor + + +def default_loader(path): + return cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0 + + +@TASK_DATASETS.register_module( + Tasks.image_portrait_enhancement, module_name=Datasets.PairedDataset) +class ImagePortraitEnhancementDataset(TorchTaskDataset): + """Paired image dataset for image portrait enhancement. + """ + + def __init__(self, dataset, is_train): + self.dataset = dataset + self.gt_size = 256 + self.is_train = is_train + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + + # Load gt and lq images. Dimension order: HWC; channel order: BGR; + # image range: [0, 1], float32. + item_dict = self.dataset[index] + gt_path = item_dict['hq:FILE'] + img_gt = default_loader(gt_path) + lq_path = item_dict['lq:FILE'] + img_lq = default_loader(lq_path) + + gt_size = self.gt_size + img_gt = cv2.resize(img_gt, (gt_size, gt_size)) + img_lq = cv2.resize(img_lq, (gt_size, gt_size)) + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], + bgr2rgb=True, + float32=True) + + return {'input': (img_lq - 0.5) / 0.5, 'target': (img_gt - 0.5) / 0.5} diff --git a/modelscope/msdatasets/task_datasets/sidd_image_denoising/__init__.py b/modelscope/msdatasets/task_datasets/sidd_image_denoising/__init__.py new file mode 100644 index 00000000..5376cd7c --- /dev/null +++ b/modelscope/msdatasets/task_datasets/sidd_image_denoising/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .sidd_image_denoising_dataset import SiddImageDenoisingDataset + +else: + _import_structure = { + 'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/task_datasets/sidd_image_denoising/data_utils.py b/modelscope/msdatasets/task_datasets/sidd_image_denoising/data_utils.py new file mode 100644 index 00000000..33fce4c8 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/sidd_image_denoising/data_utils.py @@ -0,0 +1,46 @@ +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ + +import cv2 +import torch + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def padding(img_lq, img_gt, gt_size): + h, w, _ = img_lq.shape + + h_pad = max(0, gt_size - h) + w_pad = max(0, gt_size - w) + + if h_pad == 0 and w_pad == 0: + return img_lq, img_gt + + img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) + img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) + return img_lq, img_gt diff --git a/modelscope/msdatasets/task_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py b/modelscope/msdatasets/task_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py new file mode 100644 index 00000000..3f0cdae0 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py @@ -0,0 +1,62 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import cv2 +import numpy as np + +from modelscope.metainfo import Models +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.msdatasets.task_datasets.torch_base_dataset import \ + TorchTaskDataset +from modelscope.utils.constant import Tasks +from .data_utils import img2tensor, padding +from .transforms import augment, paired_random_crop + + +def default_loader(path): + return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0 + + +@TASK_DATASETS.register_module( + Tasks.image_denoising, module_name=Models.nafnet) +class SiddImageDenoisingDataset(TorchTaskDataset): + """Paired image dataset for image restoration. + """ + + def __init__(self, dataset, opt, is_train): + self.dataset = dataset + self.opt = opt + self.is_train = is_train + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + + # Load gt and lq images. Dimension order: HWC; channel order: BGR; + # image range: [0, 1], float32. + item_dict = self.dataset[index] + gt_path = item_dict['Clean Image:FILE'] + img_gt = default_loader(gt_path) + lq_path = item_dict['Noisy Image:FILE'] + img_lq = default_loader(lq_path) + + # augmentation for training + if self.is_train: + gt_size = self.opt.gt_size + # padding + img_gt, img_lq = padding(img_gt, img_lq, gt_size) + + # random crop + img_gt, img_lq = paired_random_crop( + img_gt, img_lq, gt_size, scale=1) + + # flip, rotation + img_gt, img_lq = augment([img_gt, img_lq], self.opt.use_flip, + self.opt.use_rot) + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], + bgr2rgb=True, + float32=True) + + return {'input': img_lq, 'target': img_gt} diff --git a/modelscope/msdatasets/image_denoise_data/transforms.py b/modelscope/msdatasets/task_datasets/sidd_image_denoising/transforms.py similarity index 100% rename from modelscope/msdatasets/image_denoise_data/transforms.py rename to modelscope/msdatasets/task_datasets/sidd_image_denoising/transforms.py diff --git a/modelscope/msdatasets/task_datasets/passage_ranking_dataset.py b/modelscope/msdatasets/task_datasets/text_ranking_dataset.py similarity index 88% rename from modelscope/msdatasets/task_datasets/passage_ranking_dataset.py rename to modelscope/msdatasets/task_datasets/text_ranking_dataset.py index 517e0d36..54276843 100644 --- a/modelscope/msdatasets/task_datasets/passage_ranking_dataset.py +++ b/modelscope/msdatasets/task_datasets/text_ranking_dataset.py @@ -16,8 +16,8 @@ from .torch_base_dataset import TorchTaskDataset @TASK_DATASETS.register_module( - group_key=Tasks.passage_ranking, module_name=Models.bert) -class PassageRankingDataset(TorchTaskDataset): + group_key=Tasks.text_ranking, module_name=Models.bert) +class TextRankingDataset(TorchTaskDataset): def __init__(self, datasets: Union[Any, List[Any]], @@ -35,12 +35,11 @@ class PassageRankingDataset(TorchTaskDataset): 'positive_passages') self.neg_sequence = self.dataset_config.get('neg_sequence', 'negative_passages') - self.passage_text_fileds = self.dataset_config.get( - 'passage_text_fileds', ['title', 'text']) + self.text_fileds = self.dataset_config.get('text_fileds', + ['title', 'text']) self.qid_field = self.dataset_config.get('qid_field', 'query_id') if mode == ModeKeys.TRAIN: - train_config = kwargs.get('train', {}) - self.neg_samples = train_config.get('neg_samples', 4) + self.neg_samples = self.dataset_config.get('neg_sample', 4) super().__init__(datasets, mode, preprocessor, **kwargs) @@ -58,14 +57,14 @@ class PassageRankingDataset(TorchTaskDataset): pos_sequences = group[self.pos_sequence] pos_sequences = [ - ' '.join([ele[key] for key in self.passage_text_fileds]) + ' '.join([ele[key] for key in self.text_fileds]) for ele in pos_sequences ] labels.extend([1] * len(pos_sequences)) neg_sequences = group[self.neg_sequence] neg_sequences = [ - ' '.join([ele[key] for key in self.passage_text_fileds]) + ' '.join([ele[key] for key in self.text_fileds]) for ele in neg_sequences ] @@ -88,13 +87,13 @@ class PassageRankingDataset(TorchTaskDataset): pos_sequences = group[self.pos_sequence] pos_sequences = [ - ' '.join([ele[key] for key in self.passage_text_fileds]) + ' '.join([ele[key] for key in self.text_fileds]) for ele in pos_sequences ] neg_sequences = group[self.neg_sequence] neg_sequences = [ - ' '.join([ele[key] for key in self.passage_text_fileds]) + ' '.join([ele[key] for key in self.text_fileds]) for ele in neg_sequences ] diff --git a/modelscope/msdatasets/task_datasets/torch_base_dataset.py b/modelscope/msdatasets/task_datasets/torch_base_dataset.py index 014e4faa..4d82b741 100644 --- a/modelscope/msdatasets/task_datasets/torch_base_dataset.py +++ b/modelscope/msdatasets/task_datasets/torch_base_dataset.py @@ -19,6 +19,7 @@ class TorchTaskDataset(TaskDataset, Dataset): preprocessor=None, **kwargs): TaskDataset.__init__(self, datasets, mode, preprocessor, **kwargs) + self.trainer = None def __getitem__(self, index) -> Any: return self.prepare_sample(self._inner_dataset[index]) diff --git a/modelscope/msdatasets/task_datasets/video_summarization_dataset.py b/modelscope/msdatasets/task_datasets/video_summarization_dataset.py index 89deb7ba..34eb0450 100644 --- a/modelscope/msdatasets/task_datasets/video_summarization_dataset.py +++ b/modelscope/msdatasets/task_datasets/video_summarization_dataset.py @@ -1,3 +1,6 @@ +# Part of the implementation is borrowed and modified from PGL-SUM, +# publicly available at https://github.com/e-apostolidis/PGL-SUM + import os import h5py @@ -15,7 +18,7 @@ class VideoSummarizationDataset(TorchTaskDataset): self.mode = mode self.data_filename = os.path.join(root_dir, opt.dataset_file) self.split_filename = os.path.join(root_dir, opt.split_file) - self.split_index = opt.split_index # it represents the current split (varies from 0 to 4) + self.split_index = opt.split_index hdf = h5py.File(self.data_filename, 'r') self.list_frame_features, self.list_gtscores = [], [] self.list_user_summary = [] diff --git a/modelscope/msdatasets/utils/dataset_utils.py b/modelscope/msdatasets/utils/dataset_utils.py index ef42f75f..c7aa7682 100644 --- a/modelscope/msdatasets/utils/dataset_utils.py +++ b/modelscope/msdatasets/utils/dataset_utils.py @@ -6,6 +6,7 @@ from typing import Any, Mapping, Optional, Sequence, Union from datasets.builder import DatasetBuilder +from modelscope.hub.api import HubApi from modelscope.utils.constant import DEFAULT_DATASET_REVISION from modelscope.utils.logger import get_logger from .dataset_builder import MsCsvDatasetBuilder, TaskSpecificDatasetBuilder @@ -77,6 +78,79 @@ def get_target_dataset_structure(dataset_structure: dict, return target_subset_name, target_dataset_structure +def list_dataset_objects(hub_api: HubApi, max_limit: int, is_recursive: bool, + dataset_name: str, namespace: str, + version: str) -> list: + """ + List all of objects for specific dataset. + + Args: + hub_api (class HubApi): HubApi instance. + max_limit (int): Max number of objects. + is_recursive (bool): Whether to list objects recursively. + dataset_name (str): Dataset name. + namespace (str): Namespace. + version (str): Dataset version. + Returns: + res (list): List of objects, i.e., ['train/images/001.png', 'train/images/002.png', 'val/images/001.png', ...] + """ + res = [] + objects = hub_api.list_oss_dataset_objects( + dataset_name=dataset_name, + namespace=namespace, + max_limit=max_limit, + is_recursive=is_recursive, + is_filter_dir=True, + revision=version) + + for item in objects: + object_key = item.get('Key') + res.append(object_key) + + return res + + +def contains_dir(file_map) -> bool: + """ + To check whether input contains at least one directory. + + Args: + file_map (dict): Structure of data files. e.g., {'train': 'train.zip', 'validation': 'val.zip'} + Returns: + True if input contains at least one directory, False otherwise. + """ + res = False + for k, v in file_map.items(): + if isinstance(v, str) and not v.endswith('.zip'): + res = True + break + return res + + +def get_split_objects_map(file_map, objects): + """ + Get the map between dataset split and oss objects. + + Args: + file_map (dict): Structure of data files. e.g., {'train': 'train', 'validation': 'val'}, both of train and val + are dirs. + objects (list): List of oss objects. e.g., ['train/001/1_123.png', 'train/001/1_124.png', 'val/003/3_38.png'] + Returns: + A map of split-objects. e.g., {'train': ['train/001/1_123.png', 'train/001/1_124.png'], + 'validation':['val/003/3_38.png']} + """ + res = {} + for k, v in file_map.items(): + res[k] = [] + + for obj_key in objects: + for k, v in file_map.items(): + if obj_key.startswith(v): + res[k].append(obj_key) + + return res + + def get_dataset_files(subset_split_into: dict, dataset_name: str, namespace: str, @@ -95,14 +169,24 @@ def get_dataset_files(subset_split_into: dict, meta_map = defaultdict(dict) file_map = defaultdict(dict) args_map = defaultdict(dict) - from modelscope.hub.api import HubApi modelscope_api = HubApi() + objects = list_dataset_objects( + hub_api=modelscope_api, + max_limit=-1, + is_recursive=True, + dataset_name=dataset_name, + namespace=namespace, + version=revision) + for split, info in subset_split_into.items(): meta_map[split] = modelscope_api.get_dataset_file_url( info.get('meta', ''), dataset_name, namespace, revision) if info.get('file'): file_map[split] = info['file'] args_map[split] = info.get('args') + + if contains_dir(file_map): + file_map = get_split_objects_map(file_map, objects) return meta_map, file_map, args_map diff --git a/modelscope/msdatasets/utils/download_utils.py b/modelscope/msdatasets/utils/download_utils.py index 2e21bf50..b1c7a5ab 100644 --- a/modelscope/msdatasets/utils/download_utils.py +++ b/modelscope/msdatasets/utils/download_utils.py @@ -10,16 +10,14 @@ from .oss_utils import OssUtilities class DatasetDownloadManager(DownloadManager): - def __init__( - self, - dataset_name: str, - namespace: str, - version: str, - data_dir: Optional[str] = None, - download_config: Optional[DownloadConfig] = None, - base_path: Optional[str] = None, - record_checksums=True, - ): + def __init__(self, + dataset_name: str, + namespace: str, + version: str, + data_dir: Optional[str] = None, + download_config: Optional[DownloadConfig] = None, + base_path: Optional[str] = None, + record_checksums=True): super().__init__(dataset_name, data_dir, download_config, base_path, record_checksums) self._namespace = namespace diff --git a/modelscope/msdatasets/utils/oss_utils.py b/modelscope/msdatasets/utils/oss_utils.py index 4a403876..d7d61e89 100644 --- a/modelscope/msdatasets/utils/oss_utils.py +++ b/modelscope/msdatasets/utils/oss_utils.py @@ -50,11 +50,16 @@ class OssUtilities: progress_callback=self._percentage) return local_path - def upload(self, oss_object_name: str, local_file_path: str) -> str: + def upload(self, oss_object_name: str, local_file_path: str, + indicate_individual_progress: bool) -> str: retry_count = 0 object_key = os.path.join(self.oss_dir, oss_object_name) resumable_store = oss2.ResumableStore( root=self.upload_resumable_tmp_store) + if indicate_individual_progress: + progress_callback = self._percentage + else: + progress_callback = None while True: try: @@ -66,7 +71,7 @@ class OssUtilities: store=resumable_store, multipart_threshold=self.upload_multipart_threshold, part_size=self.upload_part_size, - progress_callback=self._percentage, + progress_callback=progress_callback, num_threads=self.upload_num_threads) break except Exception: diff --git a/modelscope/msdatasets/utils/upload_utils.py b/modelscope/msdatasets/utils/upload_utils.py index 4813b89f..2b4422b2 100644 --- a/modelscope/msdatasets/utils/upload_utils.py +++ b/modelscope/msdatasets/utils/upload_utils.py @@ -1,5 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os +from multiprocessing.dummy import Pool as ThreadPool + +from tqdm import tqdm + from .oss_utils import OssUtilities @@ -19,5 +24,38 @@ class DatasetUploadManager(object): def upload(self, object_name: str, local_file_path: str) -> str: object_key = self.oss_utilities.upload( - oss_object_name=object_name, local_file_path=local_file_path) + oss_object_name=object_name, + local_file_path=local_file_path, + indicate_individual_progress=True) return object_key + + def upload_dir(self, object_dir_name: str, local_dir_path: str, + num_processes: int, chunksize: int, + filter_hidden_files: bool) -> int: + + def run_upload(args): + self.oss_utilities.upload( + oss_object_name=args[0], + local_file_path=args[1], + indicate_individual_progress=False) + + files_list = [] + for root, dirs, files in os.walk(local_dir_path): + for file_name in files: + if filter_hidden_files and file_name.startswith('.'): + continue + # Concatenate directory name and relative path into a oss object key. e.g., train/001/1_1230.png + object_name = os.path.join( + object_dir_name, + root.replace(local_dir_path, '', 1).strip('/'), file_name) + + local_file_path = os.path.join(root, file_name) + files_list.append((object_name, local_file_path)) + + with ThreadPool(processes=num_processes) as pool: + result = list( + tqdm( + pool.imap(run_upload, files_list, chunksize=chunksize), + total=len(files_list))) + + return len(result) diff --git a/modelscope/outputs/__init__.py b/modelscope/outputs/__init__.py new file mode 100644 index 00000000..47e66714 --- /dev/null +++ b/modelscope/outputs/__init__.py @@ -0,0 +1,2 @@ +from .nlp.model_outputs import * # noqa +from .outputs import TASK_OUTPUTS, ModelOutputBase, OutputKeys diff --git a/modelscope/outputs/nlp/__init__.py b/modelscope/outputs/nlp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/outputs/nlp/model_outputs.py b/modelscope/outputs/nlp/model_outputs.py new file mode 100644 index 00000000..dcb37145 --- /dev/null +++ b/modelscope/outputs/nlp/model_outputs.py @@ -0,0 +1,543 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +from modelscope.outputs.outputs import ModelOutputBase + +Tensor = Union['torch.Tensor', 'tf.Tensor'] + + +@dataclass +class TextClassificationModelOutput(ModelOutputBase): + """The output class for text classification models. + + Args: + logits (`Tensor`): The logits output of the model. loss (`Tensor`, + *optional*) The loss of the model, available when training. + hidden_states (`Tensor`, *optional*) Hidden-states of the model at the + output of each layer plus the optional initial embedding outputs. + """ + + logits: Tensor = None + loss: Tensor = None + + +@dataclass +class TokenClassificationModelOutput(ModelOutputBase): + """The output class for token classification models. + logits (`Tensor`): The logits output of the model. + loss (`Tensor`, *optional*) The loss of the model, available when training. + """ + + logits: Tensor = None + loss: Tensor = None + offset_mapping: Tensor = None + + +@dataclass +class FillMaskModelOutput(ModelOutputBase): + """The output class for text classification models. + + Args: + logits (`Tensor`): The logits output of the model. + loss (`Tensor`, *optional*) The loss of the model, available when training. + input_ids (`Tensor`, *optional*) The input id tensor fed into the model. + hidden_states (`Tensor`, *optional*) Hidden-states of the model at the + output of each layer plus the optional initial embedding outputs. + """ + + logits: Tensor = None + loss: Tensor = None + input_ids: Tensor = None + hidden_states: Tensor = None + + +@dataclass +class TokenClassifierOutput(ModelOutputBase): + """ + Base class for outputs of token classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when + `labels` is provided) : + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, + config.num_labels)`): + Classification scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, + if the model has an embedding layer, + one for the output of each + layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the + optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + offset_mapping (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, + sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the sentence. + Selected in the range ``[0, sequence_length - 1]``. + + """ + + loss: Tensor = None + logits: Tensor = None + hidden_states: Tensor = None + attentions: Tensor = None + offset_mapping: Tensor = None + + +@dataclass +class TokenClassifierWithPredictionsOutput(ModelOutputBase): + """ + Base class for outputs of token classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when + `labels` is provided) : + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, + config.num_labels)`): + Classification scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, + if the model has an embedding layer, + one for the output of each + layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the + optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + offset_mapping (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, + sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the sentence. + Selected in the range ``[0, sequence_length - 1]``. + predictions: A PyTorch tensor of the best tag sequence for each batch of shape + (nbest, batch_size, seq_length) + + """ + + loss: Tensor = None + logits: Tensor = None + hidden_states: Tensor = None + attentions: Tensor = None + offset_mapping: Tensor = None + predictions: Tensor = None + + +@dataclass +class BaseModelOutput(ModelOutputBase): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, + sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the + model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, + if the model has an embedding layer, + one for the output of each + layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the + optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + """ + + last_hidden_state: Tensor = None + hidden_states: Optional[Tuple[Tensor]] = None + attentions: Optional[Tuple[Tensor]] = None + + +@dataclass +class BackboneModelOutput(ModelOutputBase): + """The output class for text classification models. + + Args: + last_hidden_state (`Tensor`, *optional*): Sequence of hidden-states at + the output of the last layer of the model. + pooler_output (`Tensor`, *optional*) The tensor of the pooled hidden state. + hidden_states (`Tensor`, *optional*) Hidden-states of the model at + the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: Tensor = None + pooler_output: Tensor = None + hidden_states: Tensor = None + + +@dataclass +class AttentionBackboneModelOutput(BackboneModelOutput): + """The output class for backbones of attention based models. + + Args: + attentions (`tuple(Tensor)`, *optional* Attentions weights after the + attention softmax, used to compute the weighted average in the + self-attention heads. + """ + attentions: Tensor = None + past_key_values: Tensor = None + cross_attentions: Tensor = None + + +@dataclass +class AttentionTextClassificationModelOutput(TextClassificationModelOutput): + """The output class for backbones of attention based models. + + Args: + attentions (`tuple(Tensor)`, *optional* Attentions weights after the + attention softmax, used to compute the weighted average in the + self-attention heads. + """ + attentions: Tensor = None + hidden_states: Tensor = None + + +@dataclass +class AttentionTokenClassificationModelOutput(TokenClassificationModelOutput): + """The output class for backbones of attention based models. + + Args: + attentions (`tuple(Tensor)`, *optional* Attentions weights after the attention softmax, + used to compute the weighted average in the self-attention heads. + """ + attentions: Tensor = None + hidden_states: Tensor = None + + +@dataclass +class AttentionFillMaskModelOutput(FillMaskModelOutput): + """The output class for the fill mask and attention based models. + + Args: + attentions (`tuple(Tensor)`, *optional* Attentions weights after the + attention softmax, used to compute the weighted average in the + self-attention heads. + """ + attentions: Tensor = None + + +@dataclass +class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutputBase): + """ + Base class for model's outputs that also contains a pooling of the last + hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, + sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the + model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, + hidden_size)`): + Last layer hidden-state of the first token of the sequence + (classification token) after further processing through the layers + used for the auxiliary pretraining task. E.g. for BERT-family of + models, this returns the classification token after processing + through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction + (classification) objective during pretraining. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, + if the model has an embedding layer, + one for the output of each + layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the + optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_attentions=True` and `config.add_cross_attention=True` is passed + or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the + attention softmax, used to compute the weighted average in the + cross-attention heads. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned + when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, + with each tuple having 2 tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, + embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the + self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that + can be used (see `past_key_values` input) to speed up sequential + decoding. + """ + + last_hidden_state: Tensor = None + pooler_output: Tensor = None + hidden_states: Tensor = None + past_key_values: Tensor = None + attentions: Tensor = None + cross_attentions: Tensor = None + + +@dataclass +class BaseModelOutputWithPastAndCrossAttentions(ModelOutputBase): + """ + Base class for model's outputs that may also contain a past key/values (to + speed up sequential decoding). + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, + sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the + model. + + If `past_key_values` is used only the last hidden-state of the + sequences of shape `(batch_size, 1, hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned + when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, + with each tuple having 2 tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, + embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the + self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that + can be used (see `past_key_values` input) to speed up sequential + decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, + if the model has an embedding layer, + one for the output of each + layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the + optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_attentions=True` and `config.add_cross_attention=True` is passed + or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the + attention softmax, used to compute the weighted average in the + cross-attention heads. + """ + + last_hidden_state: Tensor = None + past_key_values: Tensor = None + hidden_states: Tensor = None + attentions: Tensor = None + cross_attentions: Tensor = None + + +@dataclass +class Seq2SeqModelOutput(ModelOutputBase): + """ + Base class for model encoder's outputs that also contains : pre-computed + hidden states that can speed up sequential decoding. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, + sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the + decoder of the model. + + If `past_key_values` is used only the last hidden-state of the + sequences of shape `(batch_size, 1, hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned + when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, + with each tuple having 2 tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, + embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the + self-attention blocks and in the cross-attention blocks) that can be + used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned + when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, + if the model has an embedding layer, + one for the output of each + layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the + optional initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned + when `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used + to compute the weighted average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the + attention softmax, used to compute the weighted average in the + cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, + sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the + encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned + when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, + if the model has an embedding layer, + one for the output of each + layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the + optional initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned + when `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used + to compute the weighted average in the self-attention heads. + """ + + last_hidden_state: Tensor = None + past_key_values: Optional[Tuple[Tuple[Tensor]]] = None + decoder_hidden_states: Optional[Tuple[Tensor]] = None + decoder_attentions: Optional[Tuple[Tensor]] = None + cross_attentions: Optional[Tuple[Tensor]] = None + encoder_last_hidden_state: Optional[Tensor] = None + encoder_hidden_states: Optional[Tuple[Tensor]] = None + encoder_attentions: Optional[Tuple[Tensor]] = None + + +@dataclass +class Seq2SeqLMOutput(ModelOutputBase): + """ + Base class for sequence-to-sequence language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when + `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, + config.vocab_size)`): + Prediction scores of the language modeling head (scores for each + vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned + when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, + with each tuple having 2 tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, + embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the + self-attention blocks and in the cross-attention blocks) that can be + used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned + when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, + if the model has an embedding layer, + one for the output of each + layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the + initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned + when `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used + to compute the weighted average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the + attention softmax, used to compute the weighted average in the + cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, + sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the + encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned + when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, + if the model has an embedding layer, + one for the output of each + layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the + initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned + when `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used + to compute the weighted average in the self-attention heads. + """ + + loss: Optional[Tensor] = None + logits: Tensor = None + past_key_values: Optional[Tuple[Tuple[Tensor]]] = None + decoder_hidden_states: Optional[Tuple[Tensor]] = None + decoder_attentions: Optional[Tuple[Tensor]] = None + cross_attentions: Optional[Tuple[Tensor]] = None + encoder_last_hidden_state: Optional[Tensor] = None + encoder_hidden_states: Optional[Tuple[Tensor]] = None + encoder_attentions: Optional[Tuple[Tensor]] = None diff --git a/modelscope/outputs.py b/modelscope/outputs/outputs.py similarity index 78% rename from modelscope/outputs.py rename to modelscope/outputs/outputs.py index d8d2458a..cbdeede4 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs/outputs.py @@ -1,4 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from collections import OrderedDict, namedtuple +from dataclasses import dataclass, fields from modelscope.utils.constant import Tasks @@ -36,10 +38,15 @@ class OutputKeys(object): UUID = 'uuid' WORD = 'word' KWS_LIST = 'kws_list' + SQL_STRING = 'sql_string' + SQL_QUERY = 'sql_query' HISTORY = 'history' + QUERT_RESULT = 'query_result' TIMESTAMPS = 'timestamps' - SPLIT_VIDEO_NUM = 'split_video_num' - SPLIT_META_LIST = 'split_meta_list' + SHOT_NUM = 'shot_num' + SCENE_NUM = 'scene_num' + SCENE_META_LIST = 'scene_meta_list' + SHOT_META_LIST = 'shot_meta_list' TASK_OUTPUTS = { @@ -87,6 +94,25 @@ TASK_OUTPUTS = { Tasks.face_detection: [OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS], + # card detection result for single sample + # { + # "scores": [0.9, 0.1, 0.05, 0.05] + # "boxes": [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ], + # "keypoints": [ + # [x1, y1, x2, y2, x3, y3, x4, y4], + # [x1, y1, x2, y2, x3, y3, x4, y4], + # [x1, y1, x2, y2, x3, y3, x4, y4], + # [x1, y1, x2, y2, x3, y3, x4, y4], + # ], + # } + Tasks.card_detection: + [OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS], + # facial expression recognition result for single sample # { # "scores": [0.9, 0.1, 0.02, 0.02, 0.02, 0.02, 0.02], @@ -141,6 +167,32 @@ TASK_OUTPUTS = { Tasks.image_object_detection: [OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.BOXES], + # video object detection result for single sample + # { + + # "scores": [[0.8, 0.25, 0.05, 0.05], [0.9, 0.1, 0.05, 0.05]] + # "labels": [["person", "traffic light", "car", "bus"], + # ["person", "traffic light", "car", "bus"]] + # "boxes": + # [ + # [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ], + # [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ] + # ], + + # } + Tasks.video_object_detection: + [OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.BOXES], + # instance segmentation result for single sample # { # "scores": [0.9, 0.1, 0.05, 0.05], @@ -175,6 +227,7 @@ TASK_OUTPUTS = { Tasks.image_denoising: [OutputKeys.OUTPUT_IMG], Tasks.image_portrait_enhancement: [OutputKeys.OUTPUT_IMG], Tasks.crowd_counting: [OutputKeys.SCORES, OutputKeys.OUTPUT_IMG], + Tasks.image_inpainting: [OutputKeys.OUTPUT_IMG], # image generation task result for a single image # {"output_img": np.array with shape (h, w, 3)} @@ -182,6 +235,7 @@ TASK_OUTPUTS = { Tasks.image_to_image_translation: [OutputKeys.OUTPUT_IMG], Tasks.image_style_transfer: [OutputKeys.OUTPUT_IMG], Tasks.image_portrait_stylization: [OutputKeys.OUTPUT_IMG], + Tasks.image_body_reshaping: [OutputKeys.OUTPUT_IMG], # live category recognition result for single video # { @@ -198,7 +252,7 @@ TASK_OUTPUTS = { # human body keypoints detection result for single sample # { - # "poses": [ + # "keypoints": [ # [[x, y]*15], # [[x, y]*15], # [[x, y]*15] @@ -215,11 +269,11 @@ TASK_OUTPUTS = { # ] # } Tasks.body_2d_keypoints: - [OutputKeys.POSES, OutputKeys.SCORES, OutputKeys.BOXES], + [OutputKeys.KEYPOINTS, OutputKeys.SCORES, OutputKeys.BOXES], # 3D human body keypoints detection result for single sample # { - # "poses": [ # 3d pose coordinate in camera coordinate + # "keypoints": [ # 3d pose coordinate in camera coordinate # [[x, y, z]*17], # joints of per image # [[x, y, z]*17], # ... @@ -233,7 +287,7 @@ TASK_OUTPUTS = { # and is only avaialbe when the "render" option is enabled. # } Tasks.body_3d_keypoints: - [OutputKeys.POSES, OutputKeys.TIMESTAMPS, OutputKeys.OUTPUT_VIDEO], + [OutputKeys.KEYPOINTS, OutputKeys.TIMESTAMPS, OutputKeys.OUTPUT_VIDEO], # 2D hand keypoints result for single sample # { @@ -309,19 +363,67 @@ TASK_OUTPUTS = { Tasks.shop_segmentation: [OutputKeys.MASKS], # movide scene segmentation result for a single video # { - # "split_video_num":3, - # "split_meta_list": + # "shot_num":15, + # "shot_meta_list": + # [ + # { + # "frame": [start_frame, end_frame], + # "timestamps": [start_timestamp, end_timestamp] # ['00:00:01.133', '00:00:02.245'] + # + # } + # ] + # "scene_num":3, + # "scene_meta_list": # [ # { # "shot": [0,1,2], # "frame": [start_frame, end_frame], - # "timestamp": [start_timestamp, end_timestamp] # ['00:00:01.133', '00:00:02.245'] + # "timestamps": [start_timestamp, end_timestamp] # ['00:00:01.133', '00:00:02.245'] # } # ] # # } - Tasks.movie_scene_segmentation: - [OutputKeys.SPLIT_VIDEO_NUM, OutputKeys.SPLIT_META_LIST], + Tasks.movie_scene_segmentation: [ + OutputKeys.SHOT_NUM, OutputKeys.SHOT_META_LIST, OutputKeys.SCENE_NUM, + OutputKeys.SCENE_META_LIST + ], + + # human whole body keypoints detection result for single sample + # { + # "keypoints": [ + # [[x, y]*133], + # [[x, y]*133], + # [[x, y]*133] + # ] + # "boxes": [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ] + # } + Tasks.human_wholebody_keypoint: [OutputKeys.KEYPOINTS, OutputKeys.BOXES], + + # video summarization result for a single video + # { + # "output": + # [ + # { + # "frame": [start_frame, end_frame] + # "timestamps": [start_time, end_time] + # }, + # { + # "frame": [start_frame, end_frame] + # "timestamps": [start_time, end_time] + # } + # ] + # } + Tasks.video_summarization: [OutputKeys.OUTPUT], + + # referring video object segmentation result for a single video + # { + # "masks": [np.array # 2D array with shape [height, width]] + # } + Tasks.referring_video_object_segmentation: [OutputKeys.MASKS], # ============ nlp tasks =================== @@ -388,7 +490,6 @@ TASK_OUTPUTS = { # ] # } Tasks.word_segmentation: [OutputKeys.OUTPUT, OutputKeys.LABELS], - Tasks.part_of_speech: [OutputKeys.OUTPUT, OutputKeys.LABELS], # TODO @wenmeng.zwm support list of result check # named entity recognition result for single sample @@ -399,6 +500,7 @@ TASK_OUTPUTS = { # ] # } Tasks.named_entity_recognition: [OutputKeys.OUTPUT], + Tasks.part_of_speech: [OutputKeys.OUTPUT], # text_error_correction result for a single sample # { @@ -406,7 +508,7 @@ TASK_OUTPUTS = { # } Tasks.text_error_correction: [OutputKeys.OUTPUT], Tasks.sentence_embedding: [OutputKeys.TEXT_EMBEDDING, OutputKeys.SCORES], - Tasks.passage_ranking: [OutputKeys.SCORES], + Tasks.text_ranking: [OutputKeys.SCORES], # text generation result for single sample # { @@ -506,18 +608,12 @@ TASK_OUTPUTS = { # } Tasks.task_oriented_conversation: [OutputKeys.OUTPUT], - # conversational text-to-sql result for single sample - # { - # "text": "SELECT shop.Name FROM shop." - # } - Tasks.conversational_text_to_sql: [OutputKeys.TEXT], - # table-question-answering result for single sample # { # "sql": "SELECT shop.Name FROM shop." # "sql_history": {sel: 0, agg: 0, conds: [[0, 0, 'val']]} # } - Tasks.table_question_answering: [OutputKeys.OUTPUT, OutputKeys.HISTORY], + Tasks.table_question_answering: [OutputKeys.OUTPUT], # ============ audio tasks =================== # asr result for single sample @@ -558,6 +654,7 @@ TASK_OUTPUTS = { # "caption": "this is an image caption text." # } Tasks.image_captioning: [OutputKeys.CAPTION], + Tasks.ocr_recognition: [OutputKeys.TEXT], # visual grounding result for single sample # { @@ -605,8 +702,9 @@ TASK_OUTPUTS = { # "text_embedding": np.array with shape [1, D], # "similarity": float # } - Tasks.multi_modal_similarity: - [OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.SCORES], + Tasks.multi_modal_similarity: [ + OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.SCORES + ], # VQA result for a sample # {"text": "this is a text answser. "} @@ -664,12 +762,13 @@ TASK_OUTPUTS = { # } Tasks.hand_static: [OutputKeys.OUTPUT], - # 'output': [ - # [2, 75, 287, 240, 510, 0.8335018754005432], - # [1, 127, 83, 332, 366, 0.9175254702568054], - # [0, 0, 0, 367, 639, 0.9693422317504883]] + # { 'labels': [2, 1, 0], + # 'boxes':[[[78, 282, 240, 504], [127, 87, 332, 370], [0, 0, 367, 639]] + # 'scores':[0.8202137351036072, 0.8987470269203186, 0.9679114818572998] # } - Tasks.face_human_hand_detection: [OutputKeys.OUTPUT], + Tasks.face_human_hand_detection: [ + OutputKeys.LABELS, OutputKeys.BOXES, OutputKeys.SCORES + ], # { # {'output': 'Happiness', 'boxes': (203, 104, 663, 564)} @@ -683,3 +782,60 @@ TASK_OUTPUTS = { # } Tasks.product_segmentation: [OutputKeys.MASKS], } + + +class ModelOutputBase(list): + + def __post_init__(self): + self.reconstruct() + self.post_init = True + + def reconstruct(self): + # Low performance, but low frequency. + self.clear() + for idx, key in enumerate(self.keys()): + self.append(getattr(self, key)) + + def __getitem__(self, item): + if isinstance(item, str): + if hasattr(self, item): + return getattr(self, item) + elif isinstance(item, (int, slice)): + return super().__getitem__(item) + raise IndexError(f'No Index {item} found in the dataclass.') + + def __setitem__(self, key, value): + if isinstance(key, str): + if key in [f.name for f in fields(self)]: + if key not in self.keys(): + super().__setattr__(key, value) + self.reconstruct() + elif id(getattr(self, key)) != id(value): + super().__setattr__(key, value) + super().__setitem__(self.keys().index(key), value) + else: + super().__setattr__(key, value) + elif isinstance(key, int): + super().__setitem__(key, value) + key_name = self.keys()[key] + super().__setattr__(key_name, value) + + def __setattr__(self, key, value): + if getattr(self, 'post_init', False): + return self.__setitem__(key, value) + else: + return super().__setattr__(key, value) + + def keys(self): + return [ + f.name for f in fields(self) if getattr(self, f.name) is not None + ] + + def items(self): + return self.to_dict().items() + + def to_dict(self): + output = OrderedDict() + for key in self.keys(): + output[key] = getattr(self, key) + return output diff --git a/modelscope/pipeline_inputs.py b/modelscope/pipeline_inputs.py new file mode 100644 index 00000000..13560229 --- /dev/null +++ b/modelscope/pipeline_inputs.py @@ -0,0 +1,244 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import cv2 +import numpy as np +from PIL import Image + +from modelscope.models.base.base_head import Input +from modelscope.utils.constant import Tasks + + +class InputKeys(object): + IMAGE = 'image' + TEXT = 'text' + VIDEO = 'video' + + +class InputType(object): + IMAGE = 'image' + TEXT = 'text' + AUDIO = 'audio' + VIDEO = 'video' + BOX = 'box' + DICT = 'dict' + LIST = 'list' + INT = 'int' + + +INPUT_TYPE = { + InputType.IMAGE: (str, np.ndarray, Image.Image), + InputType.TEXT: str, + InputType.AUDIO: (str, bytes, np.ndarray), + InputType.VIDEO: (str, np.ndarray, cv2.VideoCapture), + InputType.BOX: (list, np.ndarray), + InputType.DICT: (dict, type(None)), + InputType.LIST: (list, type(None)), + InputType.INT: int, +} + + +def check_input_type(input_type, input): + expected_type = INPUT_TYPE[input_type] + assert isinstance(input, expected_type), \ + f'invalid input type for {input_type}, expected {expected_type} but got {type(input)}\n {input}' + + +TASK_INPUTS = { + # if task input is single var, value is InputType + # if task input is a tuple, value is tuple of InputType + # if task input is a dict, value is a dict of InputType, where key + # equals the one needed in pipeline input dict + # if task input is a list, value is a set of input format, in which + # each elements corresponds to one input format as described above. + # ============ vision tasks =================== + Tasks.ocr_detection: + InputType.IMAGE, + Tasks.ocr_recognition: + InputType.IMAGE, + Tasks.face_2d_keypoints: + InputType.IMAGE, + Tasks.face_detection: + InputType.IMAGE, + Tasks.facial_expression_recognition: + InputType.IMAGE, + Tasks.face_recognition: + InputType.IMAGE, + Tasks.human_detection: + InputType.IMAGE, + Tasks.face_image_generation: + InputType.INT, + Tasks.image_classification: + InputType.IMAGE, + Tasks.image_object_detection: + InputType.IMAGE, + Tasks.image_segmentation: + InputType.IMAGE, + Tasks.portrait_matting: + InputType.IMAGE, + + # image editing task result for a single image + Tasks.skin_retouching: + InputType.IMAGE, + Tasks.image_super_resolution: + InputType.IMAGE, + Tasks.image_colorization: + InputType.IMAGE, + Tasks.image_color_enhancement: + InputType.IMAGE, + Tasks.image_denoising: + InputType.IMAGE, + Tasks.image_portrait_enhancement: + InputType.IMAGE, + Tasks.crowd_counting: + InputType.IMAGE, + Tasks.image_inpainting: { + 'img': InputType.IMAGE, + 'mask': InputType.IMAGE, + }, + + # image generation task result for a single image + Tasks.image_to_image_generation: + InputType.IMAGE, + Tasks.image_to_image_translation: + InputType.IMAGE, + Tasks.image_style_transfer: { + 'content': InputType.IMAGE, + 'style': InputType.IMAGE, + }, + Tasks.image_portrait_stylization: + InputType.IMAGE, + Tasks.live_category: + InputType.VIDEO, + Tasks.action_recognition: + InputType.VIDEO, + Tasks.body_2d_keypoints: + InputType.IMAGE, + Tasks.body_3d_keypoints: + InputType.VIDEO, + Tasks.hand_2d_keypoints: + InputType.IMAGE, + Tasks.video_single_object_tracking: (InputType.VIDEO, InputType.BOX), + Tasks.video_category: + InputType.VIDEO, + Tasks.product_retrieval_embedding: + InputType.IMAGE, + Tasks.video_embedding: + InputType.VIDEO, + Tasks.virtual_try_on: (InputType.IMAGE, InputType.IMAGE, InputType.IMAGE), + Tasks.text_driven_segmentation: { + InputKeys.IMAGE: InputType.IMAGE, + InputKeys.TEXT: InputType.TEXT + }, + Tasks.shop_segmentation: + InputType.IMAGE, + Tasks.movie_scene_segmentation: + InputType.VIDEO, + + # ============ nlp tasks =================== + Tasks.text_classification: [ + InputType.TEXT, + (InputType.TEXT, InputType.TEXT), + { + 'text': InputType.TEXT, + 'text2': InputType.TEXT + }, + ], + Tasks.sentence_similarity: (InputType.TEXT, InputType.TEXT), + Tasks.nli: (InputType.TEXT, InputType.TEXT), + Tasks.sentiment_classification: + InputType.TEXT, + Tasks.zero_shot_classification: + InputType.TEXT, + Tasks.relation_extraction: + InputType.TEXT, + Tasks.translation: + InputType.TEXT, + Tasks.word_segmentation: [InputType.TEXT, { + 'text': InputType.TEXT, + }], + Tasks.part_of_speech: + InputType.TEXT, + Tasks.named_entity_recognition: + InputType.TEXT, + Tasks.text_error_correction: + InputType.TEXT, + Tasks.sentence_embedding: { + 'source_sentence': InputType.LIST, + 'sentences_to_compare': InputType.LIST, + }, + Tasks.text_ranking: (InputType.TEXT, InputType.TEXT), + Tasks.text_generation: + InputType.TEXT, + Tasks.fill_mask: + InputType.TEXT, + Tasks.task_oriented_conversation: { + 'user_input': InputType.TEXT, + 'history': InputType.DICT, + }, + Tasks.table_question_answering: { + 'question': InputType.TEXT, + 'history_sql': InputType.DICT, + }, + Tasks.faq_question_answering: { + 'query_set': InputType.LIST, + 'support_set': InputType.LIST, + }, + + # ============ audio tasks =================== + Tasks.auto_speech_recognition: + InputType.AUDIO, + Tasks.speech_signal_process: + InputType.AUDIO, + Tasks.acoustic_echo_cancellation: { + 'nearend_mic': InputType.AUDIO, + 'farend_speech': InputType.AUDIO + }, + Tasks.acoustic_noise_suppression: + InputType.AUDIO, + Tasks.text_to_speech: + InputType.TEXT, + Tasks.keyword_spotting: + InputType.AUDIO, + + # ============ multi-modal tasks =================== + Tasks.image_captioning: [InputType.IMAGE, { + 'image': InputType.IMAGE, + }], + Tasks.visual_grounding: { + 'image': InputType.IMAGE, + 'text': InputType.TEXT + }, + Tasks.text_to_image_synthesis: { + 'text': InputType.TEXT, + }, + Tasks.multi_modal_embedding: { + 'img': InputType.IMAGE, + 'text': InputType.TEXT + }, + Tasks.generative_multi_modal_embedding: { + 'image': InputType.IMAGE, + 'text': InputType.TEXT + }, + Tasks.multi_modal_similarity: { + 'img': InputType.IMAGE, + 'text': InputType.TEXT + }, + Tasks.visual_question_answering: { + 'image': InputType.IMAGE, + 'text': InputType.TEXT + }, + Tasks.visual_entailment: { + 'image': InputType.IMAGE, + 'text': InputType.TEXT, + 'text2': InputType.TEXT, + }, + Tasks.action_detection: + InputType.VIDEO, + Tasks.image_reid_person: + InputType.IMAGE, + Tasks.video_inpainting: { + 'video_input_path': InputType.TEXT, + 'video_output_path': InputType.TEXT, + 'mask_path': InputType.TEXT, + } +} diff --git a/modelscope/pipelines/audio/asr_inference_pipeline.py b/modelscope/pipelines/audio/asr_inference_pipeline.py index 282d1184..6a4864bf 100644 --- a/modelscope/pipelines/audio/asr_inference_pipeline.py +++ b/modelscope/pipelines/audio/asr_inference_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict, List, Sequence, Tuple, Union import yaml @@ -46,22 +47,28 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): if isinstance(audio_in, str): # load pcm data from url if audio_in is url str - self.audio_in = load_bytes_from_url(audio_in) + self.audio_in, checking_audio_fs = load_bytes_from_url(audio_in) elif isinstance(audio_in, bytes): # load pcm data from wav data if audio_in is wave format - self.audio_in = extract_pcm_from_wav(audio_in) + self.audio_in, checking_audio_fs = extract_pcm_from_wav(audio_in) else: self.audio_in = audio_in + # set the sample_rate of audio_in if checking_audio_fs is valid + if checking_audio_fs is not None: + self.audio_fs = checking_audio_fs + if recog_type is None or audio_format is None: self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( audio_in=self.audio_in, recog_type=recog_type, audio_format=audio_format) - if hasattr(asr_utils, 'sample_rate_checking') and audio_fs is None: - self.audio_fs = asr_utils.sample_rate_checking( + if hasattr(asr_utils, 'sample_rate_checking'): + checking_audio_fs = asr_utils.sample_rate_checking( self.audio_in, self.audio_format) + if checking_audio_fs is not None: + self.audio_fs = checking_audio_fs if self.preprocessor is None: self.preprocessor = WavToScp() @@ -79,7 +86,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): logger.info(f"Decoding with {inputs['audio_format']} files ...") - data_cmd: Sequence[Tuple[str, str]] + data_cmd: Sequence[Tuple[str, str, str]] if inputs['audio_format'] == 'wav' or inputs['audio_format'] == 'pcm': data_cmd = ['speech', 'sound'] elif inputs['audio_format'] == 'kaldi_ark': @@ -87,6 +94,9 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): elif inputs['audio_format'] == 'tfrecord': data_cmd = ['speech', 'tfrecord'] + if inputs.__contains__('mvn_file'): + data_cmd.append(inputs['mvn_file']) + # generate asr inference command cmd = { 'model_type': inputs['model_type'], diff --git a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py index 866b8d0b..db6fc65d 100644 --- a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py +++ b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os from typing import Any, Dict, List, Union @@ -36,6 +37,12 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): **kwargs) -> Dict[str, Any]: if 'keywords' in kwargs.keys(): self.keywords = kwargs['keywords'] + if isinstance(self.keywords, str): + word_list = [] + word = {} + word['keyword'] = self.keywords + word_list.append(word) + self.keywords = word_list else: self.keywords = None @@ -44,10 +51,10 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): if isinstance(audio_in, str): # load pcm data from url if audio_in is url str - audio_in = load_bytes_from_url(audio_in) + audio_in, audio_fs = load_bytes_from_url(audio_in) elif isinstance(audio_in, bytes): # load pcm data from wav data if audio_in is wave format - audio_in = extract_pcm_from_wav(audio_in) + audio_in, audio_fs = extract_pcm_from_wav(audio_in) output = self.preprocessor.forward(self.model.forward(), audio_in) output = self.forward(output) @@ -95,6 +102,9 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): pos_list=pos_kws_list, neg_list=neg_kws_list) + if 'kws_list' not in rst_dict: + rst_dict['kws_list'] = [] + return rst_dict def run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index c5db2b57..bca80502 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -13,6 +13,7 @@ import numpy as np from modelscope.models.base import Model from modelscope.msdatasets import MsDataset from modelscope.outputs import TASK_OUTPUTS +from modelscope.pipeline_inputs import TASK_INPUTS, check_input_type from modelscope.preprocessors import Preprocessor from modelscope.utils.config import Config from modelscope.utils.constant import Frameworks, ModelFile @@ -32,7 +33,7 @@ if is_tf_available(): Tensor = Union['torch.Tensor', 'tf.Tensor'] Input = Union[str, tuple, MsDataset, 'Image.Image', 'numpy.ndarray'] -InputModel = Union[str, Model] +InputModel = Union[str, Model, 'torch.nn.Module'] logger = get_logger() @@ -48,13 +49,7 @@ class Pipeline(ABC): return Model.from_pretrained( model, model_prefetched=True, device=self.device_name) if is_model(model) else model - elif isinstance(model, Model): - return model else: - if model and not isinstance(model, str): - raise ValueError( - f'model type for single model is either str or Model, but got type {type(model)}' - ) return model def initiate_multiple_models(self, input_models: List[InputModel]): @@ -138,12 +133,10 @@ class Pipeline(ABC): def _get_framework(self) -> str: frameworks = [] for m in self.models: - if isinstance(m, Model): - model_dir = m.model_dir - else: - assert isinstance(m, - str), 'model should be either str or Model.' + if isinstance(m, str): model_dir = m + else: + model_dir = m.model_dir cfg_file = osp.join(model_dir, ModelFile.CONFIGURATION) cfg = Config.from_file(cfg_file) frameworks.append(cfg.framework) @@ -210,7 +203,7 @@ class Pipeline(ABC): preprocess_params = kwargs.get('preprocess_params', {}) forward_params = kwargs.get('forward_params', {}) postprocess_params = kwargs.get('postprocess_params', {}) - + self._check_input(input) out = self.preprocess(input, **preprocess_params) with device_placement(self.framework, self.device_name): if self.framework == Frameworks.torch: @@ -225,6 +218,46 @@ class Pipeline(ABC): self._check_output(out) return out + def _check_input(self, input): + task_name = self.group_key + if task_name in TASK_INPUTS: + input_type = TASK_INPUTS[task_name] + + # if multiple input formats are defined, we first + # found the one that match input data and check + if isinstance(input_type, list): + matched_type = None + for t in input_type: + if isinstance(input, (dict, tuple)): + if type(t) == type(input): + matched_type = t + break + elif isinstance(t, str): + matched_type = t + break + if matched_type is None: + err_msg = 'input data format for current pipeline should be one of following: \n' + for t in input_type: + err_msg += f'{t}\n' + raise ValueError(err_msg) + else: + input_type = matched_type + + if isinstance(input_type, str): + check_input_type(input_type, input) + elif isinstance(input_type, tuple): + for t, input_ele in zip(input_type, input): + check_input_type(t, input_ele) + elif isinstance(input_type, dict): + for k in input_type.keys(): + # allow single input for multi-modal models + if k in input: + check_input_type(input_type[k], input[k]) + else: + raise ValueError(f'invalid input_type definition {input_type}') + else: + logger.warning(f'task {task_name} input definition is missing') + def _check_output(self, input): # this attribute is dynamically attached by registry # when cls is registered in registry using task name @@ -346,10 +379,13 @@ class DistributedPipeline(Pipeline): def _instantiate_one(cls, rank, model_dir, **kwargs): """Instantiate one model piece. - @param rank: The model rank. - @param model_dir: The model_dir in the node. - @param kwargs: Any extra args. - @return: None. The model handler should be kept in the class field. + Args: + rank: The model rank. + model_dir: The model_dir in the node. + kwargs: Any extra args. + + Returns: + None. The model handler should be kept in the class field. """ pass @@ -369,8 +405,11 @@ class DistributedPipeline(Pipeline): Use the model handler kept in the class field to forward. - @param inputs: The inputs after the preprocessing. - @return: The forward results. + Args: + inputs: The inputs after the preprocessing. + + Returns: + The forward results. """ pass @@ -388,10 +427,12 @@ def collate_fn(data, device): """ from torch.utils.data.dataloader import default_collate - from modelscope.preprocessors import InputFeatures + from modelscope.preprocessors.nlp import InputFeatures if isinstance(data, dict) or isinstance(data, Mapping): return type(data)({k: collate_fn(v, device) for k, v in data.items()}) elif isinstance(data, (tuple, list)): + if 0 == len(data): + return torch.Tensor([]) if isinstance(data[0], (int, float)): return default_collate(data).to(device) else: diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 7fa66b5f..498c9ed8 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -20,17 +20,22 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.sentence_embedding: (Pipelines.sentence_embedding, 'damo/nlp_corom_sentence-embedding_english-base'), - Tasks.passage_ranking: (Pipelines.passage_ranking, - 'damo/nlp_corom_passage-ranking_english-base'), + Tasks.text_ranking: (Pipelines.text_ranking, + 'damo/nlp_corom_passage-ranking_english-base'), Tasks.word_segmentation: (Pipelines.word_segmentation, 'damo/nlp_structbert_word-segmentation_chinese-base'), + Tasks.part_of_speech: (Pipelines.part_of_speech, + 'damo/nlp_structbert_part-of-speech_chinese-base'), Tasks.token_classification: (Pipelines.part_of_speech, 'damo/nlp_structbert_part-of-speech_chinese-base'), Tasks.named_entity_recognition: (Pipelines.named_entity_recognition, 'damo/nlp_raner_named-entity-recognition_chinese-base-news'), + Tasks.relation_extraction: + (Pipelines.relation_extraction, + 'damo/nlp_bert_relation-extraction_chinese-base'), Tasks.information_extraction: (Pipelines.relation_extraction, 'damo/nlp_bert_relation-extraction_chinese-base'), @@ -64,9 +69,6 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/nlp_space_dialog-modeling'), Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, 'damo/nlp_space_dialog-state-tracking'), - Tasks.conversational_text_to_sql: - (Pipelines.conversational_text_to_sql, - 'damo/nlp_star_conversational-text-to-sql'), Tasks.table_question_answering: (Pipelines.table_question_answering_pipeline, 'damo/nlp-convai-text2sql-pretrain-cn'), @@ -118,6 +120,11 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody'), Tasks.face_detection: (Pipelines.face_detection, 'damo/cv_resnet_facedetection_scrfd10gkps'), + Tasks.card_detection: (Pipelines.card_detection, + 'damo/cv_resnet_carddetection_scrfd34gkps'), + Tasks.face_detection: + (Pipelines.face_detection, + 'damo/cv_resnet101_face-detection_cvpr22papermogface'), Tasks.face_recognition: (Pipelines.face_recognition, 'damo/cv_ir101_facerecognition_cfglint'), Tasks.facial_expression_recognition: @@ -156,6 +163,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.image_classification: (Pipelines.daily_image_classification, 'damo/cv_vit-base_image-classification_Dailylife-labels'), + Tasks.image_object_detection: + (Pipelines.image_object_detection_auto, + 'damo/cv_yolox_image-object-detection-auto'), Tasks.ocr_recognition: (Pipelines.ocr_recognition, 'damo/cv_convnextTiny_ocr-recognition-general_damo'), @@ -179,8 +189,13 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_resnet50-bert_video-scene-segmentation_movienet'), Tasks.shop_segmentation: (Pipelines.shop_segmentation, 'damo/cv_vitb16_segmentation_shop-seg'), + Tasks.image_inpainting: (Pipelines.image_inpainting, + 'damo/cv_fft_inpainting_lama'), Tasks.video_inpainting: (Pipelines.video_inpainting, 'damo/cv_video-inpainting'), + Tasks.human_wholebody_keypoint: + (Pipelines.human_wholebody_keypoint, + 'damo/cv_hrnetw48_human-wholebody-keypoint_image'), Tasks.hand_static: (Pipelines.hand_static, 'damo/cv_mobileface_hand-static'), Tasks.face_human_hand_detection: @@ -189,6 +204,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.face_emotion: (Pipelines.face_emotion, 'damo/cv_face-emotion'), Tasks.product_segmentation: (Pipelines.product_segmentation, 'damo/cv_F3Net_product-segmentation'), + Tasks.referring_video_object_segmentation: + (Pipelines.referring_video_object_segmentation, + 'damo/cv_swin-t_referring_video-object-segmentation'), } @@ -267,9 +285,6 @@ def pipeline(task: str = None, if task is None and pipeline_name is None: raise ValueError('task or pipeline_name is required') - assert isinstance(model, (type(None), str, Model, list)), \ - f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}' - model = normalize_model_input(model, model_revision) if pipeline_name is None: # get default pipeline for this task @@ -286,8 +301,7 @@ def pipeline(task: str = None, else: # used for test case, when model is str and is not hub path pipeline_name = get_pipeline_by_model_name(task, model) - elif isinstance(model, Model) or \ - (isinstance(model, list) and isinstance(model[0], Model)): + elif model is not None: # get pipeline info from Model object first_model = model[0] if isinstance(model, list) else model if not hasattr(first_model, 'pipeline'): diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 55bad09a..97cd8761 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -35,6 +35,7 @@ if TYPE_CHECKING: from .image_super_resolution_pipeline import ImageSuperResolutionPipeline from .image_to_image_generate_pipeline import Image2ImageGenerationPipeline from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline + from .image_inpainting_pipeline import ImageInpaintingPipeline from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline from .realtime_object_detection_pipeline import RealtimeObjectDetectionPipeline from .live_category_pipeline import LiveCategoryPipeline @@ -45,7 +46,10 @@ if TYPE_CHECKING: from .video_category_pipeline import VideoCategoryPipeline from .virtual_try_on_pipeline import VirtualTryonPipeline from .shop_segmentation_pipleline import ShopSegmentationPipeline - from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline, Face2DKeypointsPipeline + from .easycv_pipelines import (EasyCVDetectionPipeline, + EasyCVSegmentationPipeline, + Face2DKeypointsPipeline, + HumanWholebodyKeypointsPipeline) from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipeline from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline from .mog_face_detection_pipeline import MogFaceDetectionPipeline @@ -54,6 +58,7 @@ if TYPE_CHECKING: from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline from .mtcnn_face_detection_pipeline import MtcnnFaceDetectionPipelin from .hand_static_pipeline import HandStaticPipeline + from .referring_video_object_segmentation_pipeline import ReferringVideoObjectSegmentationPipeline else: _import_structure = { @@ -99,6 +104,7 @@ else: 'live_category_pipeline': ['LiveCategoryPipeline'], 'image_to_image_generation_pipeline': ['Image2ImageGenerationPipeline'], + 'image_inpainting_pipeline': ['ImageInpaintingPipeline'], 'ocr_detection_pipeline': ['OCRDetectionPipeline'], 'ocr_recognition_pipeline': ['OCRRecognitionPipeline'], 'skin_retouching_pipeline': ['SkinRetouchingPipeline'], @@ -107,8 +113,10 @@ else: 'virtual_try_on_pipeline': ['VirtualTryonPipeline'], 'shop_segmentation_pipleline': ['ShopSegmentationPipeline'], 'easycv_pipeline': [ - 'EasyCVDetectionPipeline', 'EasyCVSegmentationPipeline', - 'Face2DKeypointsPipeline' + 'EasyCVDetectionPipeline', + 'EasyCVSegmentationPipeline', + 'Face2DKeypointsPipeline', + 'HumanWholebodyKeypointsPipeline', ], 'text_driven_segmentation_pipeline': ['TextDrivenSegmentationPipeline'], @@ -121,6 +129,9 @@ else: ['FacialExpressionRecognitionPipeline'], 'mtcnn_face_detection_pipeline': ['MtcnnFaceDetectionPipeline'], 'hand_static_pipeline': ['HandStaticPipeline'], + 'referring_video_object_segmentation_pipeline': [ + 'ReferringVideoObjectSegmentationPipeline' + ], } import sys diff --git a/modelscope/pipelines/cv/animal_recognition_pipeline.py b/modelscope/pipelines/cv/animal_recognition_pipeline.py index fad14680..671a5b4c 100644 --- a/modelscope/pipelines/cv/animal_recognition_pipeline.py +++ b/modelscope/pipelines/cv/animal_recognition_pipeline.py @@ -113,9 +113,8 @@ class AnimalRecognitionPipeline(Pipeline): label_mapping = f.readlines() score = torch.max(inputs['outputs']) inputs = { - OutputKeys.SCORES: - score.item(), + OutputKeys.SCORES: [score.item()], OutputKeys.LABELS: - label_mapping[inputs['outputs'].argmax()].split('\t')[1] + [label_mapping[inputs['outputs'].argmax()].split('\t')[1]] } return inputs diff --git a/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py b/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py index d6afbae4..bc2e975d 100644 --- a/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py +++ b/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py @@ -73,7 +73,7 @@ class Body2DKeypointsPipeline(Pipeline): if input[0] is None or input[1] is None: return { OutputKeys.BOXES: [], - OutputKeys.POSES: [], + OutputKeys.KEYPOINTS: [], OutputKeys.SCORES: [] } @@ -83,7 +83,7 @@ class Body2DKeypointsPipeline(Pipeline): result_boxes.append([box[0][0], box[0][1], box[1][0], box[1][1]]) return { OutputKeys.BOXES: result_boxes, - OutputKeys.POSES: poses, + OutputKeys.KEYPOINTS: poses, OutputKeys.SCORES: scores } diff --git a/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py b/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py index 474c0e54..8522ceff 100644 --- a/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py +++ b/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py @@ -132,12 +132,7 @@ class Body3DKeypointsPipeline(Pipeline): device='gpu' if torch.cuda.is_available() else 'cpu') def preprocess(self, input: Input) -> Dict[str, Any]: - video_url = input.get('input_video') - self.output_video_path = input.get('output_video_path') - if self.output_video_path is None: - self.output_video_path = tempfile.NamedTemporaryFile( - suffix='.mp4').name - + video_url = input video_frames = self.read_video_frames(video_url) if 0 == len(video_frames): res = {'success': False, 'msg': 'get video frame failed.'} @@ -148,9 +143,16 @@ class Body3DKeypointsPipeline(Pipeline): max_frame = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME # max video frame number to be predicted 3D joints for i, frame in enumerate(video_frames): kps_2d = self.human_body_2d_kps_detector(frame) + if [] == kps_2d.get('boxes'): + res = { + 'success': False, + 'msg': f'fail to detect person at image frame {i}' + } + return res + box = kps_2d['boxes'][ 0] # box: [[[x1, y1], [x2, y2]]], N human boxes per frame, [0] represent using first detected bbox - pose = kps_2d['poses'][0] # keypoints: [15, 2] + pose = kps_2d['keypoints'][0] # keypoints: [15, 2] score = kps_2d['scores'][0] # keypoints: [15, 2] all_2d_poses.append(pose) all_boxes_with_socre.append( @@ -185,7 +187,15 @@ class Body3DKeypointsPipeline(Pipeline): return res def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: - res = {OutputKeys.POSES: [], OutputKeys.TIMESTAMPS: []} + output_video_path = kwargs.get('output_video', None) + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name + + res = { + OutputKeys.KEYPOINTS: [], + OutputKeys.TIMESTAMPS: [], + OutputKeys.OUTPUT_VIDEO: output_video_path + } if not input['success']: pass @@ -195,10 +205,10 @@ class Body3DKeypointsPipeline(Pipeline): 0] # [frame_num, joint_num, joint_dim] if 'render' in self.keypoint_model_3d.cfg.keys(): - self.render_prediction(pred_3d_pose) - res[OutputKeys.OUTPUT_VIDEO] = self.output_video_path + self.render_prediction(pred_3d_pose, output_video_path) + res[OutputKeys.OUTPUT_VIDEO] = output_video_path - res[OutputKeys.POSES] = pred_3d_pose + res[OutputKeys.KEYPOINTS] = pred_3d_pose res[OutputKeys.TIMESTAMPS] = self.timestamps return res @@ -252,12 +262,12 @@ class Body3DKeypointsPipeline(Pipeline): cap.release() return frames - def render_prediction(self, pose3d_cam_rr): + def render_prediction(self, pose3d_cam_rr, output_video_path): """render predict result 3d poses. Args: pose3d_cam_rr (nd.array): [frame_num, joint_num, joint_dim], 3d pose joints - + output_video_path (str): output path for video Returns: """ frame_num = pose3d_cam_rr.shape[0] @@ -359,4 +369,4 @@ class Body3DKeypointsPipeline(Pipeline): # save mp4 Writer = writers['ffmpeg'] writer = Writer(fps=self.fps, metadata={}, bitrate=4096) - ani.save(self.output_video_path, writer=writer) + ani.save(output_video_path, writer=writer) diff --git a/modelscope/pipelines/cv/card_detection_pipeline.py b/modelscope/pipelines/cv/card_detection_pipeline.py new file mode 100644 index 00000000..00b18024 --- /dev/null +++ b/modelscope/pipelines/cv/card_detection_pipeline.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from modelscope.metainfo import Pipelines +from modelscope.pipelines.builder import PIPELINES +from modelscope.pipelines.cv.face_detection_pipeline import \ + FaceDetectionPipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.card_detection, module_name=Pipelines.card_detection) +class CardDetectionPipeline(FaceDetectionPipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a card detection pipeline for prediction + Args: + model: model id on modelscope hub. + """ + thr = 0.45 # card/face detect use different threshold + super().__init__(model=model, score_thr=thr, **kwargs) diff --git a/modelscope/pipelines/cv/crowd_counting_pipeline.py b/modelscope/pipelines/cv/crowd_counting_pipeline.py index 3143825b..93fffdf2 100644 --- a/modelscope/pipelines/cv/crowd_counting_pipeline.py +++ b/modelscope/pipelines/cv/crowd_counting_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import math from typing import Any, Dict diff --git a/modelscope/pipelines/cv/easycv_pipelines/__init__.py b/modelscope/pipelines/cv/easycv_pipelines/__init__.py index 4f149130..e0209b85 100644 --- a/modelscope/pipelines/cv/easycv_pipelines/__init__.py +++ b/modelscope/pipelines/cv/easycv_pipelines/__init__.py @@ -7,11 +7,14 @@ if TYPE_CHECKING: from .detection_pipeline import EasyCVDetectionPipeline from .segmentation_pipeline import EasyCVSegmentationPipeline from .face_2d_keypoints_pipeline import Face2DKeypointsPipeline + from .human_wholebody_keypoint_pipeline import HumanWholebodyKeypointsPipeline else: _import_structure = { 'detection_pipeline': ['EasyCVDetectionPipeline'], 'segmentation_pipeline': ['EasyCVSegmentationPipeline'], - 'face_2d_keypoints_pipeline': ['Face2DKeypointsPipeline'] + 'face_2d_keypoints_pipeline': ['Face2DKeypointsPipeline'], + 'human_wholebody_keypoint_pipeline': + ['HumanWholebodyKeypointsPipeline'], } import sys diff --git a/modelscope/pipelines/cv/easycv_pipelines/base.py b/modelscope/pipelines/cv/easycv_pipelines/base.py index 8aea1146..c130aea0 100644 --- a/modelscope/pipelines/cv/easycv_pipelines/base.py +++ b/modelscope/pipelines/cv/easycv_pipelines/base.py @@ -4,7 +4,9 @@ import os import os.path as osp from typing import Any +import numpy as np from easycv.utils.ms_utils import EasyCVMeta +from PIL import ImageFile from modelscope.hub.snapshot_download import snapshot_download from modelscope.pipelines.util import is_official_hub_path @@ -94,5 +96,19 @@ class EasyCVPipeline(object): return easycv_config + def _is_single_inputs(self, inputs): + if isinstance(inputs, str) or (isinstance(inputs, list) + and len(inputs) == 1) or isinstance( + inputs, np.ndarray) or isinstance( + inputs, ImageFile.ImageFile): + return True + + return False + def __call__(self, inputs) -> Any: - return self.predict_op(inputs) + outputs = self.predict_op(inputs) + + if self._is_single_inputs(inputs): + outputs = outputs[0] + + return outputs diff --git a/modelscope/pipelines/cv/easycv_pipelines/detection_pipeline.py b/modelscope/pipelines/cv/easycv_pipelines/detection_pipeline.py index 32365102..a1173bc4 100644 --- a/modelscope/pipelines/cv/easycv_pipelines/detection_pipeline.py +++ b/modelscope/pipelines/cv/easycv_pipelines/detection_pipeline.py @@ -1,16 +1,28 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any + from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys from modelscope.pipelines.builder import PIPELINES -from modelscope.utils.constant import Tasks +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.cv.image_utils import \ + show_image_object_detection_auto_result from .base import EasyCVPipeline @PIPELINES.register_module( Tasks.image_object_detection, module_name=Pipelines.easycv_detection) +@PIPELINES.register_module( + Tasks.image_object_detection, + module_name=Pipelines.image_object_detection_auto) class EasyCVDetectionPipeline(EasyCVPipeline): """Pipeline for easycv detection task.""" - def __init__(self, model: str, model_file_pattern='*.pt', *args, **kwargs): + def __init__(self, + model: str, + model_file_pattern=ModelFile.TORCH_MODEL_FILE, + *args, + **kwargs): """ model (str): model id on modelscope hub or local model path. model_file_pattern (str): model file pattern. @@ -21,3 +33,31 @@ class EasyCVDetectionPipeline(EasyCVPipeline): model_file_pattern=model_file_pattern, *args, **kwargs) + + def show_result(self, img_path, result, save_path=None): + show_image_object_detection_auto_result(img_path, result, save_path) + + def __call__(self, inputs) -> Any: + outputs = self.predict_op(inputs) + + scores = [] + labels = [] + boxes = [] + for output in outputs: + for score, label, box in zip(output['detection_scores'], + output['detection_classes'], + output['detection_boxes']): + scores.append(score) + labels.append(self.cfg.CLASSES[label]) + boxes.append([b for b in box]) + + results = [{ + OutputKeys.SCORES: scores, + OutputKeys.LABELS: labels, + OutputKeys.BOXES: boxes + } for output in outputs] + + if self._is_single_inputs(inputs): + results = results[0] + + return results diff --git a/modelscope/pipelines/cv/easycv_pipelines/face_2d_keypoints_pipeline.py b/modelscope/pipelines/cv/easycv_pipelines/face_2d_keypoints_pipeline.py index 7c32e0fc..b48d013e 100644 --- a/modelscope/pipelines/cv/easycv_pipelines/face_2d_keypoints_pipeline.py +++ b/modelscope/pipelines/cv/easycv_pipelines/face_2d_keypoints_pipeline.py @@ -40,4 +40,7 @@ class Face2DKeypointsPipeline(EasyCVPipeline): OutputKeys.POSES: output['pose'] } for output in outputs] + if self._is_single_inputs(inputs): + results = results[0] + return results diff --git a/modelscope/pipelines/cv/easycv_pipelines/human_wholebody_keypoint_pipeline.py b/modelscope/pipelines/cv/easycv_pipelines/human_wholebody_keypoint_pipeline.py new file mode 100644 index 00000000..936accbf --- /dev/null +++ b/modelscope/pipelines/cv/easycv_pipelines/human_wholebody_keypoint_pipeline.py @@ -0,0 +1,68 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path +from typing import Any + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import ModelFile, Tasks +from .base import EasyCVPipeline + + +@PIPELINES.register_module( + Tasks.human_wholebody_keypoint, + module_name=Pipelines.human_wholebody_keypoint) +class HumanWholebodyKeypointsPipeline(EasyCVPipeline): + """Pipeline for human wholebody 2d keypoints detection.""" + + def __init__(self, + model: str, + model_file_pattern=ModelFile.TORCH_MODEL_FILE, + *args, + **kwargs): + """ + model (str): model id on modelscope hub or local model path. + model_file_pattern (str): model file pattern. + """ + self.model_dir = model + super(HumanWholebodyKeypointsPipeline, self).__init__( + model=model, + model_file_pattern=model_file_pattern, + *args, + **kwargs) + + def _build_predict_op(self, **kwargs): + """Build EasyCV predictor.""" + from easycv.predictors.builder import build_predictor + detection_predictor_type = self.cfg['DETECTION']['type'] + detection_model_path = os.path.join( + self.model_dir, self.cfg['DETECTION']['model_path']) + detection_cfg_file = os.path.join(self.model_dir, + self.cfg['DETECTION']['config_file']) + detection_score_threshold = self.cfg['DETECTION']['score_threshold'] + self.cfg.pipeline.predictor_config[ + 'detection_predictor_config'] = dict( + type=detection_predictor_type, + model_path=detection_model_path, + config_file=detection_cfg_file, + score_threshold=detection_score_threshold) + easycv_config = self._to_easycv_config() + pipeline_op = build_predictor(self.cfg.pipeline.predictor_config, { + 'model_path': self.model_path, + 'config_file': easycv_config, + **kwargs + }) + return pipeline_op + + def __call__(self, inputs) -> Any: + outputs = self.predict_op(inputs) + + results = [{ + OutputKeys.KEYPOINTS: output['keypoints'], + OutputKeys.BOXES: output['boxes'] + } for output in outputs] + + if self._is_single_inputs(inputs): + results = results[0] + + return results diff --git a/modelscope/pipelines/cv/face_detection_pipeline.py b/modelscope/pipelines/cv/face_detection_pipeline.py index eff5b70f..608567a4 100644 --- a/modelscope/pipelines/cv/face_detection_pipeline.py +++ b/modelscope/pipelines/cv/face_detection_pipeline.py @@ -8,6 +8,7 @@ import PIL import torch from modelscope.metainfo import Pipelines +from modelscope.models.cv.face_detection import ScrfdDetect from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES @@ -29,27 +30,8 @@ class FaceDetectionPipeline(Pipeline): model: model id on modelscope hub. """ super().__init__(model=model, **kwargs) - from mmcv import Config - from mmcv.parallel import MMDataParallel - from mmcv.runner import load_checkpoint - from mmdet.models import build_detector - from modelscope.models.cv.face_detection.mmdet_patch.datasets import RetinaFaceDataset - from modelscope.models.cv.face_detection.mmdet_patch.datasets.pipelines import RandomSquareCrop - from modelscope.models.cv.face_detection.mmdet_patch.models.backbones import ResNetV1e - from modelscope.models.cv.face_detection.mmdet_patch.models.dense_heads import SCRFDHead - from modelscope.models.cv.face_detection.mmdet_patch.models.detectors import SCRFD - cfg = Config.fromfile(osp.join(model, 'mmcv_scrfd_10g_bnkps.py')) - detector = build_detector( - cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) - ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_BIN_FILE) - logger.info(f'loading model from {ckpt_path}') - device = torch.device( - f'cuda:{0}' if torch.cuda.is_available() else 'cpu') - load_checkpoint(detector, ckpt_path, map_location=device) - detector = MMDataParallel(detector, device_ids=[0]) - detector.eval() + detector = ScrfdDetect(model_dir=model, **kwargs) self.detector = detector - logger.info('load model done') def preprocess(self, input: Input) -> Dict[str, Any]: img = LoadImage.convert_to_ndarray(input) @@ -85,22 +67,7 @@ class FaceDetectionPipeline(Pipeline): return result def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: - - result = self.detector( - return_loss=False, - rescale=True, - img=[input['img'][0].unsqueeze(0)], - img_metas=[[dict(input['img_metas'][0].data)]]) - assert result is not None - result = result[0][0] - bboxes = result[:, :4].tolist() - kpss = result[:, 5:].tolist() - scores = result[:, 4].tolist() - return { - OutputKeys.SCORES: scores, - OutputKeys.BOXES: bboxes, - OutputKeys.KEYPOINTS: kpss - } + return self.detector(input) def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return inputs diff --git a/modelscope/pipelines/cv/face_emotion_pipeline.py b/modelscope/pipelines/cv/face_emotion_pipeline.py index 249493b6..9d9aa6ee 100644 --- a/modelscope/pipelines/cv/face_emotion_pipeline.py +++ b/modelscope/pipelines/cv/face_emotion_pipeline.py @@ -1,11 +1,14 @@ # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. from typing import Any, Dict +import numpy as np + from modelscope.metainfo import Pipelines from modelscope.models.cv.face_emotion import emotion_infer from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger @@ -28,10 +31,11 @@ class FaceEmotionPipeline(Pipeline): logger.info('load model done') def preprocess(self, input: Input) -> Dict[str, Any]: - return input + img = LoadImage.convert_to_ndarray(input['img_path']) + return img def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: - result, bbox = emotion_infer.inference(input['img_path'], self.model, + result, bbox = emotion_infer.inference(input, self.model, self.face_model) return {OutputKeys.OUTPUT: result, OutputKeys.BOXES: bbox} diff --git a/modelscope/pipelines/cv/face_human_hand_detection_pipeline.py b/modelscope/pipelines/cv/face_human_hand_detection_pipeline.py index d9f214c9..d41a14dd 100644 --- a/modelscope/pipelines/cv/face_human_hand_detection_pipeline.py +++ b/modelscope/pipelines/cv/face_human_hand_detection_pipeline.py @@ -2,11 +2,14 @@ from typing import Any, Dict +import numpy as np + from modelscope.metainfo import Pipelines from modelscope.models.cv.face_human_hand_detection import det_infer from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger @@ -29,14 +32,19 @@ class NanoDettForFaceHumanHandDetectionPipeline(Pipeline): logger.info('load model done') def preprocess(self, input: Input) -> Dict[str, Any]: - return input + img = LoadImage.convert_to_ndarray(input['input_path']) + return img def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: - result = det_infer.inference(self.model, self.device, - input['input_path']) - logger.info(result) - return {OutputKeys.OUTPUT: result} + cls_list, bbox_list, score_list = det_infer.inference( + self.model, self.device, input) + logger.info(cls_list, bbox_list, score_list) + return { + OutputKeys.LABELS: cls_list, + OutputKeys.BOXES: bbox_list, + OutputKeys.SCORES: score_list + } def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return inputs diff --git a/modelscope/pipelines/cv/face_image_generation_pipeline.py b/modelscope/pipelines/cv/face_image_generation_pipeline.py index 405c9a4b..1b4e2e8a 100644 --- a/modelscope/pipelines/cv/face_image_generation_pipeline.py +++ b/modelscope/pipelines/cv/face_image_generation_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os from typing import Any, Dict @@ -60,6 +61,8 @@ class FaceImageGenerationPipeline(Pipeline): return input def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + if isinstance(input, str): + input = int(input) assert isinstance(input, int) torch.manual_seed(input) torch.cuda.manual_seed(input) diff --git a/modelscope/pipelines/cv/facial_expression_recognition_pipeline.py b/modelscope/pipelines/cv/facial_expression_recognition_pipeline.py index 1b1f13d1..3c85ae62 100644 --- a/modelscope/pipelines/cv/facial_expression_recognition_pipeline.py +++ b/modelscope/pipelines/cv/facial_expression_recognition_pipeline.py @@ -45,6 +45,9 @@ class FacialExpressionRecognitionPipeline(Pipeline): # face detect pipeline det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' + self.map_list = [ + 'Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral' + ] self.face_detection = pipeline( Tasks.face_detection, model=det_model_id) @@ -119,11 +122,7 @@ class FacialExpressionRecognitionPipeline(Pipeline): result = self.fer(input) assert result is not None scores = result[0].tolist() - labels = result[1].tolist() - return { - OutputKeys.SCORES: scores, - OutputKeys.LABELS: labels, - } + return {OutputKeys.SCORES: scores, OutputKeys.LABELS: self.map_list} def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return inputs diff --git a/modelscope/pipelines/cv/general_recognition_pipeline.py b/modelscope/pipelines/cv/general_recognition_pipeline.py index 07222086..80f6f88a 100644 --- a/modelscope/pipelines/cv/general_recognition_pipeline.py +++ b/modelscope/pipelines/cv/general_recognition_pipeline.py @@ -114,9 +114,8 @@ class GeneralRecognitionPipeline(Pipeline): label_mapping = f.readlines() score = torch.max(inputs['outputs']) inputs = { - OutputKeys.SCORES: - score.item(), + OutputKeys.SCORES: [score.item()], OutputKeys.LABELS: - label_mapping[inputs['outputs'].argmax()].split('\t')[1] + [label_mapping[inputs['outputs'].argmax()].split('\t')[1]] } return inputs diff --git a/modelscope/pipelines/cv/hand_static_pipeline.py b/modelscope/pipelines/cv/hand_static_pipeline.py index 1219c873..c020b7aa 100644 --- a/modelscope/pipelines/cv/hand_static_pipeline.py +++ b/modelscope/pipelines/cv/hand_static_pipeline.py @@ -1,11 +1,14 @@ # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. from typing import Any, Dict +import numpy as np + from modelscope.metainfo import Pipelines from modelscope.models.cv.hand_static import hand_model from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger @@ -27,10 +30,11 @@ class HandStaticPipeline(Pipeline): logger.info('load model done') def preprocess(self, input: Input) -> Dict[str, Any]: - return input + img = LoadImage.convert_to_ndarray(input['img_path']) + return img def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: - result = hand_model.infer(input['img_path'], self.model, self.device) + result = hand_model.infer(input, self.model, self.device) return {OutputKeys.OUTPUT: result} def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: diff --git a/modelscope/pipelines/cv/hicossl_video_embedding_pipeline.py b/modelscope/pipelines/cv/hicossl_video_embedding_pipeline.py index 5e4cd4c6..21af2f75 100644 --- a/modelscope/pipelines/cv/hicossl_video_embedding_pipeline.py +++ b/modelscope/pipelines/cv/hicossl_video_embedding_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import math import os.path as osp from typing import Any, Dict diff --git a/modelscope/pipelines/cv/image_body_reshaping_pipeline.py b/modelscope/pipelines/cv/image_body_reshaping_pipeline.py new file mode 100644 index 00000000..c3600eb5 --- /dev/null +++ b/modelscope/pipelines/cv/image_body_reshaping_pipeline.py @@ -0,0 +1,40 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_body_reshaping, module_name=Pipelines.image_body_reshaping) +class ImageBodyReshapingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a image body reshaping pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + logger.info('body reshaping model init done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + result = {'img': img} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + output = self.model.inference(input['img']) + result = {'outputs': output} + return result + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + output_img = inputs['outputs'] + return {OutputKeys.OUTPUT_IMG: output_img} diff --git a/modelscope/pipelines/cv/image_classification_pipeline.py b/modelscope/pipelines/cv/image_classification_pipeline.py index 49467eab..69dbd1fb 100644 --- a/modelscope/pipelines/cv/image_classification_pipeline.py +++ b/modelscope/pipelines/cv/image_classification_pipeline.py @@ -13,6 +13,7 @@ from modelscope.pipelines.base import Input, Model, Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import OfaPreprocessor, Preprocessor, load_image from modelscope.utils.constant import Tasks +from modelscope.utils.device import get_device from modelscope.utils.logger import get_logger logger = get_logger() @@ -36,6 +37,7 @@ class ImageClassificationPipeline(Pipeline): else: raise NotImplementedError pipe_model.model.eval() + pipe_model.to(get_device()) if preprocessor is None and isinstance(pipe_model, OfaForAllTasks): preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) diff --git a/modelscope/pipelines/cv/image_color_enhance_pipeline.py b/modelscope/pipelines/cv/image_color_enhance_pipeline.py index 40777d60..3a4cf8bc 100644 --- a/modelscope/pipelines/cv/image_color_enhance_pipeline.py +++ b/modelscope/pipelines/cv/image_color_enhance_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict, Optional, Union import torch @@ -54,5 +55,5 @@ class ImageColorEnhancePipeline(Pipeline): def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: output_img = (inputs['outputs'].squeeze(0) * 255.).type( - torch.uint8).cpu().permute(1, 2, 0).numpy() + torch.uint8).cpu().permute(1, 2, 0).numpy()[:, :, ::-1] return {OutputKeys.OUTPUT_IMG: output_img} diff --git a/modelscope/pipelines/cv/image_colorization_pipeline.py b/modelscope/pipelines/cv/image_colorization_pipeline.py index 0fea729d..cd385024 100644 --- a/modelscope/pipelines/cv/image_colorization_pipeline.py +++ b/modelscope/pipelines/cv/image_colorization_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict import cv2 diff --git a/modelscope/pipelines/cv/image_denoise_pipeline.py b/modelscope/pipelines/cv/image_denoise_pipeline.py index 64aa3bc9..34ac1e81 100644 --- a/modelscope/pipelines/cv/image_denoise_pipeline.py +++ b/modelscope/pipelines/cv/image_denoise_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict, Optional, Union import torch @@ -104,4 +105,4 @@ class ImageDenoisePipeline(Pipeline): def postprocess(self, input: Dict[str, Any]) -> Dict[str, Any]: output_img = (input['output_tensor'].squeeze(0) * 255).cpu().permute( 1, 2, 0).numpy().astype('uint8') - return {OutputKeys.OUTPUT_IMG: output_img} + return {OutputKeys.OUTPUT_IMG: output_img[:, :, ::-1]} diff --git a/modelscope/pipelines/cv/image_detection_pipeline.py b/modelscope/pipelines/cv/image_detection_pipeline.py index f5554ca2..08633c35 100644 --- a/modelscope/pipelines/cv/image_detection_pipeline.py +++ b/modelscope/pipelines/cv/image_detection_pipeline.py @@ -43,11 +43,15 @@ class ImageDetectionPipeline(Pipeline): bboxes, scores, labels = self.model.postprocess(inputs['data']) if bboxes is None: - return None + outputs = { + OutputKeys.SCORES: [], + OutputKeys.LABELS: [], + OutputKeys.BOXES: [] + } + return outputs outputs = { OutputKeys.SCORES: scores, OutputKeys.LABELS: labels, OutputKeys.BOXES: bboxes } - return outputs diff --git a/modelscope/pipelines/cv/image_inpainting_pipeline.py b/modelscope/pipelines/cv/image_inpainting_pipeline.py new file mode 100644 index 00000000..aff9788d --- /dev/null +++ b/modelscope/pipelines/cv/image_inpainting_pipeline.py @@ -0,0 +1,147 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch +import torch.nn as nn +from torch.utils.data._utils.collate import default_collate + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.image_inpainting import FFTInpainting +from modelscope.models.cv.image_inpainting.refinement import refine_predict +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors.image import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_inpainting, module_name=Pipelines.image_inpainting) +class ImageInpaintingPipeline(Pipeline): + + def __init__(self, + model: str, + pad_out_to_modulo=8, + refine=False, + **kwargs): + """ + model: model id on modelscope hub. + """ + assert isinstance(model, str), 'model must be a single str' + super().__init__(model=model, auto_collate=False, **kwargs) + self.refine = refine + logger.info(f'loading model from dir {model}') + self.infer_model = FFTInpainting(model, predict_only=True) + if not self.refine: + self.infer_model.to(self.device) + self.infer_model.eval() + logger.info(f'loading model done, refinement is set to {self.refine}') + self.pad_out_to_modulo = pad_out_to_modulo + + def move_to_device(self, obj, device): + if isinstance(obj, nn.Module): + return obj.to(device) + if torch.is_tensor(obj): + return obj.to(device) + if isinstance(obj, (tuple, list)): + return [self.move_to_device(el, device) for el in obj] + if isinstance(obj, dict): + return { + name: self.move_to_device(val, device) + for name, val in obj.items() + } + raise ValueError(f'Unexpected type {type(obj)}') + + def transforms(self, img): + if img.ndim == 3: + img = np.transpose(img, (2, 0, 1)) + out_img = img.astype('float32') / 255 + return out_img + + def ceil_modulo(self, x, mod): + if x % mod == 0: + return x + return (x // mod + 1) * mod + + def pad_img_to_modulo(self, img, mod): + channels, height, width = img.shape + out_height = self.ceil_modulo(height, mod) + out_width = self.ceil_modulo(width, mod) + return np.pad( + img, ((0, 0), (0, out_height - height), (0, out_width - width)), + mode='symmetric') + + def preprocess(self, input: Dict[str, Any]) -> Dict[str, Any]: + if isinstance(input['img'], str): + image_name, mask_name = input['img'], input['mask'] + img = LoadImage.convert_to_ndarray(image_name) + img = self.transforms(img) + mask = np.array(LoadImage(mode='L')(mask_name)['img']) + mask = self.transforms(mask) + elif isinstance(input['img'], PIL.Image.Image): + img = input['img'] + img = self.transforms(np.array(img)) + mask = input['mask'].convert('L') + mask = self.transforms(np.array(mask)) + else: + raise TypeError( + 'input should be either str or PIL.Image, and both inputs should have the same type' + ) + result = dict(image=img, mask=mask[None, ...]) + + if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1: + result['unpad_to_size'] = result['image'].shape[1:] + result['image'] = self.pad_img_to_modulo(result['image'], + self.pad_out_to_modulo) + result['mask'] = self.pad_img_to_modulo(result['mask'], + self.pad_out_to_modulo) + + # Since Pipeline use default torch.no_grad() for performing forward func. + # We conduct inference here in case of doing training for refinement. + result = self.perform_inference(result) + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + return {OutputKeys.OUTPUT_IMG: input} + + def perform_inference(self, data): + batch = default_collate([data]) + if self.refine: + assert 'unpad_to_size' in batch, 'Unpadded size is required for the refinement' + assert 'cuda' in str(self.device), 'GPU is required for refinement' + gpu_ids = str(self.device).split(':')[-1] + cur_res = refine_predict( + batch, + self.infer_model, + gpu_ids=gpu_ids, + modulo=self.pad_out_to_modulo, + n_iters=15, + lr=0.002, + min_side=512, + max_scales=3, + px_budget=900000) + cur_res = cur_res[0].permute(1, 2, 0).detach().cpu().numpy() + else: + with torch.no_grad(): + batch = self.move_to_device(batch, self.device) + batch['mask'] = (batch['mask'] > 0) * 1 + batch = self.infer_model(batch) + cur_res = batch['inpainted'][0].permute( + 1, 2, 0).detach().cpu().numpy() + unpad_to_size = batch.get('unpad_to_size', None) + if unpad_to_size is not None: + orig_height, orig_width = unpad_to_size + cur_res = cur_res[:orig_height, :orig_width] + + cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8') + cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR) + return cur_res + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/image_matting_pipeline.py b/modelscope/pipelines/cv/image_matting_pipeline.py index d7b7fc3c..fb5d8f8b 100644 --- a/modelscope/pipelines/cv/image_matting_pipeline.py +++ b/modelscope/pipelines/cv/image_matting_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp from typing import Any, Dict diff --git a/modelscope/pipelines/cv/image_portrait_enhancement_pipeline.py b/modelscope/pipelines/cv/image_portrait_enhancement_pipeline.py index 87e692e8..3eec6526 100644 --- a/modelscope/pipelines/cv/image_portrait_enhancement_pipeline.py +++ b/modelscope/pipelines/cv/image_portrait_enhancement_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import math from typing import Any, Dict diff --git a/modelscope/pipelines/cv/image_reid_person_pipeline.py b/modelscope/pipelines/cv/image_reid_person_pipeline.py index 64674a65..9f60142a 100644 --- a/modelscope/pipelines/cv/image_reid_person_pipeline.py +++ b/modelscope/pipelines/cv/image_reid_person_pipeline.py @@ -53,6 +53,7 @@ class ImageReidPersonPipeline(Pipeline): def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: img = input['img'] img_embedding = self.model(img) + img_embedding = img_embedding.detach().cpu().numpy() return {OutputKeys.IMG_EMBEDDING: img_embedding} def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: diff --git a/modelscope/pipelines/cv/image_super_resolution_pipeline.py b/modelscope/pipelines/cv/image_super_resolution_pipeline.py index 657acc41..ca8f3209 100644 --- a/modelscope/pipelines/cv/image_super_resolution_pipeline.py +++ b/modelscope/pipelines/cv/image_super_resolution_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict import cv2 diff --git a/modelscope/pipelines/cv/image_to_image_generate_pipeline.py b/modelscope/pipelines/cv/image_to_image_generate_pipeline.py index 2a3881e7..4f0121dd 100644 --- a/modelscope/pipelines/cv/image_to_image_generate_pipeline.py +++ b/modelscope/pipelines/cv/image_to_image_generate_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp from typing import Any, Dict diff --git a/modelscope/pipelines/cv/image_to_image_translation_pipeline.py b/modelscope/pipelines/cv/image_to_image_translation_pipeline.py index 78901c9b..e5f853ca 100644 --- a/modelscope/pipelines/cv/image_to_image_translation_pipeline.py +++ b/modelscope/pipelines/cv/image_to_image_translation_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import io import os.path as osp import sys diff --git a/modelscope/pipelines/cv/movie_scene_segmentation_pipeline.py b/modelscope/pipelines/cv/movie_scene_segmentation_pipeline.py index 6704e4c0..3fffc546 100644 --- a/modelscope/pipelines/cv/movie_scene_segmentation_pipeline.py +++ b/modelscope/pipelines/cv/movie_scene_segmentation_pipeline.py @@ -60,9 +60,12 @@ class MovieSceneSegmentationPipeline(Pipeline): def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: data = {'input_video_pth': self.input_video_pth, 'feat': inputs} - video_num, meta_lst = self.model.postprocess(data) + scene_num, scene_meta_lst, shot_num, shot_meta_lst = self.model.postprocess( + data) result = { - OutputKeys.SPLIT_VIDEO_NUM: video_num, - OutputKeys.SPLIT_META_LIST: meta_lst + OutputKeys.SHOT_NUM: shot_num, + OutputKeys.SHOT_META_LIST: shot_meta_lst, + OutputKeys.SCENE_NUM: scene_num, + OutputKeys.SCENE_META_LIST: scene_meta_lst } return result diff --git a/modelscope/pipelines/cv/ocr_detection_pipeline.py b/modelscope/pipelines/cv/ocr_detection_pipeline.py index 07231efa..292ec2c5 100644 --- a/modelscope/pipelines/cv/ocr_detection_pipeline.py +++ b/modelscope/pipelines/cv/ocr_detection_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp from typing import Any, Dict diff --git a/modelscope/pipelines/cv/ocr_recognition_pipeline.py b/modelscope/pipelines/cv/ocr_recognition_pipeline.py index c20d020c..e81467a1 100644 --- a/modelscope/pipelines/cv/ocr_recognition_pipeline.py +++ b/modelscope/pipelines/cv/ocr_recognition_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import math import os.path as osp from typing import Any, Dict diff --git a/modelscope/pipelines/cv/ocr_utils/model_convnext_transformer.py b/modelscope/pipelines/cv/ocr_utils/model_convnext_transformer.py index cf5e2fe1..6ecff7ef 100644 --- a/modelscope/pipelines/cv/ocr_utils/model_convnext_transformer.py +++ b/modelscope/pipelines/cv/ocr_utils/model_convnext_transformer.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import torch import torch.nn as nn diff --git a/modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py b/modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py index d03ff405..2c2d5b00 100644 --- a/modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py +++ b/modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py @@ -1,3 +1,5 @@ +# Part of the implementation is borrowed and modified from SegLink, +# publicly available at https://github.com/bgshih/seglink import tensorflow as tf from . import ops, resnet18_v1, resnet_utils diff --git a/modelscope/pipelines/cv/ocr_utils/ocr_modules/convnext.py b/modelscope/pipelines/cv/ocr_utils/ocr_modules/convnext.py index c2059107..c0e30616 100644 --- a/modelscope/pipelines/cv/ocr_utils/ocr_modules/convnext.py +++ b/modelscope/pipelines/cv/ocr_utils/ocr_modules/convnext.py @@ -1,11 +1,5 @@ -""" Contains various versions of ConvNext Networks. -ConvNext Networks (ConvNext) were proposed in: - Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell and Saining Xie - A ConvNet for the 2020s. CVPR 2022. -Compared to https://github.com/facebookresearch/ConvNeXt, -we obtain different ConvNext variants by changing the network depth, width, -feature number, and downsample ratio. -""" +# Part of the implementation is borrowed and modified from ConvNext, +# publicly available at https://github.com/facebookresearch/ConvNeXt import torch import torch.nn as nn import torch.nn.functional as F diff --git a/modelscope/pipelines/cv/ocr_utils/ocr_modules/timm_tinyc.py b/modelscope/pipelines/cv/ocr_utils/ocr_modules/timm_tinyc.py index f54c0e78..555b1e42 100644 --- a/modelscope/pipelines/cv/ocr_utils/ocr_modules/timm_tinyc.py +++ b/modelscope/pipelines/cv/ocr_utils/ocr_modules/timm_tinyc.py @@ -1,7 +1,5 @@ -'''Referenced from rwightman's pytorch-image-models(timm). -Github: https://github.com/rwightman/pytorch-image-models -We use some modules and modify the parameters according to our network. -''' +# Part of the implementation is borrowed and modified from timm, +# publicly available at https://github.com/rwightman/pytorch-image-models import collections.abc import logging import math diff --git a/modelscope/pipelines/cv/ocr_utils/ocr_modules/vitstr.py b/modelscope/pipelines/cv/ocr_utils/ocr_modules/vitstr.py index e7d96574..5ce3aeca 100644 --- a/modelscope/pipelines/cv/ocr_utils/ocr_modules/vitstr.py +++ b/modelscope/pipelines/cv/ocr_utils/ocr_modules/vitstr.py @@ -1,10 +1,5 @@ -""" Contains various versions of ViTSTR. -ViTSTR were proposed in: - Rowel Atienza - Vision transformer for fast and efficient scene text recognition. ICDAR 2021. -Compared to https://github.com/roatienza/deep-text-recognition-benchmark, -we obtain different ViTSTR variants by changing the network patch_size and in_chans. -""" +# Part of the implementation is borrowed and modified from ViTSTR, +# publicly available at https://github.com/roatienza/deep-text-recognition-benchmark from __future__ import absolute_import, division, print_function import logging from copy import deepcopy diff --git a/modelscope/pipelines/cv/ocr_utils/ops.py b/modelscope/pipelines/cv/ocr_utils/ops.py index 09807b10..a36838a6 100644 --- a/modelscope/pipelines/cv/ocr_utils/ops.py +++ b/modelscope/pipelines/cv/ocr_utils/ops.py @@ -1,3 +1,5 @@ +# Part of the implementation is borrowed and modified from SegLink, +# publicly available at https://github.com/bgshih/seglink import math import os import shutil diff --git a/modelscope/pipelines/cv/ocr_utils/resnet18_v1.py b/modelscope/pipelines/cv/ocr_utils/resnet18_v1.py index 7930c5a3..85f9faca 100644 --- a/modelscope/pipelines/cv/ocr_utils/resnet18_v1.py +++ b/modelscope/pipelines/cv/ocr_utils/resnet18_v1.py @@ -1,3 +1,17 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== """Contains definitions for the original form of Residual Networks. The 'v1' residual networks (ResNets) implemented in this module were proposed by: diff --git a/modelscope/pipelines/cv/ocr_utils/resnet_utils.py b/modelscope/pipelines/cv/ocr_utils/resnet_utils.py index 0a9af224..2ccbd038 100644 --- a/modelscope/pipelines/cv/ocr_utils/resnet_utils.py +++ b/modelscope/pipelines/cv/ocr_utils/resnet_utils.py @@ -1,3 +1,17 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== """Contains building blocks for various versions of Residual Networks. Residual networks (ResNets) were proposed in: Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun diff --git a/modelscope/pipelines/cv/ocr_utils/utils.py b/modelscope/pipelines/cv/ocr_utils/utils.py index be8e3371..1d0fb297 100644 --- a/modelscope/pipelines/cv/ocr_utils/utils.py +++ b/modelscope/pipelines/cv/ocr_utils/utils.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import cv2 import numpy as np diff --git a/modelscope/pipelines/cv/product_retrieval_embedding_pipeline.py b/modelscope/pipelines/cv/product_retrieval_embedding_pipeline.py index 2614983b..0164a998 100644 --- a/modelscope/pipelines/cv/product_retrieval_embedding_pipeline.py +++ b/modelscope/pipelines/cv/product_retrieval_embedding_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp from typing import Any, Dict diff --git a/modelscope/pipelines/cv/product_segmentation_pipeline.py b/modelscope/pipelines/cv/product_segmentation_pipeline.py index 244b01d7..3b1b2381 100644 --- a/modelscope/pipelines/cv/product_segmentation_pipeline.py +++ b/modelscope/pipelines/cv/product_segmentation_pipeline.py @@ -2,11 +2,14 @@ from typing import Any, Dict +import numpy as np + from modelscope.metainfo import Pipelines from modelscope.models.cv.product_segmentation import seg_infer from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger @@ -28,12 +31,13 @@ class F3NetForProductSegmentationPipeline(Pipeline): logger.info('load model done') def preprocess(self, input: Input) -> Dict[str, Any]: - return input + img = LoadImage.convert_to_ndarray(input['input_path']) + img = img.astype(np.float32) + return img def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: - mask = seg_infer.inference(self.model, self.device, - input['input_path']) + mask = seg_infer.inference(self.model, self.device, input) return {OutputKeys.MASKS: mask} def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: diff --git a/modelscope/pipelines/cv/realtime_video_object_detection_pipeline.py b/modelscope/pipelines/cv/realtime_video_object_detection_pipeline.py new file mode 100644 index 00000000..073fad66 --- /dev/null +++ b/modelscope/pipelines/cv/realtime_video_object_detection_pipeline.py @@ -0,0 +1,61 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from typing import Any, Dict, List, Union + +import cv2 +import json +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.realtime_object_detection import \ + RealtimeVideoDetector +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Input, Model, Pipeline, Tensor +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import load_image +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.video_object_detection, + module_name=Pipelines.realtime_video_object_detection) +class RealtimeVideoObjectDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + super().__init__(model=model, **kwargs) + self.model = RealtimeVideoDetector(model) + + def preprocess(self, input: Input) -> Dict[Tensor, Union[str, np.ndarray]]: + return input + + def forward(self, input: Input) -> Dict[Tensor, Dict[str, np.ndarray]]: + self.video_path = input + # Processing the whole video and return results for each frame + forward_output = self.model.inference_video(self.video_path) + return {'forward_output': forward_output} + + def postprocess(self, input: Dict[Tensor, Dict[str, np.ndarray]], + **kwargs) -> str: + forward_output = input['forward_output'] + + scores, boxes, labels, timestamps = [], [], [], [] + for result in forward_output: + box, score, label, timestamp = result + scores.append(score) + boxes.append(box) + labels.append(label) + timestamps.append(timestamp) + + return { + OutputKeys.BOXES: boxes, + OutputKeys.SCORES: scores, + OutputKeys.LABELS: labels, + OutputKeys.TIMESTAMPS: timestamps, + } diff --git a/modelscope/pipelines/cv/referring_video_object_segmentation_pipeline.py b/modelscope/pipelines/cv/referring_video_object_segmentation_pipeline.py new file mode 100644 index 00000000..d264b386 --- /dev/null +++ b/modelscope/pipelines/cv/referring_video_object_segmentation_pipeline.py @@ -0,0 +1,193 @@ +# The implementation here is modified based on MTTR, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/mttr2021/MTTR +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict + +import numpy as np +import torch +import torchvision +import torchvision.transforms.functional as F +from einops import rearrange +from moviepy.editor import AudioFileClip, ImageSequenceClip, VideoFileClip +from PIL import Image, ImageDraw, ImageFont, ImageOps +from tqdm import tqdm + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.referring_video_object_segmentation, + module_name=Pipelines.referring_video_object_segmentation) +class ReferringVideoObjectSegmentationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """use `model` to create a referring video object segmentation pipeline for prediction + + Args: + model: model id on modelscope hub + """ + _device = kwargs.pop('device', 'gpu') + if torch.cuda.is_available() and _device == 'gpu': + self.device = 'gpu' + else: + self.device = 'cpu' + super().__init__(model=model, device=self.device, **kwargs) + + logger.info('Load model done!') + + def preprocess(self, input: Input) -> Dict[str, Any]: + """ + + Args: + input: path of the input video + + """ + assert isinstance(input, tuple) and len( + input + ) == 4, 'error - input type must be tuple and input length must be 4' + self.input_video_pth, text_queries, start_pt, end_pt = input + + assert 0 < end_pt - start_pt <= 10, 'error - the subclip length must be 0-10 seconds long' + assert 1 <= len( + text_queries) <= 2, 'error - 1-2 input text queries are expected' + + # extract the relevant subclip: + self.input_clip_pth = 'input_clip.mp4' + with VideoFileClip(self.input_video_pth) as video: + subclip = video.subclip(start_pt, end_pt) + subclip.write_videofile(self.input_clip_pth) + + self.window_length = 24 # length of window during inference + self.window_overlap = 6 # overlap (in frames) between consecutive windows + + self.video, audio, self.meta = torchvision.io.read_video( + filename=self.input_clip_pth) + self.video = rearrange(self.video, 't h w c -> t c h w') + + input_video = F.resize(self.video, size=360, max_size=640) + if self.device_name == 'gpu': + input_video = input_video.cuda() + + input_video = input_video.to(torch.float).div_(255) + input_video = F.normalize( + input_video, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + video_metadata = { + 'resized_frame_size': input_video.shape[-2:], + 'original_frame_size': self.video.shape[-2:] + } + + # partition the clip into overlapping windows of frames: + windows = [ + input_video[i:i + self.window_length] + for i in range(0, len(input_video), self.window_length + - self.window_overlap) + ] + # clean up the text queries: + self.text_queries = [' '.join(q.lower().split()) for q in text_queries] + + result = { + 'text_queries': self.text_queries, + 'windows': windows, + 'video_metadata': video_metadata + } + + return result + + def forward(self, input: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + pred_masks_per_query = [] + t, _, h, w = self.video.shape + for text_query in tqdm(input['text_queries'], desc='text queries'): + pred_masks = torch.zeros(size=(t, 1, h, w)) + for i, window in enumerate( + tqdm(input['windows'], desc='windows')): + + window_masks = self.model.inference( + window=window, + text_query=text_query, + metadata=input['video_metadata']) + + win_start_idx = i * ( + self.window_length - self.window_overlap) + pred_masks[win_start_idx:win_start_idx + + self.window_length] = window_masks + pred_masks_per_query.append(pred_masks) + return pred_masks_per_query + + def postprocess(self, inputs) -> Dict[str, Any]: + if self.model.cfg.pipeline.save_masked_video: + # RGB colors for instance masks: + light_blue = (41, 171, 226) + purple = (237, 30, 121) + dark_green = (35, 161, 90) + orange = (255, 148, 59) + colors = np.array([light_blue, purple, dark_green, orange]) + + # width (in pixels) of the black strip above the video on which the text queries will be displayed: + text_border_height_per_query = 36 + + video_np = rearrange(self.video, + 't c h w -> t h w c').numpy() / 255.0 + + # del video + pred_masks_per_frame = rearrange( + torch.stack(inputs), 'q t 1 h w -> t q h w').numpy() + masked_video = [] + for vid_frame, frame_masks in tqdm( + zip(video_np, pred_masks_per_frame), + total=len(video_np), + desc='applying masks...'): + # apply the masks: + for inst_mask, color in zip(frame_masks, colors): + vid_frame = apply_mask(vid_frame, inst_mask, color / 255.0) + vid_frame = Image.fromarray((vid_frame * 255).astype(np.uint8)) + # visualize the text queries: + vid_frame = ImageOps.expand( + vid_frame, + border=(0, len(self.text_queries) + * text_border_height_per_query, 0, 0)) + W, H = vid_frame.size + draw = ImageDraw.Draw(vid_frame) + font = ImageFont.truetype(font='DejaVuSansMono.ttf', size=30) + for i, (text_query, color) in enumerate( + zip(self.text_queries, colors), start=1): + w, h = draw.textsize(text_query, font=font) + draw.text(((W - w) / 2, + (text_border_height_per_query * i) - h - 3), + text_query, + fill=tuple(color) + (255, ), + font=font) + masked_video.append(np.array(vid_frame)) + print(type(vid_frame)) + print(type(masked_video[0])) + print(masked_video[0].shape) + # generate and save the output clip: + + assert self.model.cfg.pipeline.output_path + output_clip_path = self.model.cfg.pipeline.output_path + clip = ImageSequenceClip( + sequence=masked_video, fps=self.meta['video_fps']) + clip = clip.set_audio(AudioFileClip(self.input_clip_pth)) + clip.write_videofile( + output_clip_path, fps=self.meta['video_fps'], audio=True) + del masked_video + + result = {OutputKeys.MASKS: inputs} + return result + + +def apply_mask(image, mask, color, transparency=0.7): + mask = mask[..., np.newaxis].repeat(repeats=3, axis=2) + mask = mask * transparency + color_matrix = np.ones(image.shape, dtype=np.float) * color + out_image = color_matrix * mask + image * (1.0 - mask) + return out_image diff --git a/modelscope/pipelines/cv/tinynas_detection_pipeline.py b/modelscope/pipelines/cv/tinynas_detection_pipeline.py index b2063629..d35d4d36 100644 --- a/modelscope/pipelines/cv/tinynas_detection_pipeline.py +++ b/modelscope/pipelines/cv/tinynas_detection_pipeline.py @@ -12,6 +12,8 @@ from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import LoadImage from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import \ + show_image_object_detection_auto_result from modelscope.utils.logger import get_logger logger = get_logger() @@ -52,10 +54,18 @@ class TinynasDetectionPipeline(Pipeline): bboxes, scores, labels = self.model.postprocess(inputs['data']) if bboxes is None: - return None - outputs = { - OutputKeys.SCORES: scores, - OutputKeys.LABELS: labels, - OutputKeys.BOXES: bboxes - } + outputs = { + OutputKeys.SCORES: [], + OutputKeys.LABELS: [], + OutputKeys.BOXES: [] + } + else: + outputs = { + OutputKeys.SCORES: scores, + OutputKeys.LABELS: labels, + OutputKeys.BOXES: bboxes + } return outputs + + def show_result(self, img_path, result, save_path=None): + show_image_object_detection_auto_result(img_path, result, save_path) diff --git a/modelscope/pipelines/cv/video_summarization_pipeline.py b/modelscope/pipelines/cv/video_summarization_pipeline.py index 001780e1..e4fe206d 100644 --- a/modelscope/pipelines/cv/video_summarization_pipeline.py +++ b/modelscope/pipelines/cv/video_summarization_pipeline.py @@ -1,3 +1,6 @@ +# Part of the implementation is borrowed and modified from PGL-SUM, +# publicly available at https://github.com/e-apostolidis/PGL-SUM + import os.path as osp from typing import Any, Dict @@ -7,7 +10,8 @@ import torch from tqdm import tqdm from modelscope.metainfo import Pipelines -from modelscope.models.cv.video_summarization import PGLVideoSummarization +from modelscope.models.cv.video_summarization import (PGLVideoSummarization, + summary_format) from modelscope.models.cv.video_summarization.base_model import bvlc_googlenet from modelscope.models.cv.video_summarization.summarizer import ( generate_summary, get_change_points) @@ -56,6 +60,8 @@ class VideoSummarizationPipeline(Pipeline): frames = [] picks = [] cap = cv2.VideoCapture(input) + self.fps = cap.get(cv2.CAP_PROP_FPS) + self.frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT) frame_idx = 0 while (cap.isOpened()): ret, frame = cap.read() @@ -88,7 +94,9 @@ class VideoSummarizationPipeline(Pipeline): summary = self.inference(frame_features, input['n_frame'], input['picks'], change_points) - return {OutputKeys.OUTPUT: summary} + output = summary_format(summary, self.fps) + + return {OutputKeys.OUTPUT: output} def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return inputs diff --git a/modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py b/modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py index 76011be0..d3f15c23 100644 --- a/modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py +++ b/modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py @@ -11,6 +11,8 @@ from modelscope.utils.logger import get_logger logger = get_logger() +@PIPELINES.register_module( + Tasks.image_text_retrieval, module_name=Pipelines.multi_modal_embedding) @PIPELINES.register_module( Tasks.multi_modal_embedding, module_name=Pipelines.multi_modal_embedding) class MultiModalEmbeddingPipeline(Pipeline): diff --git a/modelscope/pipelines/multi_modal/ocr_recognition_pipeline.py b/modelscope/pipelines/multi_modal/ocr_recognition_pipeline.py new file mode 100644 index 00000000..c61b38f3 --- /dev/null +++ b/modelscope/pipelines/multi_modal/ocr_recognition_pipeline.py @@ -0,0 +1,52 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.multi_modal import OfaForAllTasks +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import OfaPreprocessor, Preprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.ocr_recognition, module_name=Pipelines.ofa_ocr_recognition) +class OcrRecognitionPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """ + use `model` and `preprocessor` to create a ocr recognition pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + assert isinstance(model, str) or isinstance(model, Model), \ + 'model must be a single str or OfaForAllTasks' + if isinstance(model, str): + pipe_model = Model.from_pretrained(model) + elif isinstance(model, Model): + pipe_model = model + else: + raise NotImplementedError + pipe_model.model.eval() + if preprocessor is None: + if isinstance(pipe_model, OfaForAllTasks): + preprocessor = OfaPreprocessor(pipe_model.model_dir) + super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + return super().forward(inputs, **forward_params) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index 5267b5b2..7b726308 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -4,23 +4,26 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: + from .automatic_post_editing_pipeline import AutomaticPostEditingPipeline from .conversational_text_to_sql_pipeline import ConversationalTextToSqlPipeline from .table_question_answering_pipeline import TableQuestionAnsweringPipeline from .dialog_intent_prediction_pipeline import DialogIntentPredictionPipeline from .dialog_modeling_pipeline import DialogModelingPipeline from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline from .document_segmentation_pipeline import DocumentSegmentationPipeline + from .fasttext_sequence_classification_pipeline import FasttextSequenceClassificationPipeline from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline from .feature_extraction_pipeline import FeatureExtractionPipeline from .fill_mask_pipeline import FillMaskPipeline - from .fill_mask_ponet_pipeline import FillMaskPonetPipeline from .information_extraction_pipeline import InformationExtractionPipeline - from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline - from .passage_ranking_pipeline import PassageRankingPipeline + from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline, \ + NamedEntityRecognitionThaiPipeline, \ + NamedEntityRecognitionVietPipeline + from .text_ranking_pipeline import TextRankingPipeline from .sentence_embedding_pipeline import SentenceEmbeddingPipeline - from .sequence_classification_pipeline import SequenceClassificationPipeline - from .summarization_pipeline import SummarizationPipeline from .text_classification_pipeline import TextClassificationPipeline + from .summarization_pipeline import SummarizationPipeline + from .translation_quality_estimation_pipeline import TranslationQualityEstimationPipeline from .text_error_correction_pipeline import TextErrorCorrectionPipeline from .text_generation_pipeline import TextGenerationPipeline from .text2text_generation_pipeline import Text2TextGenerationPipeline @@ -28,38 +31,50 @@ if TYPE_CHECKING: from .translation_pipeline import TranslationPipeline from .word_segmentation_pipeline import WordSegmentationPipeline from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline + from .multilingual_word_segmentation_pipeline import MultilingualWordSegmentationPipeline, \ + WordSegmentationThaiPipeline else: _import_structure = { + 'automatic_post_editing_pipeline': ['AutomaticPostEditingPipeline'], 'conversational_text_to_sql_pipeline': ['ConversationalTextToSqlPipeline'], - 'table_question_answering_pipeline': - ['TableQuestionAnsweringPipeline'], 'dialog_intent_prediction_pipeline': ['DialogIntentPredictionPipeline'], 'dialog_modeling_pipeline': ['DialogModelingPipeline'], 'dialog_state_tracking_pipeline': ['DialogStateTrackingPipeline'], + 'domain_classification_pipeline': + ['FasttextSequenceClassificationPipeline'], 'document_segmentation_pipeline': ['DocumentSegmentationPipeline'], 'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'], 'feature_extraction_pipeline': ['FeatureExtractionPipeline'], 'fill_mask_pipeline': ['FillMaskPipeline'], - 'fill_mask_ponet_pipeline': ['FillMaskPoNetPipeline'], 'information_extraction_pipeline': ['InformationExtractionPipeline'], - 'named_entity_recognition_pipeline': - ['NamedEntityRecognitionPipeline'], - 'passage_ranking_pipeline': ['PassageRankingPipeline'], + 'named_entity_recognition_pipeline': [ + 'NamedEntityRecognitionPipeline', + 'NamedEntityRecognitionThaiPipeline', + 'NamedEntityRecognitionVietPipeline' + ], + 'text_ranking_pipeline': ['TextRankingPipeline'], 'sentence_embedding_pipeline': ['SentenceEmbeddingPipeline'], - 'sequence_classification_pipeline': ['SequenceClassificationPipeline'], 'summarization_pipeline': ['SummarizationPipeline'], + 'table_question_answering_pipeline': + ['TableQuestionAnsweringPipeline'], 'text_classification_pipeline': ['TextClassificationPipeline'], 'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'], 'text_generation_pipeline': ['TextGenerationPipeline'], 'text2text_generation_pipeline': ['Text2TextGenerationPipeline'], 'token_classification_pipeline': ['TokenClassificationPipeline'], 'translation_pipeline': ['TranslationPipeline'], + 'translation_quality_estimation_pipeline': + ['TranslationQualityEstimationPipeline'], 'word_segmentation_pipeline': ['WordSegmentationPipeline'], 'zero_shot_classification_pipeline': ['ZeroShotClassificationPipeline'], + 'multilingual_word_segmentation_pipeline': [ + 'MultilingualWordSegmentationPipeline', + 'WordSegmentationThaiPipeline' + ], } import sys diff --git a/modelscope/pipelines/nlp/automatic_post_editing_pipeline.py b/modelscope/pipelines/nlp/automatic_post_editing_pipeline.py new file mode 100644 index 00000000..83968586 --- /dev/null +++ b/modelscope/pipelines/nlp/automatic_post_editing_pipeline.py @@ -0,0 +1,158 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from html import unescape +from typing import Any, Dict + +import jieba +import numpy as np +import tensorflow as tf +from sacremoses import (MosesDetokenizer, MosesDetruecaser, + MosesPunctNormalizer, MosesTokenizer, MosesTruecaser) +from sentencepiece import SentencePieceProcessor +from tensorflow.contrib.seq2seq.python.ops import beam_search_ops + +from modelscope.metainfo import Pipelines +from modelscope.models.base import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.config import Config, ConfigFields +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + tf.disable_eager_execution() + +logger = get_logger() + +__all__ = ['AutomaticPostEditingPipeline'] + + +@PIPELINES.register_module( + Tasks.translation, module_name=Pipelines.automatic_post_editing) +class AutomaticPostEditingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """Build an automatic post editing pipeline with a model dir. + + @param model: Model path for saved pb file + """ + super().__init__(model=model, **kwargs) + export_dir = model + self.cfg = Config.from_file( + os.path.join(export_dir, ModelFile.CONFIGURATION)) + joint_vocab_file = os.path.join( + export_dir, self.cfg[ConfigFields.preprocessor]['vocab']) + self.vocab = dict([(w.strip(), i) for i, w in enumerate( + open(joint_vocab_file, 'r', encoding='utf8'))]) + self.vocab_reverse = dict([(i, w.strip()) for i, w in enumerate( + open(joint_vocab_file, 'r', encoding='utf8'))]) + self.unk_id = self.cfg[ConfigFields.preprocessor].get('unk_id', -1) + strip_unk = self.cfg.get(ConfigFields.postprocessor, + {}).get('strip_unk', True) + self.unk_token = '' if strip_unk else self.cfg.get( + ConfigFields.postprocessor, {}).get('unk_token', '') + if self.unk_id == -1: + self.unk_id = len(self.vocab) - 1 + tf.reset_default_graph() + tf_config = tf.ConfigProto(allow_soft_placement=True) + tf_config.gpu_options.allow_growth = True + self._session = tf.Session(config=tf_config) + tf.saved_model.loader.load( + self._session, [tf.python.saved_model.tag_constants.SERVING], + export_dir) + default_graph = tf.get_default_graph() + self.input_src_id_placeholder = default_graph.get_tensor_by_name( + 'Placeholder:0') + self.input_src_len_placeholder = default_graph.get_tensor_by_name( + 'Placeholder_1:0') + self.input_mt_id_placeholder = default_graph.get_tensor_by_name( + 'Placeholder_2:0') + self.input_mt_len_placeholder = default_graph.get_tensor_by_name( + 'Placeholder_3:0') + output_id_beam = default_graph.get_tensor_by_name( + 'enc2enc/decoder/transpose:0') + output_len_beam = default_graph.get_tensor_by_name( + 'enc2enc/decoder/Minimum:0') + output_id = tf.cast( + tf.map_fn(lambda x: x[0], output_id_beam), dtype=tf.int64) + output_len = tf.map_fn(lambda x: x[0], output_len_beam) + self.output = {'output_ids': output_id, 'output_lens': output_len} + init = tf.global_variables_initializer() + local_init = tf.local_variables_initializer() + self._session.run([init, local_init]) + tf.saved_model.loader.load( + self._session, [tf.python.saved_model.tag_constants.SERVING], + export_dir) + + # preprocess + self._src_lang = self.cfg[ConfigFields.preprocessor]['src_lang'] + self._tgt_lang = self.cfg[ConfigFields.preprocessor]['tgt_lang'] + tok_escape = self.cfg[ConfigFields.preprocessor].get( + 'tokenize_escape', False) + src_tokenizer = MosesTokenizer(lang=self._src_lang) + mt_tokenizer = MosesTokenizer(lang=self._tgt_lang) + truecase_model = os.path.join( + export_dir, self.cfg[ConfigFields.preprocessor]['truecaser']) + truecaser = MosesTruecaser(load_from=truecase_model) + sp_model = os.path.join( + export_dir, self.cfg[ConfigFields.preprocessor]['sentencepiece']) + sp = SentencePieceProcessor() + sp.load(sp_model) + + self.src_preprocess = lambda x: ' '.join( + sp.encode_as_pieces( + truecaser.truecase( + src_tokenizer.tokenize( + x, return_str=True, escape=tok_escape), + return_str=True))) + self.mt_preprocess = lambda x: ' '.join( + sp.encode_as_pieces( + truecaser.truecase( + mt_tokenizer.tokenize( + x, return_str=True, escape=tok_escape), + return_str=True))) + + # post process, de-bpe, de-truecase, detok + detruecaser = MosesDetruecaser() + detokenizer = MosesDetokenizer(lang=self._tgt_lang) + self.postprocess_fun = lambda x: detokenizer.detokenize( + detruecaser.detruecase( + x.replace(' ▁', '@@').replace(' ', '').replace('@@', ' '). + strip()[1:], + return_str=True).split()) + + def preprocess(self, input: str) -> Dict[str, Any]: + src, mt = input.split('\005', 1) + src_sp, mt_sp = self.src_preprocess(src), self.mt_preprocess(mt) + input_src_ids = np.array( + [[self.vocab.get(w, self.unk_id) for w in src_sp.strip().split()]]) + input_mt_ids = np.array( + [[self.vocab.get(w, self.unk_id) for w in mt_sp.strip().split()]]) + input_src_lens = [len(x) for x in input_src_ids] + input_mt_lens = [len(x) for x in input_mt_ids] + feed_dict = { + self.input_src_id_placeholder: input_src_ids, + self.input_mt_id_placeholder: input_mt_ids, + self.input_src_len_placeholder: input_src_lens, + self.input_mt_len_placeholder: input_mt_lens + } + return feed_dict + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + with self._session.as_default(): + sess_outputs = self._session.run(self.output, feed_dict=input) + return sess_outputs + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + output_ids, output_len = inputs['output_ids'][0], inputs[ + 'output_lens'][0] + output_ids = output_ids[:output_len - 1] # -1 for + output_tokens = ' '.join([ + self.vocab_reverse.get(wid, self.unk_token) for wid in output_ids + ]) + post_editing_output = self.postprocess_fun(output_tokens) + result = {OutputKeys.TRANSLATION: post_editing_output} + return result diff --git a/modelscope/pipelines/nlp/conversational_text_to_sql_pipeline.py b/modelscope/pipelines/nlp/conversational_text_to_sql_pipeline.py index c46e8c81..48df0c40 100644 --- a/modelscope/pipelines/nlp/conversational_text_to_sql_pipeline.py +++ b/modelscope/pipelines/nlp/conversational_text_to_sql_pipeline.py @@ -11,15 +11,13 @@ from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import ConversationalTextToSqlPreprocessor -from modelscope.preprocessors.star.fields import (SubPreprocessor, - process_tables) from modelscope.utils.constant import Tasks __all__ = ['ConversationalTextToSqlPipeline'] @PIPELINES.register_module( - Tasks.conversational_text_to_sql, + Tasks.table_question_answering, module_name=Pipelines.conversational_text_to_sql) class ConversationalTextToSqlPipeline(Pipeline): @@ -39,17 +37,6 @@ class ConversationalTextToSqlPipeline(Pipeline): if preprocessor is None: preprocessor = ConversationalTextToSqlPreprocessor(model.model_dir) - preprocessor.device = 'cuda' if \ - ('device' not in kwargs or kwargs['device'] == 'gpu') \ - and torch.cuda.is_available() else 'cpu' - use_device = True if preprocessor.device == 'cuda' else False - preprocessor.processor = \ - SubPreprocessor(model_dir=model.model_dir, - db_content=True, - use_gpu=use_device) - preprocessor.output_tables = \ - process_tables(preprocessor.processor, - preprocessor.tables) super().__init__(model=model, preprocessor=preprocessor, **kwargs) def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: @@ -62,7 +49,7 @@ class ConversationalTextToSqlPipeline(Pipeline): Dict[str, str]: the prediction results """ sql = Example.evaluator.obtain_sql(inputs['predict'][0], inputs['db']) - result = {OutputKeys.TEXT: sql} + result = {OutputKeys.OUTPUT: {OutputKeys.TEXT: sql}} return result def _collate_fn(self, data): diff --git a/modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py b/modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py index 79d32ace..9520c06f 100644 --- a/modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py +++ b/modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Union from modelscope.metainfo import Pipelines from modelscope.models import Model -from modelscope.models.nlp import SpaceForDialogStateTracking +from modelscope.models.nlp import SpaceForDST from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Pipeline from modelscope.pipelines.builder import PIPELINES @@ -20,7 +20,7 @@ __all__ = ['DialogStateTrackingPipeline'] class DialogStateTrackingPipeline(Pipeline): def __init__(self, - model: Union[SpaceForDialogStateTracking, str], + model: Union[SpaceForDST, str], preprocessor: DialogStateTrackingPreprocessor = None, **kwargs): """use `model` and `preprocessor` to create a dialog state tracking pipeline for @@ -33,8 +33,7 @@ class DialogStateTrackingPipeline(Pipeline): """ model = model if isinstance( - model, - SpaceForDialogStateTracking) else Model.from_pretrained(model) + model, SpaceForDST) else Model.from_pretrained(model) self.model = model if preprocessor is None: preprocessor = DialogStateTrackingPreprocessor(model.model_dir) diff --git a/modelscope/pipelines/nlp/distributed_gpt3_pipeline.py b/modelscope/pipelines/nlp/distributed_gpt3_pipeline.py new file mode 100644 index 00000000..325d3303 --- /dev/null +++ b/modelscope/pipelines/nlp/distributed_gpt3_pipeline.py @@ -0,0 +1,54 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.nlp.gpt3.distributed_gpt3 import DistributedGPT3 +from modelscope.pipelines.base import DistributedPipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import TextGenerationJiebaPreprocessor +from modelscope.utils.constant import Tasks + + +@PIPELINES.register_module( + Tasks.text_generation, module_name=Pipelines.gpt3_generation) +class DistributedGPT3Pipeline(DistributedPipeline): + """This class is used to instantiate the gpt3 model. + """ + + model = None + + def __init__(self, model, preprocessor=None, **kwargs): + if preprocessor is None: + preprocessor = TextGenerationJiebaPreprocessor(model) + super().__init__(model, preprocessor=preprocessor, **kwargs) + assert hasattr(preprocessor, 'tokenizer') + + @classmethod + def _instantiate_one(cls, rank, model_dir, **kwargs): + cls.model = DistributedGPT3(model_dir, rank, **kwargs) + cls.model.eval() + + @classmethod + def _forward_one(cls, inputs: Dict[str, Any]) -> Dict[str, Any]: + tokens = inputs['inputs']['input_ids'].cuda( + torch.cuda.current_device()) + return cls.model.generate(tokens) + + def postprocess(self, inputs: Dict[str, Any], + **postprocess_params) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + from modelscope.outputs import OutputKeys + return { + OutputKeys.TEXT: + self.preprocessor.tokenizer.detokenize(inputs[0].tolist()) + } diff --git a/modelscope/pipelines/nlp/distributed_plug_pipeline.py b/modelscope/pipelines/nlp/distributed_plug_pipeline.py index e5c05e86..8499f7ff 100644 --- a/modelscope/pipelines/nlp/distributed_plug_pipeline.py +++ b/modelscope/pipelines/nlp/distributed_plug_pipeline.py @@ -27,7 +27,8 @@ class DistributedPlugPipeline(DistributedPipeline): **kwargs): """Create a plug pipeline instance. - @param model: The model_id of plug(damo/nlp_plug_text-generation_27B). + Args: + model: The model_id of plug(damo/nlp_plug_text-generation_27B). The default path to damo/nlp_plug_text-generation_27B can be obtained by function get_cache_dir("damo/nlp_plug_text-generation_27B"), the model should be downloaded to this path before calling this class by model_id. @@ -52,11 +53,11 @@ class DistributedPlugPipeline(DistributedPipeline): |_ mp_rank_05_model_states.pt |_ mp_rank_06_model_states.pt |_ mp_rank_07_model_states.pt - @param preprocessor: The optional preprocessor, if not passed in, a TextGenerationPreprocessor will + preprocessor: The optional preprocessor, if not passed in, a TextGenerationPreprocessor will be used as default. - @param first_sequence: The first_sequence key name if the input format is a dict. - @param kwargs: - sequence_length: The input sequence_length. + first_sequence: The first_sequence key name if the input format is a dict. + kwargs: + sequence_length: The input sequence_length. """ if preprocessor is None: preprocessor = TextGenerationPreprocessor( diff --git a/modelscope/pipelines/nlp/faq_question_answering_pipeline.py b/modelscope/pipelines/nlp/faq_question_answering_pipeline.py index 1d46d8fd..fd614e91 100644 --- a/modelscope/pipelines/nlp/faq_question_answering_pipeline.py +++ b/modelscope/pipelines/nlp/faq_question_answering_pipeline.py @@ -2,15 +2,12 @@ from typing import Any, Dict, Union -import torch - from modelscope.metainfo import Pipelines from modelscope.models import Model -from modelscope.models.nlp import SbertForFaqQuestionAnswering from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Pipeline from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import FaqQuestionAnsweringPreprocessor +from modelscope.preprocessors import Preprocessor from modelscope.utils.constant import Tasks __all__ = ['FaqQuestionAnsweringPipeline'] @@ -21,19 +18,19 @@ __all__ = ['FaqQuestionAnsweringPipeline'] class FaqQuestionAnsweringPipeline(Pipeline): def __init__(self, - model: Union[str, SbertForFaqQuestionAnswering], - preprocessor: FaqQuestionAnsweringPreprocessor = None, + model: Union[str, Model], + preprocessor: Preprocessor = None, **kwargs): - model = model if isinstance( - model, - SbertForFaqQuestionAnswering) else Model.from_pretrained(model) - model.eval() + model = Model.from_pretrained(model) if isinstance(model, + str) else model if preprocessor is None: - preprocessor = FaqQuestionAnsweringPreprocessor( + preprocessor = Preprocessor.from_pretrained( model.model_dir, **kwargs) - self.preprocessor = preprocessor - super(FaqQuestionAnsweringPipeline, self).__init__( - model=model, preprocessor=preprocessor, **kwargs) + if preprocessor is None: + from modelscope.preprocessors import FaqQuestionAnsweringPreprocessor + preprocessor = FaqQuestionAnsweringPreprocessor( + model.model_dir, **kwargs) + super().__init__(model=model, preprocessor=preprocessor, **kwargs) def _sanitize_parameters(self, **pipeline_parameters): return pipeline_parameters, pipeline_parameters, pipeline_parameters @@ -46,8 +43,7 @@ class FaqQuestionAnsweringPipeline(Pipeline): def forward(self, inputs: [list, Dict[str, Any]], **forward_params) -> Dict[str, Any]: - with torch.no_grad(): - return self.model(inputs) + return self.model(inputs) def postprocess(self, inputs: [list, Dict[str, Any]], **postprocess_params) -> Dict[str, Any]: diff --git a/modelscope/pipelines/nlp/fasttext_sequence_classification_pipeline.py b/modelscope/pipelines/nlp/fasttext_sequence_classification_pipeline.py new file mode 100644 index 00000000..f10af88f --- /dev/null +++ b/modelscope/pipelines/nlp/fasttext_sequence_classification_pipeline.py @@ -0,0 +1,69 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Any, Dict, Union + +import numpy as np +import sentencepiece +from fasttext import load_model +from fasttext.FastText import _FastText + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import SequenceClassificationPreprocessor +from modelscope.utils.constant import ModelFile, Tasks + +__all__ = ['FasttextSequenceClassificationPipeline'] + + +def sentencepiece_tokenize(sp_model, sent): + tokens = [] + for t in sp_model.EncodeAsPieces(sent): + s = t.strip() + if s: + tokens.append(s) + return ' '.join(tokens) + + +@PIPELINES.register_module( + Tasks.text_classification, module_name=Pipelines.domain_classification) +class FasttextSequenceClassificationPipeline(Pipeline): + + def __init__(self, model: Union[str, _FastText], **kwargs): + """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction + + Args: + model: a model directory including model.bin and spm.model + preprocessor (SequenceClassificationPreprocessor): a preprocessor instance + """ + super().__init__(model=model) + model_file = os.path.join(model, ModelFile.TORCH_MODEL_BIN_FILE) + spm_file = os.path.join(model, 'sentencepiece.model') + assert os.path.isdir(model) and os.path.exists(model_file) and os.path.exists(spm_file), \ + '`model` should be a directory contains `model.bin` and `sentencepiece.model`' + self.model = load_model(model_file) + self.spm = sentencepiece.SentencePieceProcessor() + self.spm.Load(spm_file) + + def preprocess(self, inputs: str) -> Dict[str, Any]: + text = inputs.strip() + text_sp = sentencepiece_tokenize(self.spm, text) + return {'text_sp': text_sp, 'text': text} + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + topk = inputs.get('topk', -1) + label, probs = self.model.predict(inputs['text_sp'], k=topk) + label = [x.replace('__label__', '') for x in label] + result = { + OutputKeys.LABEL: label[0], + OutputKeys.SCORE: probs[0], + OutputKeys.LABELS: label, + OutputKeys.SCORES: probs + } + return result + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/nlp/feature_extraction_pipeline.py b/modelscope/pipelines/nlp/feature_extraction_pipeline.py index 3af0c28d..e94e4337 100644 --- a/modelscope/pipelines/nlp/feature_extraction_pipeline.py +++ b/modelscope/pipelines/nlp/feature_extraction_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os from typing import Any, Dict, Optional, Union diff --git a/modelscope/pipelines/nlp/fill_mask_pipeline.py b/modelscope/pipelines/nlp/fill_mask_pipeline.py index 3d515e2d..0f3446e6 100644 --- a/modelscope/pipelines/nlp/fill_mask_pipeline.py +++ b/modelscope/pipelines/nlp/fill_mask_pipeline.py @@ -1,145 +1,103 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import os from typing import Any, Dict, Optional, Union -import torch +import numpy as np from modelscope.metainfo import Pipelines from modelscope.models import Model from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Pipeline, Tensor from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import NLPPreprocessor, Preprocessor -from modelscope.utils.config import Config -from modelscope.utils.constant import ModelFile, Tasks +from modelscope.preprocessors import Preprocessor +from modelscope.utils.constant import Tasks __all__ = ['FillMaskPipeline'] -_type_map = { - 'veco': 'roberta', - 'sbert': 'bert', -} @PIPELINES.register_module(Tasks.fill_mask, module_name=Pipelines.fill_mask) +@PIPELINES.register_module( + Tasks.fill_mask, module_name=Pipelines.fill_mask_ponet) class FillMaskPipeline(Pipeline): def __init__(self, model: Union[Model, str], preprocessor: Optional[Preprocessor] = None, - first_sequence='sentence', + first_sequence: str = 'sentence', **kwargs): - """Use `model` and `preprocessor` to create a nlp fill mask pipeline for prediction + """The inference pipeline for all the fill mask sub-tasks. Args: - model (str or Model): Supply either a local model dir which supported mlm task, or a - mlm model id from the model hub, or a torch model instance. - preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for - the model if supplied. - first_sequence: The key to read the sentence in. - sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value. - - NOTE: Inputs of type 'str' are also supported. In this scenario, the 'first_sequence' + model (`str` or `Model` or module instance): A model instance or a model local dir + or a model id in the model hub. + preprocessor (`Preprocessor`, `optional`): A Preprocessor instance. + first_sequence (`str`, `optional`): The key to read the sentence in. + sequence_length (`int`, `optional`): Max sequence length in the user's custom scenario, default 128. + + NOTE1: Inputs of type 'str' are also supported. In this scenario, the 'first_sequence' param will have no effect. - Example: + Example1: >>> from modelscope.pipelines import pipeline >>> pipeline_ins = pipeline('fill-mask', model='damo/nlp_structbert_fill-mask_english-large') >>> input = 'Everything in [MASK] you call reality is really [MASK] a reflection of your [MASK].' >>> print(pipeline_ins(input)) + Example2: + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline('fill-mask', model='damo/nlp_ponet_fill-mask_english-base') + >>> input = 'Everything in [MASK] you call reality is really [MASK] a reflection of your [MASK].' + >>> print(pipeline_ins(input)) NOTE2: Please pay attention to the model's special tokens. If bert based model(bert, structbert, etc.) is used, the mask token is '[MASK]'. If the xlm-roberta(xlm-roberta, veco, etc.) based model is used, the mask token is ''. To view other examples plese check the tests/pipelines/test_fill_mask.py. """ - fill_mask_model = model if isinstance( - model, Model) else Model.from_pretrained(model) + + fill_mask_model = Model.from_pretrained(model) if isinstance( + model, str) else model if preprocessor is None: - preprocessor = NLPPreprocessor( + preprocessor = Preprocessor.from_pretrained( fill_mask_model.model_dir, first_sequence=first_sequence, second_sequence=None, sequence_length=kwargs.pop('sequence_length', 128)) fill_mask_model.eval() + assert hasattr( + preprocessor, 'mask_id' + ), 'The input preprocessor should have the mask_id attribute.' super().__init__( model=fill_mask_model, preprocessor=preprocessor, **kwargs) - self.preprocessor = preprocessor - self.config = Config.from_file( - os.path.join(fill_mask_model.model_dir, ModelFile.CONFIGURATION)) - self.tokenizer = preprocessor.tokenizer - self.mask_id = {'roberta': 250001, 'bert': 103, 'deberta_v2': 4} - - self.rep_map = { - 'bert': { - '[unused0]': '', - '[PAD]': '', - '[unused1]': '', - r' +': ' ', - '[SEP]': '', - '[unused2]': '', - '[CLS]': '', - '[UNK]': '' - }, - 'roberta': { - r' +': ' ', - '': '', - '': '', - '': '', - '': '', - '': ' ' - }, - 'deberta_v2': { - '[PAD]': '', - r' +': ' ', - '[SEP]': '', - '[CLS]': '', - '[UNK]': '' - }, - } - def forward(self, inputs: Dict[str, Any], **forward_params) -> Dict[str, Any]: - with torch.no_grad(): - return self.model(**inputs, **forward_params) + return self.model(**inputs, **forward_params) def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: """process the prediction results Args: - inputs (Dict[str, Any]): _description_ - + inputs (Dict[str, Any]): The model outputs. + The output should follow some rules: + 1. Values can be retrieved by keys(dict-like, or the __getitem__ method is overriden) + 2. 'logits' and 'input_ids' key exists. + Models in modelscope will return the output dataclass `modelscope.outputs.FillMaskModelOutput`. Returns: Dict[str, str]: the prediction results """ - import numpy as np logits = inputs[OutputKeys.LOGITS].detach().cpu().numpy() input_ids = inputs[OutputKeys.INPUT_IDS].detach().cpu().numpy() pred_ids = np.argmax(logits, axis=-1) - if hasattr(self.model.config, 'backbone'): - model_type = self.model.config.backbone.type - else: - model_type = self.model.config.model_type - process_type = model_type if model_type in self.mask_id else _type_map[ - model_type] - rst_ids = np.where(input_ids == self.mask_id[process_type], pred_ids, + rst_ids = np.where(input_ids == self.preprocessor.mask_id, pred_ids, input_ids) - def rep_tokens(string, rep_map): - for k, v in rep_map.items(): - string = string.replace(k, v) - return string.strip() - pred_strings = [] for ids in rst_ids: # batch - if 'language' in self.config.model and self.config.model.language == 'zh': - pred_string = self.tokenizer.convert_ids_to_tokens(ids) - pred_string = ''.join(pred_string) - else: - pred_string = self.tokenizer.decode(ids) - pred_string = rep_tokens(pred_string, self.rep_map[process_type]) + pred_string = self.preprocessor.decode( + ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=True) pred_strings.append(pred_string) return {OutputKeys.TEXT: pred_strings} diff --git a/modelscope/pipelines/nlp/fill_mask_ponet_pipeline.py b/modelscope/pipelines/nlp/fill_mask_ponet_pipeline.py deleted file mode 100644 index 9770fc38..00000000 --- a/modelscope/pipelines/nlp/fill_mask_ponet_pipeline.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import os -from typing import Any, Dict, Optional, Union - -import torch - -from modelscope.metainfo import Pipelines -from modelscope.models import Model -from modelscope.outputs import OutputKeys -from modelscope.pipelines.base import Pipeline, Tensor -from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import FillMaskPoNetPreprocessor, Preprocessor -from modelscope.utils.config import Config -from modelscope.utils.constant import ModelFile, Tasks - -__all__ = ['FillMaskPonetPipeline'] -_type_map = {'ponet': 'bert'} - - -@PIPELINES.register_module( - Tasks.fill_mask, module_name=Pipelines.fill_mask_ponet) -class FillMaskPonetPipeline(Pipeline): - - def __init__(self, - model: Union[Model, str], - preprocessor: Optional[Preprocessor] = None, - first_sequence='sentence', - **kwargs): - """Use `model` and `preprocessor` to create a nlp fill mask pipeline for prediction - - Args: - model (str or Model): Supply either a local model dir which supported fill-mask task, - or a fill-mask model id from the model hub, or a torch model instance. - preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for - the model if supplied. - first_sequence: The key to read the sentence in. - - NOTE: Inputs of type 'str' are also supported. In this scenario, the 'first_sequence' - param will have no effect. - - Example: - >>> from modelscope.pipelines import pipeline - >>> pipeline_ins = pipeline( - 'fill-mask', model='damo/nlp_ponet_fill-mask_english-base') - >>> input = 'Everything in [MASK] you call reality is really [MASK] a reflection of your [MASK].' - >>> print(pipeline_ins(input)) - - NOTE2: Please pay attention to the model's special tokens. - If bert based model(bert, structbert, etc.) is used, the mask token is '[MASK]'. - If the xlm-roberta(xlm-roberta, veco, etc.) based model is used, the mask token is ''. - To view other examples plese check the tests/pipelines/test_fill_mask.py. - """ - fill_mask_model = model if isinstance( - model, Model) else Model.from_pretrained(model) - - self.config = Config.from_file( - os.path.join(fill_mask_model.model_dir, ModelFile.CONFIGURATION)) - - if preprocessor is None: - preprocessor = FillMaskPoNetPreprocessor( - fill_mask_model.model_dir, - first_sequence=first_sequence, - second_sequence=None, - sequence_length=kwargs.pop('sequence_length', 512)) - - fill_mask_model.eval() - super().__init__( - model=fill_mask_model, preprocessor=preprocessor, **kwargs) - - self.preprocessor = preprocessor - - self.tokenizer = preprocessor.tokenizer - self.mask_id = {'roberta': 250001, 'bert': 103} - - self.rep_map = { - 'bert': { - '[unused0]': '', - '[PAD]': '', - '[unused1]': '', - r' +': ' ', - '[SEP]': '', - '[unused2]': '', - '[CLS]': '', - '[UNK]': '' - }, - 'roberta': { - r' +': ' ', - '': '', - '': '', - '': '', - '': '', - '': ' ' - } - } - - def forward(self, inputs: Dict[str, Any], - **forward_params) -> Dict[str, Any]: - with torch.no_grad(): - return self.model(**inputs, **forward_params) - - def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: - """process the prediction results - - Args: - inputs (Dict[str, Any]): _description_ - - Returns: - Dict[str, str]: the prediction results - """ - import numpy as np - logits = inputs[OutputKeys.LOGITS].detach().cpu().numpy() - input_ids = inputs[OutputKeys.INPUT_IDS].detach().cpu().numpy() - pred_ids = np.argmax(logits, axis=-1) - model_type = self.model.config.model_type - process_type = model_type if model_type in self.mask_id else _type_map[ - model_type] - rst_ids = np.where(input_ids == self.mask_id[process_type], pred_ids, - input_ids) - - def rep_tokens(string, rep_map): - for k, v in rep_map.items(): - string = string.replace(k, v) - return string.strip() - - pred_strings = [] - for ids in rst_ids: # batch - if 'language' in self.config.model and self.config.model.language == 'zh': - pred_string = self.tokenizer.convert_ids_to_tokens(ids) - pred_string = ''.join(pred_string) - else: - pred_string = self.tokenizer.decode(ids) - pred_string = rep_tokens(pred_string, self.rep_map[process_type]) - pred_strings.append(pred_string) - - return {OutputKeys.TEXT: pred_strings} diff --git a/modelscope/pipelines/nlp/information_extraction_pipeline.py b/modelscope/pipelines/nlp/information_extraction_pipeline.py index 763e941c..8ac85f43 100644 --- a/modelscope/pipelines/nlp/information_extraction_pipeline.py +++ b/modelscope/pipelines/nlp/information_extraction_pipeline.py @@ -17,6 +17,8 @@ __all__ = ['InformationExtractionPipeline'] @PIPELINES.register_module( Tasks.information_extraction, module_name=Pipelines.relation_extraction) +@PIPELINES.register_module( + Tasks.relation_extraction, module_name=Pipelines.relation_extraction) class InformationExtractionPipeline(Pipeline): def __init__(self, diff --git a/modelscope/pipelines/nlp/multilingual_word_segmentation_pipeline.py b/modelscope/pipelines/nlp/multilingual_word_segmentation_pipeline.py new file mode 100644 index 00000000..56c3a041 --- /dev/null +++ b/modelscope/pipelines/nlp/multilingual_word_segmentation_pipeline.py @@ -0,0 +1,125 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import (Preprocessor, + TokenClassificationPreprocessor, + WordSegmentationPreprocessorThai) +from modelscope.utils.constant import Tasks + +__all__ = [ + 'MultilingualWordSegmentationPipeline', 'WordSegmentationThaiPipeline' +] + + +@PIPELINES.register_module( + Tasks.word_segmentation, + module_name=Pipelines.multilingual_word_segmentation) +class MultilingualWordSegmentationPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """Use `model` and `preprocessor` to create a nlp word segmentation pipeline for prediction + + Args: + model (str or Model): Supply either a local model dir which supported word segmentation task, or a + model id from the model hub, or a torch model instance. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. + sequence_length: Max sequence length in the user's custom scenario. 512 will be used as a default value. + + To view other examples plese check the tests/pipelines/test_multilingual_word_segmentation.py. + """ + + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = TokenClassificationPreprocessor( + model.model_dir, + sequence_length=kwargs.pop('sequence_length', 512)) + model.eval() + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.tokenizer = preprocessor.tokenizer + self.config = model.config + assert len(self.config.id2label) > 0 + self.id2label = self.config.id2label + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + text = inputs.pop(OutputKeys.TEXT) + with torch.no_grad(): + return { + **super().forward(inputs, **forward_params), OutputKeys.TEXT: + text + } + + def postprocess(self, inputs: Dict[str, Any], + **postprocess_params) -> Dict[str, str]: + text = inputs['text'] + offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']] + labels = [ + self.id2label[x] + for x in inputs['predictions'].squeeze(0).cpu().numpy() + ] + entities = [] + entity = {} + for label, offsets in zip(labels, offset_mapping): + if label[0] in 'BS': + if entity: + entity['span'] = text[entity['start']:entity['end']] + entities.append(entity) + entity = { + 'type': label[2:], + 'start': offsets[0], + 'end': offsets[1] + } + if label[0] in 'IES': + if entity: + entity['end'] = offsets[1] + if label[0] in 'ES': + if entity: + entity['span'] = text[entity['start']:entity['end']] + entities.append(entity) + entity = {} + if entity: + entity['span'] = text[entity['start']:entity['end']] + entities.append(entity) + + word_segments = [entity['span'] for entity in entities] + outputs = {OutputKeys.OUTPUT: word_segments, OutputKeys.LABELS: []} + + return outputs + + +@PIPELINES.register_module( + Tasks.word_segmentation, module_name=Pipelines.word_segmentation_thai) +class WordSegmentationThaiPipeline(MultilingualWordSegmentationPipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = WordSegmentationPreprocessorThai( + model.model_dir, + sequence_length=kwargs.pop('sequence_length', 512)) + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def postprocess(self, inputs: Dict[str, Any], + **postprocess_params) -> Dict[str, str]: + outputs = super().postprocess(inputs, **postprocess_params) + word_segments = outputs[OutputKeys.OUTPUT] + word_segments = [seg.replace(' ', '') for seg in word_segments] + + return {OutputKeys.OUTPUT: word_segments, OutputKeys.LABELS: []} diff --git a/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py b/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py index 7275feca..fdcf9e0f 100644 --- a/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py +++ b/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py @@ -9,11 +9,17 @@ from modelscope.models import Model from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Pipeline from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import (Preprocessor, +from modelscope.preprocessors import (NERPreprocessorThai, NERPreprocessorViet, + Preprocessor, TokenClassificationPreprocessor) from modelscope.utils.constant import Tasks +from modelscope.utils.tensor_utils import (torch_nested_detach, + torch_nested_numpify) -__all__ = ['NamedEntityRecognitionPipeline'] +__all__ = [ + 'NamedEntityRecognitionPipeline', 'NamedEntityRecognitionThaiPipeline', + 'NamedEntityRecognitionVietPipeline' +] @PIPELINES.register_module( @@ -59,37 +65,104 @@ class NamedEntityRecognitionPipeline(Pipeline): def forward(self, inputs: Dict[str, Any], **forward_params) -> Dict[str, Any]: + text = inputs.pop(OutputKeys.TEXT) with torch.no_grad(): - return super().forward(inputs, **forward_params) + return { + **self.model(**inputs, **forward_params), OutputKeys.TEXT: text + } def postprocess(self, inputs: Dict[str, Any], **postprocess_params) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): should be tensors from model + + Returns: + Dict[str, str]: the prediction results + """ text = inputs['text'] + if OutputKeys.PREDICTIONS not in inputs: + logits = inputs[OutputKeys.LOGITS] + predictions = torch.argmax(logits[0], dim=-1) + else: + predictions = inputs[OutputKeys.PREDICTIONS].squeeze( + 0).cpu().numpy() + predictions = torch_nested_numpify(torch_nested_detach(predictions)) offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']] - labels = [self.id2label[x] for x in inputs['predicts']] - entities = [] - entity = {} + + labels = [self.id2label[x] for x in predictions] + chunks = [] + chunk = {} for label, offsets in zip(labels, offset_mapping): if label[0] in 'BS': - if entity: - entity['span'] = text[entity['start']:entity['end']] - entities.append(entity) - entity = { + if chunk: + chunk['span'] = text[chunk['start']:chunk['end']] + chunks.append(chunk) + chunk = { 'type': label[2:], 'start': offsets[0], 'end': offsets[1] } if label[0] in 'IES': - if entity: - entity['end'] = offsets[1] + if chunk: + chunk['end'] = offsets[1] + if label[0] in 'ES': - if entity: - entity['span'] = text[entity['start']:entity['end']] - entities.append(entity) - entity = {} - if entity: - entity['span'] = text[entity['start']:entity['end']] - entities.append(entity) - outputs = {OutputKeys.OUTPUT: entities} + if chunk: + chunk['span'] = text[chunk['start']:chunk['end']] + chunks.append(chunk) + chunk = {} + + if chunk: + chunk['span'] = text[chunk['start']:chunk['end']] + chunks.append(chunk) + + # for cws output + if len(chunks) > 0 and chunks[0]['type'] == 'cws': + spans = [ + chunk['span'] for chunk in chunks if chunk['span'].strip() + ] + seg_result = ' '.join(spans) + outputs = {OutputKeys.OUTPUT: seg_result, OutputKeys.LABELS: []} + # for ner outpus + else: + outputs = {OutputKeys.OUTPUT: chunks} return outputs + + +@PIPELINES.register_module( + Tasks.named_entity_recognition, + module_name=Pipelines.named_entity_recognition_thai) +class NamedEntityRecognitionThaiPipeline(NamedEntityRecognitionPipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = NERPreprocessorThai( + model.model_dir, + sequence_length=kwargs.pop('sequence_length', 512)) + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + +@PIPELINES.register_module( + Tasks.named_entity_recognition, + module_name=Pipelines.named_entity_recognition_viet) +class NamedEntityRecognitionVietPipeline(NamedEntityRecognitionPipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = NERPreprocessorViet( + model.model_dir, + sequence_length=kwargs.pop('sequence_length', 512)) + super().__init__(model=model, preprocessor=preprocessor, **kwargs) diff --git a/modelscope/pipelines/nlp/sentence_embedding_pipeline.py b/modelscope/pipelines/nlp/sentence_embedding_pipeline.py index 16dedb2e..cfa5c2f1 100644 --- a/modelscope/pipelines/nlp/sentence_embedding_pipeline.py +++ b/modelscope/pipelines/nlp/sentence_embedding_pipeline.py @@ -2,15 +2,14 @@ from typing import Any, Dict, Optional, Union -import torch +import numpy as np from modelscope.metainfo import Pipelines from modelscope.models import Model from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Pipeline from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import (Preprocessor, - SentenceEmbeddingPreprocessor) +from modelscope.preprocessors import Preprocessor from modelscope.utils.constant import Tasks __all__ = ['SentenceEmbeddingPipeline'] @@ -33,20 +32,18 @@ class SentenceEmbeddingPipeline(Pipeline): the model if supplied. sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value. """ - model = model if isinstance(model, - Model) else Model.from_pretrained(model) + model = Model.from_pretrained(model) if isinstance(model, + str) else model if preprocessor is None: - preprocessor = SentenceEmbeddingPreprocessor( + preprocessor = Preprocessor.from_pretrained( model.model_dir if isinstance(model, Model) else model, first_sequence=first_sequence, sequence_length=kwargs.pop('sequence_length', 128)) - model.eval() super().__init__(model=model, preprocessor=preprocessor, **kwargs) def forward(self, inputs: Dict[str, Any], **forward_params) -> Dict[str, Any]: - with torch.no_grad(): - return {**self.model(inputs, **forward_params)} + return self.model(**inputs, **forward_params) def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """process the prediction results @@ -57,6 +54,11 @@ class SentenceEmbeddingPipeline(Pipeline): Returns: Dict[str, Any]: the predicted text representation """ - embs = inputs[OutputKeys.TEXT_EMBEDDING] - scores = inputs[OutputKeys.SCORES] + embs = inputs['last_hidden_state'][:, 0].cpu().numpy() + num_sent = embs.shape[0] + if num_sent >= 2: + scores = np.dot(embs[0:1, ], np.transpose(embs[1:, ], + (1, 0))).tolist()[0] + else: + scores = [] return {OutputKeys.TEXT_EMBEDDING: embs, OutputKeys.SCORES: scores} diff --git a/modelscope/pipelines/nlp/sequence_classification_pipeline.py b/modelscope/pipelines/nlp/sequence_classification_pipeline.py deleted file mode 100644 index 8d0e1dcd..00000000 --- a/modelscope/pipelines/nlp/sequence_classification_pipeline.py +++ /dev/null @@ -1,83 +0,0 @@ -from typing import Any, Dict, Union - -import numpy as np -import torch - -from modelscope.metainfo import Pipelines -from modelscope.models.base import Model -from modelscope.outputs import OutputKeys -from modelscope.pipelines.base import Pipeline -from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import (Preprocessor, - SequenceClassificationPreprocessor) -from modelscope.utils.constant import Tasks - - -@PIPELINES.register_module( - Tasks.text_classification, module_name=Pipelines.sentiment_analysis) -@PIPELINES.register_module(Tasks.nli, module_name=Pipelines.nli) -@PIPELINES.register_module( - Tasks.sentence_similarity, module_name=Pipelines.sentence_similarity) -@PIPELINES.register_module( - Tasks.text_classification, module_name=Pipelines.sentiment_classification) -class SequenceClassificationPipeline(Pipeline): - - def __init__(self, - model: Union[Model, str], - preprocessor: Preprocessor = None, - **kwargs): - """This is the base class for all the sequence classification sub-tasks. - - Args: - model (str or Model): A model instance or a model local dir or a model id in the model hub. - preprocessor (Preprocessor): a preprocessor instance, must not be None. - """ - assert isinstance(model, str) or isinstance(model, Model), \ - 'model must be a single str or Model' - model = model if isinstance(model, - Model) else Model.from_pretrained(model) - first_sequence = kwargs.pop('first_sequence', 'first_sequence') - second_sequence = kwargs.pop('second_sequence', None) - - if preprocessor is None: - preprocessor = SequenceClassificationPreprocessor( - model.model_dir if isinstance(model, Model) else model, - first_sequence=first_sequence, - second_sequence=second_sequence, - sequence_length=kwargs.pop('sequence_length', 512)) - - assert preprocessor is not None - model.eval() - super().__init__(model=model, preprocessor=preprocessor, **kwargs) - self.id2label = kwargs.get('id2label') - if self.id2label is None and hasattr(self.preprocessor, 'id2label'): - self.id2label = self.preprocessor.id2label - assert self.id2label is not None, 'Cannot convert id to the original label, please pass in the mapping ' \ - 'as a parameter or make sure the preprocessor has the attribute.' - - def forward(self, inputs: Dict[str, Any], - **forward_params) -> Dict[str, Any]: - with torch.no_grad(): - return self.model(**inputs, **forward_params) - - def postprocess(self, - inputs: Dict[str, Any], - topk: int = 5) -> Dict[str, str]: - """process the prediction results - - Args: - inputs (Dict[str, Any]): _description_ - topk (int): The topk probs to take - Returns: - Dict[str, str]: the prediction results - """ - - probs = inputs[OutputKeys.PROBABILITIES][0] - num_classes = probs.shape[0] - topk = min(topk, num_classes) - top_indices = np.argpartition(probs, -topk)[-topk:] - cls_ids = top_indices[np.argsort(probs[top_indices])] - probs = probs[cls_ids].tolist() - - cls_names = [self.id2label[cid] for cid in cls_ids] - return {OutputKeys.SCORES: probs, OutputKeys.LABELS: cls_names} diff --git a/modelscope/pipelines/nlp/summarization_pipeline.py b/modelscope/pipelines/nlp/summarization_pipeline.py index 7a91eff1..30dd4b30 100644 --- a/modelscope/pipelines/nlp/summarization_pipeline.py +++ b/modelscope/pipelines/nlp/summarization_pipeline.py @@ -13,7 +13,7 @@ logger = get_logger() @PIPELINES.register_module( - Tasks.summarization, module_name=Pipelines.text_generation) + Tasks.text_summarization, module_name=Pipelines.text_generation) class SummarizationPipeline(Pipeline): def __init__(self, diff --git a/modelscope/pipelines/nlp/table_question_answering_pipeline.py b/modelscope/pipelines/nlp/table_question_answering_pipeline.py index 96bfbc34..b75a8153 100644 --- a/modelscope/pipelines/nlp/table_question_answering_pipeline.py +++ b/modelscope/pipelines/nlp/table_question_answering_pipeline.py @@ -2,6 +2,8 @@ import os from typing import Any, Dict, Union +import json +import torch from transformers import BertTokenizer from modelscope.metainfo import Pipelines @@ -11,9 +13,13 @@ from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import TableQuestionAnsweringPreprocessor -from modelscope.preprocessors.star3.fields.database import Database -from modelscope.preprocessors.star3.fields.struct import Constant, SQLQuery +from modelscope.preprocessors.nlp.space_T_cn.fields.database import Database +from modelscope.preprocessors.nlp.space_T_cn.fields.struct import (Constant, + SQLQuery) from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() __all__ = ['TableQuestionAnsweringPipeline'] @@ -63,6 +69,7 @@ class TableQuestionAnsweringPipeline(Pipeline): self.max_where_num = constant.max_where_num self.col_type_dict = constant.col_type_dict self.schema_link_dict = constant.schema_link_dict + self.limit_dict = constant.limit_dict super().__init__(model=model, preprocessor=preprocessor, **kwargs) @@ -70,6 +77,7 @@ class TableQuestionAnsweringPipeline(Pipeline): action = self.action_ops[result['action']] headers = table['header_name'] current_sql = result['sql'] + current_sql['from'] = [table['table_id']] if history_sql is None: return current_sql @@ -214,10 +222,11 @@ class TableQuestionAnsweringPipeline(Pipeline): else: return current_sql - def sql_dict_to_str(self, result, table): + def sql_dict_to_str(self, result, tables): """ convert sql struct to string """ + table = tables[result['sql']['from'][0]] header_names = table['header_name'] + ['空列'] header_ids = table['header_id'] + ['null'] sql = result['sql'] @@ -230,14 +239,22 @@ class TableQuestionAnsweringPipeline(Pipeline): str_sel_list.append(header_name) sql_sel_list.append(header_id) else: - str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '( ' - + header_name + ' )') - sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '( ' - + header_id + ' )') + str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' + + header_name + ')') + sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' + + header_id + ')') str_cond_list, sql_cond_list = [], [] + where_conds, orderby_conds = [], [] for cond in sql['conds']: + if cond[1] in [4, 5]: + orderby_conds.append(cond) + else: + where_conds.append(cond) + for cond in where_conds: header_name = header_names[cond[0]] + if header_name == '空列': + continue header_id = '`%s`.`%s`' % (table['table_id'], header_ids[cond[0]]) op = self.cond_ops[cond[1]] value = cond[2] @@ -245,15 +262,55 @@ class TableQuestionAnsweringPipeline(Pipeline): + '" )') sql_cond_list.append('( ' + header_id + ' ' + op + ' "' + value + '" )') + cond_str = ' ' + self.cond_conn_ops[sql['cond_conn_op']] + ' ' + str_where_conds = cond_str.join(str_cond_list) + sql_where_conds = cond_str.join(sql_cond_list) + if len(orderby_conds) != 0: + str_orderby_column = ', '.join( + [header_names[cond[0]] for cond in orderby_conds]) + sql_orderby_column = ', '.join([ + '`%s`.`%s`' % (table['table_id'], header_ids[cond[0]]) + for cond in orderby_conds + ]) + str_orderby_op = self.cond_ops[orderby_conds[0][1]] + str_orderby = '%s %s' % (str_orderby_column, str_orderby_op) + sql_orderby = '%s %s' % (sql_orderby_column, str_orderby_op) + limit_key = orderby_conds[0][2] + is_in, limit_num = False, -1 + for key in self.limit_dict: + if key in limit_key: + is_in = True + limit_num = self.limit_dict[key] + break + if is_in: + str_orderby += ' LIMIT %d' % (limit_num) + sql_orderby += ' LIMIT %d' % (limit_num) + else: + str_orderby = '' + + if len(str_cond_list) != 0 and len(str_orderby) != 0: + final_str = 'SELECT %s FROM %s WHERE %s ORDER BY %s' % ( + ', '.join(str_sel_list), table['table_name'], str_where_conds, + str_orderby) + final_sql = 'SELECT %s FROM `%s` WHERE %s ORDER BY %s' % ( + ', '.join(sql_sel_list), table['table_id'], sql_where_conds, + sql_orderby) + elif len(str_cond_list) != 0: + final_str = 'SELECT %s FROM %s WHERE %s' % ( + ', '.join(str_sel_list), table['table_name'], str_where_conds) + final_sql = 'SELECT %s FROM `%s` WHERE %s' % ( + ', '.join(sql_sel_list), table['table_id'], sql_where_conds) + elif len(str_orderby) != 0: + final_str = 'SELECT %s FROM %s ORDER BY %s' % ( + ', '.join(str_sel_list), table['table_name'], str_orderby) + final_sql = 'SELECT %s FROM `%s` ORDER BY %s' % ( + ', '.join(sql_sel_list), table['table_id'], sql_orderby) + else: + final_str = 'SELECT %s FROM %s' % (', '.join(str_sel_list), + table['table_name']) + final_sql = 'SELECT %s FROM `%s`' % (', '.join(sql_sel_list), + table['table_id']) - cond = ' ' + self.cond_conn_ops[sql['cond_conn_op']] + ' ' - - final_str = 'SELECT %s FROM %s WHERE %s' % (', '.join(str_sel_list), - table['table_name'], - cond.join(str_cond_list)) - final_sql = 'SELECT %s FROM `%s` WHERE %s' % (', '.join(sql_sel_list), - table['table_id'], - cond.join(sql_cond_list)) sql = SQLQuery( string=final_str, query=final_sql, sql_result=result['sql']) @@ -270,14 +327,47 @@ class TableQuestionAnsweringPipeline(Pipeline): """ result = inputs['result'] history_sql = inputs['history_sql'] - result['sql'] = self.post_process_multi_turn( - history_sql=history_sql, - result=result, - table=self.db.tables[result['table_id']]) - sql = self.sql_dict_to_str( - result=result, table=self.db.tables[result['table_id']]) - output = {OutputKeys.OUTPUT: sql, OutputKeys.HISTORY: result['sql']} - return output + try: + result['sql'] = self.post_process_multi_turn( + history_sql=history_sql, + result=result, + table=self.db.tables[result['table_id']]) + except Exception: + result['sql'] = history_sql + sql = self.sql_dict_to_str(result=result, tables=self.db.tables) + + # add sqlite + if self.db.is_use_sqlite: + try: + cursor = self.db.connection_obj.cursor().execute(sql.query) + header_ids, header_names = [], [] + for description in cursor.description: + header_names.append(self.db.tables[result['table_id']] + ['headerid2name'].get( + description[0], description[0])) + header_ids.append(description[0]) + rows = [] + for res in cursor.fetchall(): + rows.append(list(res)) + tabledata = { + 'header_id': header_ids, + 'header_name': header_names, + 'rows': rows + } + except Exception as e: + logger.error(e) + tabledata = {'header_id': [], 'header_name': [], 'rows': []} + else: + tabledata = {'header_id': [], 'header_name': [], 'rows': []} + + output = { + OutputKeys.SQL_STRING: sql.string, + OutputKeys.SQL_QUERY: sql.query, + OutputKeys.HISTORY: result['sql'], + OutputKeys.QUERT_RESULT: tabledata, + } + + return {OutputKeys.OUTPUT: output} def _collate_fn(self, data): return data diff --git a/modelscope/pipelines/nlp/text2text_generation_pipeline.py b/modelscope/pipelines/nlp/text2text_generation_pipeline.py index 9ccd00f4..a739df69 100644 --- a/modelscope/pipelines/nlp/text2text_generation_pipeline.py +++ b/modelscope/pipelines/nlp/text2text_generation_pipeline.py @@ -1,20 +1,35 @@ -from typing import Any, Dict, Optional, Union +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, List, Optional, Union import torch +from numpy import isin from modelscope.metainfo import Pipelines from modelscope.models.base import Model from modelscope.outputs import OutputKeys -from modelscope.pipelines.base import Pipeline, Tensor +from modelscope.pipelines.base import Input, Pipeline, Tensor from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import Text2TextGenerationPreprocessor +from modelscope.utils.config import use_task_specific_params from modelscope.utils.constant import Tasks __all__ = ['Text2TextGenerationPipeline'] +TRANSLATE_PIPELINES = [ + Pipelines.translation_en_to_de, + Pipelines.translation_en_to_ro, + Pipelines.translation_en_to_fr, +] + @PIPELINES.register_module( Tasks.text2text_generation, module_name=Pipelines.text2text_generation) +@PIPELINES.register_module( + Tasks.text2text_generation, module_name=Pipelines.translation_en_to_de) +@PIPELINES.register_module( + Tasks.text2text_generation, module_name=Pipelines.translation_en_to_ro) +@PIPELINES.register_module( + Tasks.text2text_generation, module_name=Pipelines.translation_en_to_fr) class Text2TextGenerationPipeline(Pipeline): def __init__( @@ -38,13 +53,13 @@ class Text2TextGenerationPipeline(Pipeline): Example: >>> from modelscope.pipelines import pipeline - >>> pipeline_ins = pipeline(task='text-generation', - >>> model='damo/nlp_palm2.0_text-generation_chinese-base') - >>> sentence1 = '本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方:' - >>> '1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代' + >>> pipeline_ins = pipeline(task='text2text-generation', + >>> model='damo/nlp_t5_text2text-generation_chinese-base') + >>> sentence1 = '中国的首都位于。' >>> print(pipeline_ins(sentence1)) >>> # Or use the dict input: >>> print(pipeline_ins({'sentence': sentence1})) + >>> # 北京 To view other examples plese check the tests/pipelines/test_text_generation.py. """ @@ -55,9 +70,22 @@ class Text2TextGenerationPipeline(Pipeline): model.model_dir, sequence_length=kwargs.pop('sequence_length', 128)) self.tokenizer = preprocessor.tokenizer + self.pipeline = model.pipeline.type model.eval() super().__init__(model=model, preprocessor=preprocessor, **kwargs) + def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: + """ Provide specific preprocess for text2text generation pipeline in order to handl multi tasks + """ + if not isinstance(inputs, str): + raise ValueError(f'Not supported input type: {type(inputs)}') + + if self.pipeline in TRANSLATE_PIPELINES: + use_task_specific_params(self.model, self.pipeline) + inputs = self.model.config.prefix + inputs + + return super().preprocess(inputs, **preprocess_params) + def forward(self, inputs: Dict[str, Any], **forward_params) -> Dict[str, Any]: diff --git a/modelscope/pipelines/nlp/text_classification_pipeline.py b/modelscope/pipelines/nlp/text_classification_pipeline.py index 13d9964d..9e00ad7f 100644 --- a/modelscope/pipelines/nlp/text_classification_pipeline.py +++ b/modelscope/pipelines/nlp/text_classification_pipeline.py @@ -1,43 +1,124 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict, Union +import numpy as np + from modelscope.metainfo import Pipelines +from modelscope.models.base import Model from modelscope.models.multi_modal import OfaForAllTasks -from modelscope.pipelines.base import Model, Pipeline +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import OfaPreprocessor, Preprocessor from modelscope.utils.constant import Tasks -from modelscope.utils.logger import get_logger - -logger = get_logger() +@PIPELINES.register_module( + Tasks.text_classification, module_name=Pipelines.sentiment_analysis) +@PIPELINES.register_module(Tasks.nli, module_name=Pipelines.nli) +@PIPELINES.register_module( + Tasks.sentence_similarity, module_name=Pipelines.sentence_similarity) @PIPELINES.register_module( Tasks.text_classification, module_name=Pipelines.text_classification) +@PIPELINES.register_module( + Tasks.text_classification, module_name=Pipelines.sentiment_classification) +@PIPELINES.register_module( + Tasks.text_classification, module_name=Pipelines.sentence_similarity) +@PIPELINES.register_module( + Tasks.sentiment_classification, + module_name=Pipelines.sentiment_classification) class TextClassificationPipeline(Pipeline): def __init__(self, model: Union[Model, str], - preprocessor: [Preprocessor] = None, + preprocessor: Preprocessor = None, **kwargs): + """The inference pipeline for all the text classification sub-tasks. + + Args: + model (`str` or `Model` or module instance): A model instance or a model local dir + or a model id in the model hub. + preprocessor (`Preprocessor`, `optional`): A Preprocessor instance. + first_sequence (`str`, `optional`): The key of the first sentence. + second_sequence (`str`, `optional`): The key of the second sentence. + sequence_length (`int`, `optional`): The sequence length. + id2label (`dict`, `optional`): The id-label mapping. + + Example: + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline('text-classification', + model='damo/nlp_structbert_sentence-similarity_chinese-base') + >>> input = ('这是个测试', '这也是个测试') + >>> print(pipeline_ins(input)) + + NOTE: Inputs of type 'str' are also supported. In this scenario, the 'first_sequence' and 'second_sequence' + param will have no affection. """ - use `model` and `preprocessor` to create a kws pipeline for prediction + model = Model.from_pretrained(model) if isinstance(model, + str) else model + + if preprocessor is None: + if isinstance(model, OfaForAllTasks): + preprocessor = OfaPreprocessor(model_dir=model.model_dir) + else: + first_sequence = kwargs.pop('first_sequence', 'first_sequence') + second_sequence = kwargs.pop('second_sequence', None) + preprocessor = Preprocessor.from_pretrained( + model if isinstance(model, str) else model.model_dir, + first_sequence=first_sequence, + second_sequence=second_sequence, + sequence_length=kwargs.pop('sequence_length', 512)) + + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.id2label = kwargs.get('id2label') + if self.id2label is None and hasattr(self.preprocessor, 'id2label'): + self.id2label = self.preprocessor.id2label + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + if isinstance(self.model, OfaForAllTasks): + return super().forward(inputs, **forward_params) + return self.model(**inputs, **forward_params) + + def postprocess(self, + inputs: Dict[str, Any], + topk: int = 5) -> Dict[str, str]: + """process the prediction results + Args: - model: model id on modelscope hub. + inputs (`Dict[str, Any]` or `TextClassificationModelOutput`): The model output, please check + the `TextClassificationModelOutput` class for details. + topk (int): The topk probs to take + Returns: + Dict[str, str]: the prediction results. + scores: The probabilities of each label. + labels: The real labels. + Label at index 0 is the smallest probability. """ - super().__init__(model=model) - assert isinstance(model, str) or isinstance(model, Model), \ - 'model must be a single str or OfaForAllTasks' - if isinstance(model, str): - pipe_model = Model.from_pretrained(model) - elif isinstance(model, Model): - pipe_model = model + if isinstance(self.model, OfaForAllTasks): + return inputs else: - raise NotImplementedError - pipe_model.model.eval() - if preprocessor is None and isinstance(pipe_model, OfaForAllTasks): - preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) - super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) - - def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - return inputs + assert self.id2label is not None, 'Cannot convert id to the original label, please pass in the mapping ' \ + 'as a parameter or make sure the preprocessor has the attribute.' + logits = inputs[OutputKeys.LOGITS].cpu().numpy() + if logits.shape[0] == 1: + logits = logits[0] + + def softmax(logits): + exp = np.exp(logits - np.max(logits, axis=-1, keepdims=True)) + return exp / exp.sum(axis=-1, keepdims=True) + + probs = softmax(logits) + num_classes = probs.shape[-1] + topk = min(topk, num_classes) + top_indices = np.argpartition(probs, -topk)[-topk:] + probs = np.take_along_axis(probs, top_indices, axis=-1).tolist() + + def map_to_label(id): + return self.id2label[id] + + v_func = np.vectorize(map_to_label) + return { + OutputKeys.SCORES: probs, + OutputKeys.LABELS: v_func(top_indices).tolist() + } diff --git a/modelscope/pipelines/nlp/text_generation_pipeline.py b/modelscope/pipelines/nlp/text_generation_pipeline.py index ea35763f..28acebb4 100644 --- a/modelscope/pipelines/nlp/text_generation_pipeline.py +++ b/modelscope/pipelines/nlp/text_generation_pipeline.py @@ -6,10 +6,12 @@ import torch from modelscope.metainfo import Pipelines from modelscope.models.base import Model +from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Pipeline, Tensor from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import TextGenerationPreprocessor -from modelscope.utils.constant import Tasks +from modelscope.preprocessors import Preprocessor, build_preprocessor +from modelscope.utils.constant import Fields, Tasks +from modelscope.utils.hub import read_config __all__ = ['TextGenerationPipeline'] @@ -20,7 +22,7 @@ class TextGenerationPipeline(Pipeline): def __init__(self, model: Union[Model, str], - preprocessor: Optional[TextGenerationPreprocessor] = None, + preprocessor: Optional[Preprocessor] = None, first_sequence='sentence', **kwargs): """Use `model` and `preprocessor` to create a generation pipeline for prediction. @@ -50,19 +52,34 @@ class TextGenerationPipeline(Pipeline): """ model = model if isinstance(model, Model) else Model.from_pretrained(model) + cfg = read_config(model.model_dir) + self.postprocessor = cfg.pop('postprocessor', None) if preprocessor is None: - preprocessor = TextGenerationPreprocessor( + preprocessor_cfg = cfg.preprocessor + preprocessor_cfg.update({ + 'model_dir': model.model_dir, - first_sequence=first_sequence, - second_sequence=None, - sequence_length=kwargs.pop('sequence_length', 128)) + 'first_sequence': + first_sequence, + 'second_sequence': + None, + 'sequence_length': + kwargs.pop('sequence_length', 128) + }) + preprocessor = build_preprocessor(preprocessor_cfg, Fields.nlp) model.eval() super().__init__(model=model, preprocessor=preprocessor, **kwargs) + def _sanitize_parameters(self, **pipeline_parameters): + return {}, pipeline_parameters, {} + def forward(self, inputs: Dict[str, Any], **forward_params) -> Dict[str, Any]: with torch.no_grad(): - return self.model.generate(inputs) + return self.model.generate(inputs, **forward_params) + + def sentence_piece(self, inputs) -> Dict[str, Tensor]: + return self.preprocessor.tokenizer.decode(inputs.tolist()[0]) def postprocess(self, inputs: Dict[str, Tensor], **postprocess_params) -> Dict[str, str]: @@ -74,4 +91,7 @@ class TextGenerationPipeline(Pipeline): Returns: Dict[str, str]: the prediction results """ - return inputs + return inputs if self.postprocessor is None else { + OutputKeys.TEXT: + getattr(self, self.postprocessor.replace('-', '_'))(inputs) + } diff --git a/modelscope/pipelines/nlp/passage_ranking_pipeline.py b/modelscope/pipelines/nlp/text_ranking_pipeline.py similarity index 70% rename from modelscope/pipelines/nlp/passage_ranking_pipeline.py rename to modelscope/pipelines/nlp/text_ranking_pipeline.py index 1d818ac0..9cee327b 100644 --- a/modelscope/pipelines/nlp/passage_ranking_pipeline.py +++ b/modelscope/pipelines/nlp/text_ranking_pipeline.py @@ -2,22 +2,22 @@ from typing import Any, Dict, Optional, Union -import torch +import numpy as np from modelscope.metainfo import Pipelines from modelscope.models import Model from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Pipeline from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import PassageRankingPreprocessor, Preprocessor +from modelscope.preprocessors import Preprocessor, TextRankingPreprocessor from modelscope.utils.constant import Tasks -__all__ = ['PassageRankingPipeline'] +__all__ = ['TextRankingPipeline'] @PIPELINES.register_module( - Tasks.passage_ranking, module_name=Pipelines.passage_ranking) -class PassageRankingPipeline(Pipeline): + Tasks.text_ranking, module_name=Pipelines.text_ranking) +class TextRankingPipeline(Pipeline): def __init__(self, model: Union[Model, str], @@ -32,20 +32,18 @@ class PassageRankingPipeline(Pipeline): the model if supplied. sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value. """ - model = model if isinstance(model, - Model) else Model.from_pretrained(model) + model = Model.from_pretrained(model) if isinstance(model, + str) else model if preprocessor is None: - preprocessor = PassageRankingPreprocessor( - model.model_dir if isinstance(model, Model) else model, + preprocessor = Preprocessor.from_pretrained( + model.model_dir, sequence_length=kwargs.pop('sequence_length', 128)) - model.eval() super().__init__(model=model, preprocessor=preprocessor, **kwargs) def forward(self, inputs: Dict[str, Any], **forward_params) -> Dict[str, Any]: - with torch.no_grad(): - return {**self.model(inputs, **forward_params)} + return self.model(**inputs, **forward_params) def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """process the prediction results @@ -55,6 +53,10 @@ class PassageRankingPipeline(Pipeline): Returns: Dict[str, Any]: the predicted text representation """ - pred_list = inputs[OutputKeys.SCORES] + def sigmoid(logits): + return np.exp(logits) / (1 + np.exp(logits)) + + logits = inputs[OutputKeys.LOGITS].squeeze(-1).detach().cpu().numpy() + pred_list = sigmoid(logits).tolist() return {OutputKeys.SCORES: pred_list} diff --git a/modelscope/pipelines/nlp/token_classification_pipeline.py b/modelscope/pipelines/nlp/token_classification_pipeline.py index 5367c1a8..c36f0dfc 100644 --- a/modelscope/pipelines/nlp/token_classification_pipeline.py +++ b/modelscope/pipelines/nlp/token_classification_pipeline.py @@ -7,17 +7,24 @@ import torch from modelscope.metainfo import Pipelines from modelscope.models import Model from modelscope.outputs import OutputKeys -from modelscope.pipelines.base import Pipeline, Tensor +from modelscope.pipelines.base import Pipeline from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import (Preprocessor, - TokenClassificationPreprocessor) +from modelscope.preprocessors import Preprocessor from modelscope.utils.constant import Tasks +from modelscope.utils.tensor_utils import (torch_nested_detach, + torch_nested_numpify) __all__ = ['TokenClassificationPipeline'] @PIPELINES.register_module( Tasks.token_classification, module_name=Pipelines.part_of_speech) +@PIPELINES.register_module( + Tasks.token_classification, module_name=Pipelines.word_segmentation) +@PIPELINES.register_module( + Tasks.token_classification, module_name=Pipelines.named_entity_recognition) +@PIPELINES.register_module( + Tasks.part_of_speech, module_name=Pipelines.part_of_speech) class TokenClassificationPipeline(Pipeline): def __init__(self, @@ -30,19 +37,18 @@ class TokenClassificationPipeline(Pipeline): model (str or Model): A model instance or a model local dir or a model id in the model hub. preprocessor (Preprocessor): a preprocessor instance, must not be None. """ - assert isinstance(model, str) or isinstance(model, Model), \ - 'model must be a single str or Model' - model = model if isinstance(model, - Model) else Model.from_pretrained(model) + model = Model.from_pretrained(model) if isinstance(model, + str) else model + if preprocessor is None: - preprocessor = TokenClassificationPreprocessor( + preprocessor = Model.from_pretrained( model.model_dir, sequence_length=kwargs.pop('sequence_length', 128)) model.eval() super().__init__(model=model, preprocessor=preprocessor, **kwargs) - self.id2label = getattr(model, 'id2label') - assert self.id2label is not None, 'Cannot convert id to the original label, please pass in the mapping ' \ - 'as a parameter or make sure the preprocessor has the attribute.' + self.id2label = kwargs.get('id2label') + if self.id2label is None and hasattr(self.preprocessor, 'id2label'): + self.id2label = self.preprocessor.id2label def forward(self, inputs: Dict[str, Any], **forward_params) -> Dict[str, Any]: @@ -57,38 +63,59 @@ class TokenClassificationPipeline(Pipeline): """process the prediction results Args: - inputs (Dict[str, Any]): _description_ + inputs (Dict[str, Any]): should be tensors from model Returns: Dict[str, str]: the prediction results """ + text = inputs['text'] + if not hasattr(inputs, 'predictions'): + logits = inputs[OutputKeys.LOGITS] + predictions = torch.argmax(logits[0], dim=-1) + else: + predictions = inputs[OutputKeys.PREDICTIONS].squeeze( + 0).cpu().numpy() + predictions = torch_nested_numpify(torch_nested_detach(predictions)) + offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']] - pred_list = inputs['predictions'] - labels = [] - for pre in pred_list: - labels.append(self.id2label[pre]) - labels = labels[1:-1] + labels = [self.id2label[x] for x in predictions] + if len(labels) > len(offset_mapping): + labels = labels[1:-1] chunks = [] - tags = [] - chunk = '' - assert len(inputs['text']) == len(labels) - for token, label in zip(inputs['text'], labels): - if label[0] == 'B' or label[0] == 'I': - chunk += token - else: - chunk += token - chunks.append(chunk) - chunk = '' - tags.append(label.split('-')[-1]) + chunk = {} + for label, offsets in zip(labels, offset_mapping): + if label[0] in 'BS': + if chunk: + chunk['span'] = text[chunk['start']:chunk['end']] + chunks.append(chunk) + chunk = { + 'type': label[2:], + 'start': offsets[0], + 'end': offsets[1] + } + if label[0] in 'IES': + if chunk: + chunk['end'] = offsets[1] + + if label[0] in 'ES': + if chunk: + chunk['span'] = text[chunk['start']:chunk['end']] + chunks.append(chunk) + chunk = {} + if chunk: + chunk['span'] = text[chunk['start']:chunk['end']] chunks.append(chunk) - tags.append(label.split('-')[-1]) - pos_result = [] - seg_result = ' '.join(chunks) - for chunk, tag in zip(chunks, tags): - pos_result.append({OutputKeys.WORD: chunk, OutputKeys.LABEL: tag}) - outputs = { - OutputKeys.OUTPUT: seg_result, - OutputKeys.LABELS: pos_result - } + + # for cws output + if len(chunks) > 0 and chunks[0]['type'] == 'cws': + spans = [ + chunk['span'] for chunk in chunks if chunk['span'].strip() + ] + seg_result = ' '.join(spans) + outputs = {OutputKeys.OUTPUT: seg_result, OutputKeys.LABELS: []} + + # for ner outputs + else: + outputs = {OutputKeys.OUTPUT: chunks} return outputs diff --git a/modelscope/pipelines/nlp/translation_pipeline.py b/modelscope/pipelines/nlp/translation_pipeline.py index eb7f7f74..68a03631 100644 --- a/modelscope/pipelines/nlp/translation_pipeline.py +++ b/modelscope/pipelines/nlp/translation_pipeline.py @@ -34,7 +34,8 @@ class TranslationPipeline(Pipeline): def __init__(self, model: Model, **kwargs): """Build a translation pipeline with a model dir or a model id in the model hub. - @param model: A Model instance. + Args: + model: A Model instance. """ super().__init__(model=model, **kwargs) model = self.model.model_dir diff --git a/modelscope/pipelines/nlp/translation_quality_estimation_pipeline.py b/modelscope/pipelines/nlp/translation_quality_estimation_pipeline.py new file mode 100644 index 00000000..6ef203b9 --- /dev/null +++ b/modelscope/pipelines/nlp/translation_quality_estimation_pipeline.py @@ -0,0 +1,72 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import io +import os +from typing import Any, Dict, Union + +import numpy as np +import torch +from transformers import XLMRobertaTokenizer + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.models.nlp import BertForSequenceClassification +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import SequenceClassificationPreprocessor +from modelscope.utils.constant import ModelFile, Tasks + +__all__ = ['TranslationQualityEstimationPipeline'] + + +@PIPELINES.register_module( + Tasks.sentence_similarity, + module_name=Pipelines.translation_quality_estimation) +class TranslationQualityEstimationPipeline(Pipeline): + + def __init__(self, model: str, device: str = 'gpu', **kwargs): + super().__init__(model=model, device=device) + model_file = os.path.join(model, ModelFile.TORCH_MODEL_FILE) + with open(model_file, 'rb') as f: + buffer = io.BytesIO(f.read()) + self.tokenizer = XLMRobertaTokenizer.from_pretrained(model) + self.model = torch.jit.load( + buffer, map_location=self.device).to(self.device) + + def preprocess(self, inputs: Dict[str, Any]): + src_text = inputs['source_text'].strip() + tgt_text = inputs['target_text'].strip() + encoded_inputs = self.tokenizer.batch_encode_plus( + [[src_text, tgt_text]], + return_tensors='pt', + padding=True, + truncation=True) + input_ids = encoded_inputs['input_ids'].to(self.device) + attention_mask = encoded_inputs['attention_mask'].to(self.device) + inputs.update({ + 'input_ids': input_ids, + 'attention_mask': attention_mask + }) + return inputs + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + if 'input_ids' not in inputs: + inputs = self.preprocess(inputs) + res = self.model(inputs['input_ids'], inputs['attention_mask']) + result = { + OutputKeys.LABELS: '-1', + OutputKeys.SCORES: res[0].detach().squeeze().tolist() + } + return result + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): input data dict + + Returns: + Dict[str, str]: the prediction results + """ + return inputs diff --git a/modelscope/pipelines/nlp/word_segmentation_pipeline.py b/modelscope/pipelines/nlp/word_segmentation_pipeline.py index 9d4bb67f..0df8f1ad 100644 --- a/modelscope/pipelines/nlp/word_segmentation_pipeline.py +++ b/modelscope/pipelines/nlp/word_segmentation_pipeline.py @@ -12,6 +12,8 @@ from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import (Preprocessor, TokenClassificationPreprocessor) from modelscope.utils.constant import Tasks +from modelscope.utils.tensor_utils import (torch_nested_detach, + torch_nested_numpify) __all__ = ['WordSegmentationPipeline'] @@ -72,28 +74,56 @@ class WordSegmentationPipeline(Pipeline): """process the prediction results Args: - inputs (Dict[str, Any]): _description_ + inputs (Dict[str, Any]): should be tensors from model Returns: Dict[str, str]: the prediction results """ - - pred_list = inputs['predictions'] - labels = [] - for pre in pred_list: - labels.append(self.id2label[pre]) - labels = labels[1:-1] + text = inputs['text'] + logits = inputs[OutputKeys.LOGITS] + predictions = torch.argmax(logits[0], dim=-1) + logits = torch_nested_numpify(torch_nested_detach(logits)) + predictions = torch_nested_numpify(torch_nested_detach(predictions)) + offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']] + + labels = [self.id2label[x] for x in predictions] + if len(labels) > len(offset_mapping): + labels = labels[1:-1] chunks = [] - chunk = '' - assert len(inputs['text']) == len(labels) - for token, label in zip(inputs['text'], labels): - if label[0] == 'B' or label[0] == 'I': - chunk += token - else: - chunk += token - chunks.append(chunk) - chunk = '' + chunk = {} + for label, offsets in zip(labels, offset_mapping): + if label[0] in 'BS': + if chunk: + chunk['span'] = text[chunk['start']:chunk['end']] + chunks.append(chunk) + chunk = { + 'type': label[2:], + 'start': offsets[0], + 'end': offsets[1] + } + if label[0] in 'IES': + if chunk: + chunk['end'] = offsets[1] + + if label[0] in 'ES': + if chunk: + chunk['span'] = text[chunk['start']:chunk['end']] + chunks.append(chunk) + chunk = {} + if chunk: + chunk['span'] = text[chunk['start']:chunk['end']] chunks.append(chunk) - seg_result = ' '.join(chunks) - return {OutputKeys.OUTPUT: seg_result, OutputKeys.LABELS: []} + + # for cws output + if len(chunks) > 0 and chunks[0]['type'] == 'cws': + spans = [ + chunk['span'] for chunk in chunks if chunk['span'].strip() + ] + seg_result = ' '.join(spans) + outputs = {OutputKeys.OUTPUT: seg_result, OutputKeys.LABELS: []} + + # for ner outpus + else: + outputs = {OutputKeys.OUTPUT: chunks} + return outputs diff --git a/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py b/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py index fc7051c7..ecd538b9 100644 --- a/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py +++ b/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py @@ -74,7 +74,8 @@ class ZeroShotClassificationPipeline(Pipeline): preprocess_params = {} postprocess_params = {} if 'candidate_labels' in kwargs: - candidate_labels = kwargs.pop('candidate_labels') + candidate_labels = self._parse_labels( + kwargs.pop('candidate_labels')) preprocess_params['candidate_labels'] = candidate_labels postprocess_params['candidate_labels'] = candidate_labels else: @@ -84,10 +85,17 @@ class ZeroShotClassificationPipeline(Pipeline): postprocess_params['multi_label'] = kwargs.pop('multi_label', False) return preprocess_params, {}, postprocess_params + def _parse_labels(self, labels): + if isinstance(labels, str): + labels = labels.replace(',', ',') # replace cn comma to en comma + labels = [ + label.strip() for label in labels.split(',') if label.strip() + ] + return labels + def forward(self, inputs: Dict[str, Any], **forward_params) -> Dict[str, Any]: - with torch.no_grad(): - return self.model(**inputs, **forward_params) + return self.model(**inputs, **forward_params) def postprocess(self, inputs: Dict[str, Any], @@ -99,7 +107,7 @@ class ZeroShotClassificationPipeline(Pipeline): Returns: Dict[str, Any]: the prediction results """ - logits = inputs[OutputKeys.LOGITS] + logits = inputs[OutputKeys.LOGITS].cpu().numpy() if multi_label or len(candidate_labels) == 1: logits = logits[..., [self.contradiction_id, self.entailment_id]] scores = softmax(logits, axis=-1)[..., 1] diff --git a/modelscope/pipelines/science/__init__.py b/modelscope/pipelines/science/__init__.py new file mode 100644 index 00000000..1f81809b --- /dev/null +++ b/modelscope/pipelines/science/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .protein_structure_pipeline import ProteinStructurePipeline + +else: + _import_structure = { + 'protein_structure_pipeline': ['ProteinStructurePipeline'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/pipelines/science/protein_structure_pipeline.py b/modelscope/pipelines/science/protein_structure_pipeline.py new file mode 100644 index 00000000..3dc51c72 --- /dev/null +++ b/modelscope/pipelines/science/protein_structure_pipeline.py @@ -0,0 +1,215 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import time +from typing import Any, Dict, List, Optional, Union + +import json +import numpy as np +import torch +from unicore.utils import tensor_tree_map + +from modelscope.metainfo import Pipelines +from modelscope.models.base import Model +from modelscope.models.science.unifold.config import model_config +from modelscope.models.science.unifold.data import protein, residue_constants +from modelscope.models.science.unifold.dataset import (UnifoldDataset, + load_and_process) +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline, Tensor +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import Preprocessor, build_preprocessor +from modelscope.utils.constant import Fields, Frameworks, Tasks +from modelscope.utils.device import device_placement +from modelscope.utils.hub import read_config +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['ProteinStructurePipeline'] + + +def automatic_chunk_size(seq_len): + if seq_len < 512: + chunk_size = 256 + elif seq_len < 1024: + chunk_size = 128 + elif seq_len < 2048: + chunk_size = 32 + elif seq_len < 3072: + chunk_size = 16 + else: + chunk_size = 1 + return chunk_size + + +def load_feature_for_one_target( + config, + data_folder, + seed=0, + is_multimer=False, + use_uniprot=False, + symmetry_group=None, +): + if not is_multimer: + uniprot_msa_dir = None + sequence_ids = ['A'] + if use_uniprot: + uniprot_msa_dir = data_folder + + else: + uniprot_msa_dir = data_folder + sequence_ids = open(os.path.join(data_folder, + 'chains.txt')).readline().split() + + if symmetry_group is None: + batch, _ = load_and_process( + config=config.data, + mode='predict', + seed=seed, + batch_idx=None, + data_idx=0, + is_distillation=False, + sequence_ids=sequence_ids, + monomer_feature_dir=data_folder, + uniprot_msa_dir=uniprot_msa_dir, + ) + else: + raise NotImplementedError + batch = UnifoldDataset.collater([batch]) + return batch + + +@PIPELINES.register_module( + Tasks.protein_structure, module_name=Pipelines.protein_structure) +class ProteinStructurePipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """Use `model` and `preprocessor` to create a protein structure pipeline for prediction. + + Args: + model (str or Model): Supply either a local model dir which supported the protein structure task, + or a model id from the model hub, or a torch model instance. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. + + Example: + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline(task='protein-structure', + >>> model='DPTech/uni-fold-monomer') + >>> protein = 'LILNLRGGAFVSNTQITMADKQKKFINEIQEGDLVRSYSITDETFQQNAVTSIVKHEADQLCQINFGKQHVVC' + >>> print(pipeline_ins(protein)) + + """ + import copy + model_path = copy.deepcopy(model) if isinstance(model, str) else None + cfg = read_config(model_path) # only model is str + self.cfg = cfg + self.config = model_config( + cfg['pipeline']['model_name']) # alphafold config + model = model if isinstance( + model, Model) else Model.from_pretrained(model_path) + self.postprocessor = cfg.pop('postprocessor', None) + if preprocessor is None: + preprocessor_cfg = cfg.preprocessor + preprocessor = build_preprocessor(preprocessor_cfg, Fields.science) + model.eval() + model.model.inference_mode() + model.model_dir = model_path + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def _sanitize_parameters(self, **pipeline_parameters): + return pipeline_parameters, pipeline_parameters, pipeline_parameters + + def _process_single(self, input, *args, **kwargs) -> Dict[str, Any]: + preprocess_params = kwargs.get('preprocess_params', {}) + forward_params = kwargs.get('forward_params', {}) + postprocess_params = kwargs.get('postprocess_params', {}) + out = self.preprocess(input, **preprocess_params) + with device_placement(self.framework, self.device_name): + with torch.no_grad(): + out = self.forward(out, **forward_params) + + out = self.postprocess(out, **postprocess_params) + return out + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + plddts = {} + ptms = {} + + output_dir = os.path.join(self.preprocessor.output_dir_base, + inputs['target_id']) + + pdbs = [] + for seed in range(self.cfg['pipeline']['times']): + cur_seed = hash((42, seed)) % 100000 + batch = load_feature_for_one_target( + self.config, + output_dir, + cur_seed, + is_multimer=inputs['is_multimer'], + use_uniprot=inputs['is_multimer'], + symmetry_group=self.preprocessor.symmetry_group, + ) + seq_len = batch['aatype'].shape[-1] + self.model.model.globals.chunk_size = automatic_chunk_size(seq_len) + + with torch.no_grad(): + batch = { + k: torch.as_tensor(v, device='cuda:0') + for k, v in batch.items() + } + out = self.model(batch) + + def to_float(x): + if x.dtype == torch.bfloat16 or x.dtype == torch.half: + return x.float() + else: + return x + + # Toss out the recycling dimensions --- we don't need them anymore + batch = tensor_tree_map(lambda t: t[-1, 0, ...], batch) + batch = tensor_tree_map(to_float, batch) + out = tensor_tree_map(lambda t: t[0, ...], out[0]) + out = tensor_tree_map(to_float, out) + batch = tensor_tree_map(lambda x: np.array(x.cpu()), batch) + out = tensor_tree_map(lambda x: np.array(x.cpu()), out) + + plddt = out['plddt'] + mean_plddt = np.mean(plddt) + plddt_b_factors = np.repeat( + plddt[..., None], residue_constants.atom_type_num, axis=-1) + # TODO: , may need to reorder chains, based on entity_ids + cur_protein = protein.from_prediction( + features=batch, result=out, b_factors=plddt_b_factors) + cur_save_name = (f'{cur_seed}') + plddts[cur_save_name] = str(mean_plddt) + if inputs[ + 'is_multimer'] and self.preprocessor.symmetry_group is None: + ptms[cur_save_name] = str(np.mean(out['iptm+ptm'])) + with open(os.path.join(output_dir, cur_save_name + '.pdb'), + 'w') as f: + f.write(protein.to_pdb(cur_protein)) + pdbs.append(protein.to_pdb(cur_protein)) + + logger.info('plddts:' + str(plddts)) + model_name = self.cfg['pipeline']['model_name'] + score_name = f'{model_name}' + plddt_fname = score_name + '_plddt.json' + + with open(os.path.join(output_dir, plddt_fname), 'w') as f: + json.dump(plddts, f, indent=4) + if ptms: + logger.info('ptms' + str(ptms)) + ptm_fname = score_name + '_ptm.json' + with open(os.path.join(output_dir, ptm_fname), 'w') as f: + json.dump(ptms, f, indent=4) + + return pdbs + + def postprocess(self, inputs: Dict[str, Tensor], **postprocess_params): + return inputs diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 90303b65..e568098f 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -16,29 +16,21 @@ if TYPE_CHECKING: from .kws import WavToLists from .multi_modal import (OfaPreprocessor, MPlugPreprocessor) from .nlp import ( - DocumentSegmentationPreprocessor, - FaqQuestionAnsweringPreprocessor, - FillMaskPoNetPreprocessor, - NLPPreprocessor, - NLPTokenizerPreprocessorBase, - PassageRankingPreprocessor, - RelationExtractionPreprocessor, - SentenceEmbeddingPreprocessor, - SequenceClassificationPreprocessor, - TokenClassificationPreprocessor, - TextErrorCorrectionPreprocessor, - TextGenerationPreprocessor, - Text2TextGenerationPreprocessor, - Tokenize, + DocumentSegmentationPreprocessor, FaqQuestionAnsweringPreprocessor, + FillMaskPoNetPreprocessor, NLPPreprocessor, + NLPTokenizerPreprocessorBase, TextRankingPreprocessor, + RelationExtractionPreprocessor, SentenceEmbeddingPreprocessor, + SequenceClassificationPreprocessor, TokenClassificationPreprocessor, + TextErrorCorrectionPreprocessor, TextGenerationPreprocessor, + Text2TextGenerationPreprocessor, Tokenize, WordSegmentationBlankSetToLabelPreprocessor, - ZeroShotClassificationPreprocessor, - ) - from .space import (DialogIntentPredictionPreprocessor, - DialogModelingPreprocessor, - DialogStateTrackingPreprocessor) + ZeroShotClassificationPreprocessor, TextGenerationJiebaPreprocessor, + SentencePiecePreprocessor, DialogIntentPredictionPreprocessor, + DialogModelingPreprocessor, DialogStateTrackingPreprocessor, + ConversationalTextToSqlPreprocessor, + TableQuestionAnsweringPreprocessor, NERPreprocessorViet, + NERPreprocessorThai, WordSegmentationPreprocessorThai) from .video import ReadVideoData, MovieSceneSegmentationPreprocessor - from .star import ConversationalTextToSqlPreprocessor - from .star3 import TableQuestionAnsweringPreprocessor else: _import_structure = { @@ -56,28 +48,24 @@ else: 'multi_modal': ['OfaPreprocessor', 'MPlugPreprocessor'], 'nlp': [ 'DocumentSegmentationPreprocessor', - 'FaqQuestionAnsweringPreprocessor', - 'FillMaskPoNetPreprocessor', - 'NLPPreprocessor', - 'NLPTokenizerPreprocessorBase', - 'PassageRankingPreprocessor', - 'RelationExtractionPreprocessor', + 'FaqQuestionAnsweringPreprocessor', 'FillMaskPoNetPreprocessor', + 'NLPPreprocessor', 'NLPTokenizerPreprocessorBase', + 'TextRankingPreprocessor', 'RelationExtractionPreprocessor', 'SentenceEmbeddingPreprocessor', 'SequenceClassificationPreprocessor', 'TokenClassificationPreprocessor', - 'TextErrorCorrectionPreprocessor', - 'TextGenerationPreprocessor', - 'Tokenize', - 'Text2TextGenerationPreprocessor', + 'TextErrorCorrectionPreprocessor', 'TextGenerationPreprocessor', + 'Tokenize', 'Text2TextGenerationPreprocessor', 'WordSegmentationBlankSetToLabelPreprocessor', 'ZeroShotClassificationPreprocessor', - ], - 'space': [ + 'TextGenerationJiebaPreprocessor', 'SentencePiecePreprocessor', + 'NERPreprocessorViet', 'NERPreprocessorThai', + 'WordSegmentationPreprocessorThai', 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', - 'DialogStateTrackingPreprocessor', 'InputFeatures' + 'DialogStateTrackingPreprocessor', + 'ConversationalTextToSqlPreprocessor', + 'TableQuestionAnsweringPreprocessor' ], - 'star': ['ConversationalTextToSqlPreprocessor'], - 'star3': ['TableQuestionAnsweringPreprocessor'], } import sys diff --git a/modelscope/preprocessors/asr.py b/modelscope/preprocessors/asr.py index d58383d7..91bf5860 100644 --- a/modelscope/preprocessors/asr.py +++ b/modelscope/preprocessors/asr.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os from typing import Any, Dict, List, Union @@ -131,6 +133,12 @@ class WavToScp(Preprocessor): else: inputs['asr_model_config'] = asr_model_config + if inputs['model_config'].__contains__('mvn_file'): + mvn_file = os.path.join(inputs['model_workspace'], + inputs['model_config']['mvn_file']) + assert os.path.exists(mvn_file), 'mvn_file does not exist' + inputs['mvn_file'] = mvn_file + elif inputs['model_type'] == Frameworks.tf: assert inputs['model_config'].__contains__( 'vocab_file'), 'vocab_file does not exist' diff --git a/modelscope/preprocessors/base.py b/modelscope/preprocessors/base.py index 6360a907..c2716a13 100644 --- a/modelscope/preprocessors/base.py +++ b/modelscope/preprocessors/base.py @@ -1,15 +1,22 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os from abc import ABC, abstractmethod -from typing import Any, Dict +from copy import deepcopy +from typing import Any, Dict, Optional, Sequence -from modelscope.utils.constant import ModeKeys +from modelscope.utils.config import Config +from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModeKeys, Tasks +from modelscope.utils.hub import read_config, snapshot_download +from modelscope.utils.logger import get_logger +from .builder import build_preprocessor + +logger = get_logger(__name__) class Preprocessor(ABC): - def __init__(self, *args, **kwargs): - self._mode = ModeKeys.INFERENCE + def __init__(self, mode=ModeKeys.INFERENCE, *args, **kwargs): + self._mode = mode self.device = int( os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else None pass @@ -25,3 +32,61 @@ class Preprocessor(ABC): @mode.setter def mode(self, value): self._mode = value + + @classmethod + def from_pretrained(cls, + model_name_or_path: str, + revision: Optional[str] = DEFAULT_MODEL_REVISION, + cfg_dict: Config = None, + preprocessor_mode=ModeKeys.INFERENCE, + **kwargs): + """ Instantiate a model from local directory or remote model repo. Note + that when loading from remote, the model revision can be specified. + """ + if not os.path.exists(model_name_or_path): + model_dir = snapshot_download( + model_name_or_path, revision=revision) + else: + model_dir = model_name_or_path + if cfg_dict is None: + cfg = read_config(model_dir) + else: + cfg = cfg_dict + task = cfg.task + if 'task' in kwargs: + task = kwargs.pop('task') + field_name = Tasks.find_field_by_task(task) + if not hasattr(cfg, 'preprocessor'): + logger.error('No preprocessor field found in cfg.') + return None + + sub_key = 'train' if preprocessor_mode == ModeKeys.TRAIN else 'val' + + if 'type' not in cfg.preprocessor: + if sub_key in cfg.preprocessor: + sub_cfg = getattr(cfg.preprocessor, sub_key) + else: + logger.error( + f'No {sub_key} key and type key found in ' + f'preprocessor domain of configuration.json file.') + return None + else: + sub_cfg = cfg.preprocessor + + if len(sub_cfg): + if isinstance(sub_cfg, Sequence): + # TODO: for Sequence, need adapt to `mode` and `mode_dir` args, + # and add mode for Compose or other plans + raise NotImplementedError('Not supported yet!') + sub_cfg = deepcopy(sub_cfg) + sub_cfg.update({'model_dir': model_dir}) + sub_cfg.update(kwargs) + preprocessor = build_preprocessor(sub_cfg, field_name) + else: + logger.error( + f'Cannot find available config to build preprocessor at mode {preprocessor_mode}, ' + f'please check the preprocessor field in the configuration.json file.' + ) + return None + preprocessor.mode = preprocessor_mode + return preprocessor diff --git a/modelscope/preprocessors/kws.py b/modelscope/preprocessors/kws.py index 9c370ed5..6f09d545 100644 --- a/modelscope/preprocessors/kws.py +++ b/modelscope/preprocessors/kws.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os from typing import Any, Dict, List, Union diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index f38ff8ae..256c5243 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp +from io import BytesIO from typing import Any, Dict, List, Tuple, Union import torch @@ -15,6 +16,7 @@ from .base import Preprocessor from .builder import PREPROCESSORS from .ofa import * # noqa from .ofa.utils.collate import collate_fn +from .ofa.utils.constant import OFA_TASK_KEY_MAPPING __all__ = [ 'OfaPreprocessor', @@ -26,14 +28,20 @@ __all__ = [ Fields.multi_modal, module_name=Preprocessors.ofa_tasks_preprocessor) class OfaPreprocessor(Preprocessor): - def __init__(self, model_dir: str, *args, **kwargs): + def __init__(self, + model_dir: str, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): """preprocess the data Args: model_dir (str): model path + mode: preprocessor mode (model mode) """ super().__init__(*args, **kwargs) preprocess_mapping = { + Tasks.ocr_recognition: OfaOcrRecognitionPreprocessor, Tasks.image_captioning: OfaImageCaptioningPreprocessor, Tasks.visual_grounding: OfaVisualGroundingPreprocessor, Tasks.visual_question_answering: @@ -41,27 +49,21 @@ class OfaPreprocessor(Preprocessor): Tasks.visual_entailment: OfaVisualEntailmentPreprocessor, Tasks.image_classification: OfaImageClassificationPreprocessor, Tasks.text_classification: OfaTextClassificationPreprocessor, - Tasks.summarization: OfaSummarizationPreprocessor, + Tasks.text_summarization: OfaSummarizationPreprocessor, Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor } - input_key_mapping = { - Tasks.image_captioning: ['image'], - Tasks.image_classification: ['image'], - Tasks.summarization: ['text'], - Tasks.text_classification: ['text', 'text2'], - Tasks.visual_grounding: ['image', 'text'], - Tasks.visual_question_answering: ['image', 'text'], - Tasks.visual_entailment: ['image', 'text', 'text2'], - Tasks.text_to_image_synthesis: ['text'] - } model_dir = model_dir if osp.exists(model_dir) else snapshot_download( model_dir) self.cfg = Config.from_file( osp.join(model_dir, ModelFile.CONFIGURATION)) - self.preprocess = preprocess_mapping[self.cfg.task](self.cfg, - model_dir) - self.keys = input_key_mapping[self.cfg.task] + self.preprocess = preprocess_mapping[self.cfg.task]( + cfg=self.cfg, model_dir=model_dir, mode=mode) + self.keys = OFA_TASK_KEY_MAPPING[self.cfg.task] self.tokenizer = self.preprocess.tokenizer + if kwargs.get('no_collate', None): + self.no_collate = True + else: + self.no_collate = False # just for modelscope demo def _build_dict(self, input: Union[Input, List[Input]]) -> Dict[str, Any]: @@ -72,20 +74,37 @@ class OfaPreprocessor(Preprocessor): data[key] = item return data + def _ofa_input_compatibility_conversion(self, data): + if 'image' in data and self.cfg.model.get('type', None) == 'ofa': + if isinstance(data['image'], str): + image = load_image(data['image']) + else: + image = data['image'] + if image.mode != 'RGB': + image = image.convert('RGB') + img_buffer = BytesIO() + image.save(img_buffer, format='JPEG') + data['image'] = Image.open(img_buffer) + return data + def __call__(self, input: Union[str, tuple, Dict[str, Any]], *args, **kwargs) -> Dict[str, Any]: if isinstance(input, dict): data = input else: data = self._build_dict(input) + data = self._ofa_input_compatibility_conversion(data) sample = self.preprocess(data) str_data = dict() for k, v in data.items(): str_data[k] = str(v) sample['sample'] = str_data - return collate_fn([sample], - pad_idx=self.tokenizer.pad_token_id, - eos_idx=self.tokenizer.eos_token_id) + if self.no_collate: + return sample + else: + return collate_fn([sample], + pad_idx=self.tokenizer.pad_token_id, + eos_idx=self.tokenizer.eos_token_id) @PREPROCESSORS.register_module( @@ -138,7 +157,7 @@ class MPlugPreprocessor(Preprocessor): def image_open(self, path: str) -> Tuple[Image.Image, int]: if path not in self._image_map: index = len(self._image_map) - self._image_map[path] = (load_image(path), index) + self._image_map[path] = (Image.open(path), index) return self._image_map[path] def __call__( diff --git a/modelscope/preprocessors/nlp/__init__.py b/modelscope/preprocessors/nlp/__init__.py index dfbb5c81..d9c55fe1 100644 --- a/modelscope/preprocessors/nlp/__init__.py +++ b/modelscope/preprocessors/nlp/__init__.py @@ -5,46 +5,80 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .text_error_correction import TextErrorCorrectionPreprocessor - from .nlp_base import ( - DocumentSegmentationPreprocessor, - FaqQuestionAnsweringPreprocessor, - FillMaskPoNetPreprocessor, - NLPPreprocessor, - NLPTokenizerPreprocessorBase, - PassageRankingPreprocessor, - RelationExtractionPreprocessor, - SentenceEmbeddingPreprocessor, - SequenceClassificationPreprocessor, - TokenClassificationPreprocessor, - TextGenerationPreprocessor, - Text2TextGenerationPreprocessor, - Tokenize, - WordSegmentationBlankSetToLabelPreprocessor, - ZeroShotClassificationPreprocessor, - ) - + from .nlp_base import (NLPTokenizerPreprocessorBase, NLPBasePreprocessor) + from .text_generation_jieba_preprocessor import TextGenerationJiebaPreprocessor + from .sentence_piece_preprocessor import SentencePiecePreprocessor + from .bert_seq_cls_tokenizer import Tokenize + from .document_segmentation_preprocessor import DocumentSegmentationPreprocessor + from .faq_question_answering_preprocessor import FaqQuestionAnsweringPreprocessor + from .fill_mask_preprocessor import FillMaskPoNetPreprocessor, NLPPreprocessor + from .text_ranking_preprocessor import TextRankingPreprocessor + from .relation_extraction_preprocessor import RelationExtractionPreprocessor + from .sentence_classification_preprocessor import SequenceClassificationPreprocessor + from .sentence_embedding_preprocessor import SentenceEmbeddingPreprocessor + from .text_generation_preprocessor import TextGenerationPreprocessor + from .text2text_generation_preprocessor import Text2TextGenerationPreprocessor + from .token_classification_preprocessor import TokenClassificationPreprocessor, \ + WordSegmentationBlankSetToLabelPreprocessor + from .token_classification_thai_preprocessor import WordSegmentationPreprocessorThai, NERPreprocessorThai + from .token_classification_viet_preprocessor import NERPreprocessorViet + from .zero_shot_classification_reprocessor import ZeroShotClassificationPreprocessor + from .space import (DialogIntentPredictionPreprocessor, + DialogModelingPreprocessor, + DialogStateTrackingPreprocessor, InputFeatures, + MultiWOZBPETextField, IntentBPETextField) + from .space_T_en import ConversationalTextToSqlPreprocessor + from .space_T_cn import TableQuestionAnsweringPreprocessor else: _import_structure = { 'nlp_base': [ - 'DocumentSegmentationPreprocessor', - 'FaqQuestionAnsweringPreprocessor', - 'FillMaskPoNetPreprocessor', - 'NLPPreprocessor', 'NLPTokenizerPreprocessorBase', - 'PassageRankingPreprocessor', - 'RelationExtractionPreprocessor', - 'SentenceEmbeddingPreprocessor', - 'SequenceClassificationPreprocessor', + 'NLPBasePreprocessor', + ], + 'text_generation_jieba_preprocessor': + ['TextGenerationJiebaPreprocessor'], + 'sentence_piece_preprocessor': ['SentencePiecePreprocessor'], + 'bert_seq_cls_tokenizer': ['Tokenize'], + 'document_segmentation_preprocessor': + ['DocumentSegmentationPreprocessor'], + 'faq_question_answering_preprocessor': + ['FaqQuestionAnsweringPreprocessor'], + 'fill_mask_preprocessor': + ['FillMaskPoNetPreprocessor', 'NLPPreprocessor'], + 'text_ranking_preprocessor': ['TextRankingPreprocessor'], + 'relation_extraction_preprocessor': ['RelationExtractionPreprocessor'], + 'sentence_classification_preprocessor': + ['SequenceClassificationPreprocessor'], + 'sentence_embedding_preprocessor': ['SentenceEmbeddingPreprocessor'], + 'text_generation_preprocessor': ['TextGenerationPreprocessor'], + 'text2text_generation_preprocessor': + ['Text2TextGenerationPreprocessor'], + 'token_classification_preprocessor': [ 'TokenClassificationPreprocessor', - 'TextGenerationPreprocessor', - 'Tokenize', - 'Text2TextGenerationPreprocessor', - 'WordSegmentationBlankSetToLabelPreprocessor', - 'ZeroShotClassificationPreprocessor', + 'WordSegmentationBlankSetToLabelPreprocessor' ], + 'zero_shot_classification_reprocessor': + ['ZeroShotClassificationPreprocessor'], 'text_error_correction': [ 'TextErrorCorrectionPreprocessor', ], + 'token_classification_thai_preprocessor': [ + 'NERPreprocessorThai', + 'WordSegmentationPreprocessorThai', + ], + 'token_classification_viet_preprocessor': [ + 'NERPreprocessorViet', + ], + 'space': [ + 'DialogIntentPredictionPreprocessor', + 'DialogModelingPreprocessor', + 'DialogStateTrackingPreprocessor', + 'InputFeatures', + 'MultiWOZBPETextField', + 'IntentBPETextField', + ], + 'space_T_en': ['ConversationalTextToSqlPreprocessor'], + 'space_T_cn': ['TableQuestionAnsweringPreprocessor'], } import sys diff --git a/modelscope/preprocessors/nlp/bert_seq_cls_tokenizer.py b/modelscope/preprocessors/nlp/bert_seq_cls_tokenizer.py new file mode 100644 index 00000000..576687ce --- /dev/null +++ b/modelscope/preprocessors/nlp/bert_seq_cls_tokenizer.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Union + +from transformers import AutoTokenizer + +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields, InputFields + + +@PREPROCESSORS.register_module(Fields.nlp) +class Tokenize(Preprocessor): + + def __init__(self, tokenizer_name) -> None: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + def __call__(self, data: Union[str, Dict[str, Any]]) -> Dict[str, Any]: + if isinstance(data, str): + data = {InputFields.text: data} + token_dict = self.tokenizer(data[InputFields.text]) + data.update(token_dict) + return data diff --git a/modelscope/preprocessors/nlp/document_segmentation_preprocessor.py b/modelscope/preprocessors/nlp/document_segmentation_preprocessor.py new file mode 100644 index 00000000..5ab0a0c6 --- /dev/null +++ b/modelscope/preprocessors/nlp/document_segmentation_preprocessor.py @@ -0,0 +1,220 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields +from modelscope.utils.logger import get_logger +from .nlp_base import NLPBasePreprocessor + +logger = get_logger() + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.document_segmentation) +class DocumentSegmentationPreprocessor(NLPBasePreprocessor): + + def __init__(self, model_dir: str, config, *args, **kwargs): + """preprocess the data + + Args: + model_dir (str): model path + """ + + super().__init__(model_dir, *args, **kwargs) + from transformers import BertTokenizerFast + self.tokenizer = BertTokenizerFast.from_pretrained( + model_dir, + use_fast=True, + ) + self.question_column_name = 'labels' + self.context_column_name = 'sentences' + self.example_id_column_name = 'example_id' + self.label_to_id = {'B-EOP': 0, 'O': 1} + self.target_specical_ids = set() + self.target_specical_ids.add(self.tokenizer.eos_token_id) + self.max_seq_length = config.max_position_embeddings + self.label_list = ['B-EOP', 'O'] + + def __call__(self, examples) -> Dict[str, Any]: + questions = examples[self.question_column_name] + contexts = examples[self.context_column_name] + example_ids = examples[self.example_id_column_name] + num_examples = len(questions) + + sentences = [] + for sentence_list in contexts: + sentence_list = [_ + '[EOS]' for _ in sentence_list] + sentences.append(sentence_list) + + try: + tokenized_examples = self.tokenizer( + sentences, + is_split_into_words=True, + add_special_tokens=False, + return_token_type_ids=True, + return_attention_mask=True, + ) + except Exception as e: + logger.error(e) + return {} + + segment_ids = [] + token_seq_labels = [] + for example_index in range(num_examples): + example_input_ids = tokenized_examples['input_ids'][example_index] + example_labels = questions[example_index] + example_labels = [ + self.label_to_id[_] if _ in self.label_to_id else -100 + for _ in example_labels + ] + example_token_labels = [] + segment_id = [] + cur_seg_id = 1 + for token_index in range(len(example_input_ids)): + if example_input_ids[token_index] in self.target_specical_ids: + example_token_labels.append(example_labels[cur_seg_id - 1]) + segment_id.append(cur_seg_id) + cur_seg_id += 1 + else: + example_token_labels.append(-100) + segment_id.append(cur_seg_id) + + segment_ids.append(segment_id) + token_seq_labels.append(example_token_labels) + + tokenized_examples['segment_ids'] = segment_ids + tokenized_examples['token_seq_labels'] = token_seq_labels + + new_segment_ids = [] + new_token_seq_labels = [] + new_input_ids = [] + new_token_type_ids = [] + new_attention_mask = [] + new_example_ids = [] + new_sentences = [] + + for example_index in range(num_examples): + example_input_ids = tokenized_examples['input_ids'][example_index] + example_token_type_ids = tokenized_examples['token_type_ids'][ + example_index] + example_attention_mask = tokenized_examples['attention_mask'][ + example_index] + example_segment_ids = tokenized_examples['segment_ids'][ + example_index] + example_token_seq_labels = tokenized_examples['token_seq_labels'][ + example_index] + example_sentences = contexts[example_index] + example_id = example_ids[example_index] + example_total_num_sentences = len(questions[example_index]) + example_total_num_tokens = len( + tokenized_examples['input_ids'][example_index]) + accumulate_length = [ + i for i, x in enumerate(tokenized_examples['input_ids'] + [example_index]) + if x == self.tokenizer.eos_token_id + ] + samples_boundary = [] + left_index = 0 + sent_left_index = 0 + sent_i = 0 + + # for sent_i, length in enumerate(accumulate_length): + while sent_i < len(accumulate_length): + length = accumulate_length[sent_i] + right_index = length + 1 + sent_right_index = sent_i + 1 + if right_index - left_index >= self.max_seq_length - 1 or right_index == example_total_num_tokens: + samples_boundary.append([left_index, right_index]) + + sample_input_ids = [ + self.tokenizer.cls_token_id + ] + example_input_ids[left_index:right_index] + sample_input_ids = sample_input_ids[:self.max_seq_length] + + sample_token_type_ids = [ + 0 + ] + example_token_type_ids[left_index:right_index] + sample_token_type_ids = sample_token_type_ids[:self. + max_seq_length] + + sample_attention_mask = [ + 1 + ] + example_attention_mask[left_index:right_index] + sample_attention_mask = sample_attention_mask[:self. + max_seq_length] + + sample_segment_ids = [ + 0 + ] + example_segment_ids[left_index:right_index] + sample_segment_ids = sample_segment_ids[:self. + max_seq_length] + + sample_token_seq_labels = [ + -100 + ] + example_token_seq_labels[left_index:right_index] + sample_token_seq_labels = sample_token_seq_labels[:self. + max_seq_length] + + if sent_right_index - 1 == sent_left_index: + left_index = right_index + sample_input_ids[-1] = self.tokenizer.eos_token_id + sample_token_seq_labels[-1] = -100 + else: + left_index = accumulate_length[sent_i - 1] + 1 + if sample_token_seq_labels[-1] != -100: + sample_token_seq_labels[-1] = -100 + + if sent_right_index - 1 == sent_left_index or right_index == example_total_num_tokens: + sample_sentences = example_sentences[ + sent_left_index:sent_right_index] + sent_left_index = sent_right_index + sent_i += 1 + else: + sample_sentences = example_sentences[ + sent_left_index:sent_right_index - 1] + sent_left_index = sent_right_index - 1 + + if (len([_ for _ in sample_token_seq_labels if _ != -100 + ])) != len(sample_sentences) - 1 and (len([ + _ + for _ in sample_token_seq_labels if _ != -100 + ])) != len(sample_sentences): + tmp = [] + for w_i, w, l in zip( + sample_input_ids, + self.tokenizer.decode(sample_input_ids).split( + ' '), sample_token_seq_labels): + tmp.append((w_i, w, l)) + while len(sample_input_ids) < self.max_seq_length: + sample_input_ids.append(self.tokenizer.pad_token_id) + sample_token_type_ids.append(0) + sample_attention_mask.append(0) + sample_segment_ids.append(example_total_num_sentences + + 1) + sample_token_seq_labels.append(-100) + + new_input_ids.append(sample_input_ids) + new_token_type_ids.append(sample_token_type_ids) + new_attention_mask.append(sample_attention_mask) + new_segment_ids.append(sample_segment_ids) + new_token_seq_labels.append(sample_token_seq_labels) + new_example_ids.append(example_id) + new_sentences.append(sample_sentences) + else: + sent_i += 1 + continue + + output_samples = {} + + output_samples['input_ids'] = new_input_ids + output_samples['token_type_ids'] = new_token_type_ids + output_samples['attention_mask'] = new_attention_mask + + output_samples['segment_ids'] = new_segment_ids + output_samples['example_id'] = new_example_ids + output_samples['labels'] = new_token_seq_labels + output_samples['sentences'] = new_sentences + + return output_samples diff --git a/modelscope/preprocessors/nlp/faq_question_answering_preprocessor.py b/modelscope/preprocessors/nlp/faq_question_answering_preprocessor.py new file mode 100644 index 00000000..873a8448 --- /dev/null +++ b/modelscope/preprocessors/nlp/faq_question_answering_preprocessor.py @@ -0,0 +1,98 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Any, Dict + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.config import Config, ConfigFields +from modelscope.utils.constant import Fields, ModeKeys, ModelFile +from modelscope.utils.type_assert import type_assert +from .nlp_base import NLPBasePreprocessor + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.faq_question_answering_preprocessor) +class FaqQuestionAnsweringPreprocessor(NLPBasePreprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + super(FaqQuestionAnsweringPreprocessor, self).__init__( + model_dir, mode=ModeKeys.INFERENCE, **kwargs) + + from transformers import BertTokenizer + + preprocessor_config = Config.from_file( + os.path.join(model_dir, ModelFile.CONFIGURATION)).get( + ConfigFields.preprocessor, {}) + if preprocessor_config.get('tokenizer', + 'BertTokenizer') == 'XLMRoberta': + from transformers import XLMRobertaTokenizer + self.tokenizer = XLMRobertaTokenizer.from_pretrained(model_dir) + else: + self.tokenizer = BertTokenizer.from_pretrained(model_dir) + + self.MAX_LEN = preprocessor_config.get('max_seq_length', 50) + self.label_dict = None + + def pad(self, samples, max_len): + result = [] + for sample in samples: + pad_len = max_len - len(sample[:max_len]) + result.append(sample[:max_len] + + [self.tokenizer.pad_token_id] * pad_len) + return result + + def set_label_dict(self, label_dict): + self.label_dict = label_dict + + def get_label(self, label_id): + assert self.label_dict is not None and label_id < len(self.label_dict) + return self.label_dict[label_id] + + def encode_plus(self, text): + return [ + self.tokenizer.cls_token_id + ] + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize(text)) + [self.tokenizer.sep_token_id] + + @type_assert(object, Dict) + def __call__(self, data: Dict[str, Any], + **preprocessor_param) -> Dict[str, Any]: + TMP_MAX_LEN = preprocessor_param.get('max_seq_length', self.MAX_LEN) + queryset = data['query_set'] + if not isinstance(queryset, list): + queryset = [queryset] + supportset = data['support_set'] + supportset = sorted(supportset, key=lambda d: d['label']) + + queryset_tokenized = [self.encode_plus(text) for text in queryset] + supportset_tokenized = [ + self.encode_plus(item['text']) for item in supportset + ] + + max_len = max( + [len(seq) for seq in queryset_tokenized + supportset_tokenized]) + max_len = min(TMP_MAX_LEN, max_len) + queryset_padded = self.pad(queryset_tokenized, max_len) + supportset_padded = self.pad(supportset_tokenized, max_len) + + supportset_labels_ori = [item['label'] for item in supportset] + label_dict = [] + for label in supportset_labels_ori: + if label not in label_dict: + label_dict.append(label) + self.set_label_dict(label_dict) + supportset_labels_ids = [ + label_dict.index(label) for label in supportset_labels_ori + ] + return { + 'query': queryset_padded, + 'support': supportset_padded, + 'support_labels': supportset_labels_ids + } + + def batch_encode(self, sentence_list: list, max_length=None): + if not max_length: + max_length = self.MAX_LEN + return self.tokenizer.batch_encode_plus( + sentence_list, padding=True, max_length=max_length) diff --git a/modelscope/preprocessors/nlp/fill_mask_preprocessor.py b/modelscope/preprocessors/nlp/fill_mask_preprocessor.py new file mode 100644 index 00000000..b0638dbc --- /dev/null +++ b/modelscope/preprocessors/nlp/fill_mask_preprocessor.py @@ -0,0 +1,142 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os.path as osp +import re +from typing import Any, Dict, Tuple, Union + +import numpy as np +import torch + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.config import Config +from modelscope.utils.constant import Fields, ModeKeys, ModelFile +from modelscope.utils.nlp import import_external_nltk_data +from .nlp_base import NLPTokenizerPreprocessorBase + + +@PREPROCESSORS.register_module(Fields.nlp, module_name=Preprocessors.fill_mask) +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.feature_extraction) +class NLPPreprocessor(NLPTokenizerPreprocessorBase): + """The tokenizer preprocessor used in MLM task. + """ + + def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): + kwargs['truncation'] = kwargs.get('truncation', True) + kwargs['padding'] = kwargs.get('padding', 'max_length') + kwargs['max_length'] = kwargs.pop('sequence_length', 128) + kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', + True) + super().__init__(model_dir, mode=mode, **kwargs) + + @property + def mask_id(self): + return self.tokenizer.mask_token_id + + def decode(self, + token_ids, + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = True, + **kwargs): + return self.tokenizer.decode(token_ids, skip_special_tokens, + clean_up_tokenization_spaces, **kwargs) + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.fill_mask_ponet) +class FillMaskPoNetPreprocessor(NLPTokenizerPreprocessorBase): + """The tokenizer preprocessor used in PoNet model's MLM task. + """ + + def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): + kwargs['truncation'] = kwargs.get('truncation', True) + kwargs['padding'] = kwargs.get('padding', 'max_length') + kwargs['max_length'] = kwargs.pop('sequence_length', 512) + kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', + True) + super().__init__(model_dir, mode=mode, **kwargs) + + self.cfg = Config.from_file( + osp.join(model_dir, ModelFile.CONFIGURATION)) + self.language = self.cfg.model.get('language', 'en') + if self.language == 'en': + from nltk.tokenize import sent_tokenize + import_external_nltk_data( + osp.join(model_dir, 'nltk_data'), 'tokenizers/punkt') + elif self.language in ['zh', 'cn']: + + def sent_tokenize(para): + para = re.sub(r'([。!!?\?])([^”’])', r'\1\n\2', para) # noqa * + para = re.sub(r'(\.{6})([^”’])', r'\1\n\2', para) # noqa * + para = re.sub(r'(\…{2})([^”’])', r'\1\n\2', para) # noqa * + para = re.sub(r'([。!?\?][”’])([^,。!?\?])', r'\1\n\2', + para) # noqa * + para = para.rstrip() + return [_ for _ in para.split('\n') if _] + else: + raise NotImplementedError + + self.sent_tokenize = sent_tokenize + self.max_length = kwargs['max_length'] + + def __call__(self, data: Union[str, Tuple, Dict]) -> Dict[str, Any]: + """process the raw input data + + Args: + data (tuple): [sentence1, sentence2] + sentence1 (str): a sentence + Example: + 'you are so handsome.' + sentence2 (str): a sentence + Example: + 'you are so beautiful.' + Returns: + Dict[str, Any]: the preprocessed data + """ + + text_a, text_b, labels = self.parse_text_and_label(data) + output = self.tokenizer( + text_a, + text_b, + return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None, + **self.tokenize_kwargs) + max_seq_length = self.max_length + + if text_b is None: + segment_ids = [] + seg_lens = list( + map( + len, + self.tokenizer( + self.sent_tokenize(text_a), + add_special_tokens=False, + truncation=True)['input_ids'])) + segment_id = [0] + sum( + [[i] * sl for i, sl in enumerate(seg_lens, start=1)], []) + segment_id = segment_id[:max_seq_length - 1] + segment_ids.append(segment_id + [segment_id[-1] + 1] + * (max_seq_length - len(segment_id))) + if self.mode == ModeKeys.INFERENCE: + segment_ids = torch.tensor(segment_ids) + output['segment_ids'] = segment_ids + + output = { + k: np.array(v) if isinstance(v, list) else v + for k, v in output.items() + } + + self.labels_to_id(labels, output) + return output + + @property + def mask_id(self): + return self.tokenizer.mask_token_id + + def decode(self, + token_ids, + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = True, + **kwargs): + return self.tokenizer.decode(token_ids, skip_special_tokens, + clean_up_tokenization_spaces, **kwargs) diff --git a/modelscope/preprocessors/nlp/nlp_base.py b/modelscope/preprocessors/nlp/nlp_base.py index 6b559de9..48a04d7a 100644 --- a/modelscope/preprocessors/nlp/nlp_base.py +++ b/modelscope/preprocessors/nlp/nlp_base.py @@ -1,62 +1,112 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import os.path as osp -import re -from typing import Any, Dict, Iterable, Optional, Tuple, Union +import os +from abc import ABC +from collections.abc import Mapping +from typing import Any, Dict, List, Tuple, Union +import json import numpy as np import torch from transformers import AutoTokenizer -from modelscope.metainfo import Models, Preprocessors +from modelscope.metainfo import Models from modelscope.outputs import OutputKeys from modelscope.preprocessors.base import Preprocessor -from modelscope.preprocessors.builder import PREPROCESSORS -from modelscope.utils.config import Config, ConfigFields -from modelscope.utils.constant import Fields, InputFields, ModeKeys, ModelFile +from modelscope.utils.constant import ModeKeys from modelscope.utils.hub import get_model_type, parse_label_mapping from modelscope.utils.logger import get_logger -from modelscope.utils.nlp import import_external_nltk_data -from modelscope.utils.type_assert import type_assert logger = get_logger() __all__ = [ - 'DocumentSegmentationPreprocessor', - 'FaqQuestionAnsweringPreprocessor', - 'NLPPreprocessor', - 'FillMaskPoNetPreprocessor', + 'NLPBasePreprocessor', 'NLPTokenizerPreprocessorBase', - 'PassageRankingPreprocessor', - 'RelationExtractionPreprocessor', - 'SentenceEmbeddingPreprocessor', - 'SequenceClassificationPreprocessor', - 'TokenClassificationPreprocessor', - 'Text2TextGenerationPreprocessor', - 'TextGenerationPreprocessor', - 'Tokenize', - 'WordSegmentationBlankSetToLabelPreprocessor', - 'ZeroShotClassificationPreprocessor', ] -@PREPROCESSORS.register_module(Fields.nlp) -class Tokenize(Preprocessor): +class NLPBasePreprocessor(Preprocessor, ABC): - def __init__(self, tokenizer_name) -> None: - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + def __init__(self, + model_dir: str, + first_sequence=None, + second_sequence=None, + label=None, + label2id=None, + mode=ModeKeys.INFERENCE, + **kwargs): + """The NLP preprocessor base class. - def __call__(self, data: Union[str, Dict[str, Any]]) -> Dict[str, Any]: - if isinstance(data, str): - data = {InputFields.text: data} - token_dict = self.tokenizer(data[InputFields.text]) - data.update(token_dict) - return data + Args: + model_dir (str): The local model path + first_sequence: The key for the first sequence + second_sequence: The key for the second sequence + label: The label key + label2id: An optional label2id mapping, the class will try to call utils.parse_label_mapping + if this mapping is not supplied. + mode: Run this preprocessor in either 'train'/'eval'/'inference' mode + """ + self.model_dir = model_dir + self.first_sequence = first_sequence + self.second_sequence = second_sequence + self.label = label + + self.use_fast = kwargs.pop('use_fast', None) + if self.use_fast is None and os.path.isfile( + os.path.join(model_dir, 'tokenizer_config.json')): + with open(os.path.join(model_dir, 'tokenizer_config.json'), + 'r') as f: + json_config = json.load(f) + self.use_fast = json_config.get('use_fast') + self.use_fast = False if self.use_fast is None else self.use_fast + + self.label2id = label2id + if self.label2id is None: + self.label2id = parse_label_mapping(self.model_dir) + super().__init__(mode, **kwargs) + + @property + def mask_id(self): + """Child preprocessor can override this property to return the id of mask token. + + Returns: + The id of mask token, default None. + """ + return None + def decode(self, + token_ids: Union[int, List[int], 'np.ndarray', 'torch.Tensor', + 'tf.Tensor'], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = True, + **kwargs): + """Turn the token_ids to real sentence. + + Args: + token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`): + List of tokenized input ids. Can be obtained using the `__call__` method. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): + Whether or not to clean up the tokenization spaces. + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific decode method. + Returns: + The real sentence decoded by the preprocessor. + """ + raise NotImplementedError() -class NLPTokenizerPreprocessorBase(Preprocessor): - def __init__(self, model_dir: str, mode: str, **kwargs): +class NLPTokenizerPreprocessorBase(NLPBasePreprocessor): + + def __init__(self, + model_dir: str, + first_sequence: str = None, + second_sequence: str = None, + label: str = 'label', + label2id: dict = None, + mode: str = ModeKeys.INFERENCE, + **kwargs): """The NLP tokenizer preprocessor base class. Any nlp preprocessor which uses the hf tokenizer can inherit from this class. @@ -65,31 +115,27 @@ class NLPTokenizerPreprocessorBase(Preprocessor): model_dir (str): The local model path first_sequence: The key for the first sequence second_sequence: The key for the second sequence - label: The label key - label2id: An optional label2id mapping, the class will try to call utils.parse_label_mapping - if this mapping is not supplied. - mode: Run this preprocessor in either 'train'/'eval'/'inference' mode + label: The key for the label + label2id: An optional label2id dict. + If label2id is None, the preprocessor will try to parse label-id mapping from: + - configuration.json model.label2id/model.id2label + - config.json label2id/id2label + - label_mapping.json + mode: Run this preprocessor in either 'train'/'eval'/'inference' mode, the behavior may be different. kwargs: These kwargs will be directly fed into the tokenizer. """ - super().__init__(**kwargs) - self.model_dir: str = model_dir - self.first_sequence: str = kwargs.pop('first_sequence', - 'first_sequence') - self.second_sequence = kwargs.pop('second_sequence', 'second_sequence') - self.sequence_length = kwargs.pop('sequence_length', 128) - - self._mode = mode - self.label = kwargs.pop('label', OutputKeys.LABEL) - self.label2id = None - if 'label2id' in kwargs: - self.label2id = kwargs.pop('label2id') - if self.label2id is None: - self.label2id = parse_label_mapping(self.model_dir) - + super().__init__(model_dir, first_sequence, second_sequence, label, + label2id, mode) + self.model_dir = model_dir self.tokenize_kwargs = kwargs - self.tokenizer = self.build_tokenizer(model_dir) + logger.info(f'The key of sentence1: {self.first_sequence}, ' + f'The key of sentence2: {self.second_sequence}, ' + f'The key of label: {self.label}') + if self.first_sequence is None: + logger.warning('[Important] first_sequence attribute is not set, ' + 'this will cause an error if your input is a dict.') @property def id2label(self): @@ -107,8 +153,11 @@ class NLPTokenizerPreprocessorBase(Preprocessor): NOTE: This default implementation only returns slow tokenizer, because the fast tokenizers have a multi-thread problem. - @param model_dir: The local model dir. - @return: The initialized tokenizer. + Args: + model_dir: The local model dir. + + Returns: + The initialized tokenizer. """ self.is_transformer_based_model = 'lstm' not in model_dir # fast version lead to parallel inference failed @@ -116,32 +165,23 @@ class NLPTokenizerPreprocessorBase(Preprocessor): if model_type in (Models.structbert, Models.gpt3, Models.palm, Models.plug): from modelscope.models.nlp.structbert import SbertTokenizer, SbertTokenizerFast - return SbertTokenizer.from_pretrained( - model_dir - ) if self._mode == ModeKeys.INFERENCE else SbertTokenizerFast.from_pretrained( - model_dir) + tokenizer = SbertTokenizerFast if self.use_fast else SbertTokenizer + return tokenizer.from_pretrained(model_dir) elif model_type == Models.veco: from modelscope.models.nlp.veco import VecoTokenizer, VecoTokenizerFast - return VecoTokenizer.from_pretrained( - model_dir - ) if self._mode == ModeKeys.INFERENCE else VecoTokenizerFast.from_pretrained( - model_dir) + tokenizer = VecoTokenizerFast if self.use_fast else VecoTokenizer + return tokenizer.from_pretrained(model_dir) elif model_type == Models.deberta_v2: from modelscope.models.nlp.deberta_v2 import DebertaV2Tokenizer, DebertaV2TokenizerFast - return DebertaV2Tokenizer.from_pretrained( - model_dir - ) if self._mode == ModeKeys.INFERENCE else DebertaV2TokenizerFast.from_pretrained( - model_dir) + tokenizer = DebertaV2TokenizerFast if self.use_fast else DebertaV2Tokenizer + return tokenizer.from_pretrained(model_dir) elif not self.is_transformer_based_model: from transformers import BertTokenizer, BertTokenizerFast - return BertTokenizer.from_pretrained( - model_dir - ) if self._mode == ModeKeys.INFERENCE else BertTokenizerFast.from_pretrained( - model_dir) + tokenizer = BertTokenizerFast if self.use_fast else BertTokenizer + return tokenizer.from_pretrained(model_dir) else: return AutoTokenizer.from_pretrained( - model_dir, - use_fast=False if self._mode == ModeKeys.INFERENCE else True) + model_dir, use_fast=self.use_fast) def __call__(self, data: Union[str, Tuple, Dict]) -> Dict[str, Any]: """process the raw input data @@ -178,8 +218,11 @@ class NLPTokenizerPreprocessorBase(Preprocessor): If the pair param is False, data will be parsed as the first_sentence and the label, else it will be parsed as the first_sentence and the second_sentence. - @param data: The input data. - @return: The sentences and labels tuple. + Args: + data: The input data. + + Returns: + The sentences and labels tuple. """ text_a, text_b, labels = None, None, None if isinstance(data, str): @@ -192,7 +235,7 @@ class NLPTokenizerPreprocessorBase(Preprocessor): text_a, text_b = data else: text_a, labels = data - elif isinstance(data, dict): + elif isinstance(data, Mapping): text_a = data.get(self.first_sequence) text_b = data.get(self.second_sequence) labels = data.get(self.label) @@ -206,958 +249,34 @@ class NLPTokenizerPreprocessorBase(Preprocessor): If the original label's type is float, or the label2id mapping does not exist, the original label will be returned. - @param labels: The input labels. - @param output: The label id. - @return: The final labels. + Args: + labels: The input labels. + output: The label id. + + Returns: + The final labels. """ def label_can_be_mapped(label): return isinstance(label, str) or isinstance(label, int) - if labels is not None: - if isinstance(labels, Iterable) and all([label_can_be_mapped(label) for label in labels]) \ + try: + if isinstance(labels, (tuple, list)) and all([label_can_be_mapped(label) for label in labels]) \ and self.label2id is not None: output[OutputKeys.LABELS] = [ - self.label2id[str(label)] for label in labels + self.label2id[label] + if label in self.label2id else self.label2id[str(label)] + for label in labels ] elif label_can_be_mapped(labels) and self.label2id is not None: - output[OutputKeys.LABELS] = self.label2id[str(labels)] - else: + output[OutputKeys.LABELS] = self.label2id[ + labels] if labels in self.label2id else self.label2id[str( + labels)] + elif labels is not None: output[OutputKeys.LABELS] = labels - - -@PREPROCESSORS.register_module(Fields.nlp, module_name=Preprocessors.fill_mask) -@PREPROCESSORS.register_module( - Fields.nlp, module_name=Preprocessors.feature_extraction) -class NLPPreprocessor(NLPTokenizerPreprocessorBase): - """The tokenizer preprocessor used in MLM task. - """ - - def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): - kwargs['truncation'] = kwargs.get('truncation', True) - kwargs['padding'] = kwargs.get('padding', 'max_length') - kwargs['max_length'] = kwargs.pop('sequence_length', 128) - kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', - True) - super().__init__(model_dir, mode=mode, **kwargs) - - -@PREPROCESSORS.register_module( - Fields.nlp, module_name=Preprocessors.passage_ranking) -class PassageRankingPreprocessor(NLPTokenizerPreprocessorBase): - """The tokenizer preprocessor used in passage ranking model. - """ - - def __init__(self, - model_dir: str, - mode=ModeKeys.INFERENCE, - *args, - **kwargs): - """preprocess the data - - Args: - model_dir (str): model path - """ - super().__init__(model_dir, pair=True, mode=mode, *args, **kwargs) - self.model_dir: str = model_dir - self.first_sequence: str = kwargs.pop('first_sequence', - 'source_sentence') - self.second_sequence = kwargs.pop('second_sequence', - 'sentences_to_compare') - self.sequence_length = kwargs.pop('sequence_length', 128) - - self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) - - @type_assert(object, (str, tuple, Dict)) - def __call__(self, data: Union[tuple, Dict]) -> Dict[str, Any]: - if isinstance(data, tuple): - sentence1, sentence2 = data - elif isinstance(data, dict): - sentence1 = data.get(self.first_sequence) - sentence2 = data.get(self.second_sequence) - if isinstance(sentence2, str): - sentence2 = [sentence2] - if isinstance(sentence1, str): - sentence1 = [sentence1] - sentence1 = sentence1 * len(sentence2) - - max_seq_length = self.sequence_length - feature = self.tokenizer( - sentence1, - sentence2, - padding='max_length', - truncation=True, - max_length=max_seq_length, - return_tensors='pt') - if 'labels' in data: - labels = data['labels'] - feature['labels'] = labels - if 'qid' in data: - qid = data['qid'] - feature['qid'] = qid - return feature - - -@PREPROCESSORS.register_module( - Fields.nlp, module_name=Preprocessors.nli_tokenizer) -@PREPROCESSORS.register_module( - Fields.nlp, module_name=Preprocessors.sen_sim_tokenizer) -@PREPROCESSORS.register_module( - Fields.nlp, module_name=Preprocessors.bert_seq_cls_tokenizer) -@PREPROCESSORS.register_module( - Fields.nlp, module_name=Preprocessors.sen_cls_tokenizer) -class SequenceClassificationPreprocessor(NLPTokenizerPreprocessorBase): - """The tokenizer preprocessor used in sequence classification. - """ - - def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): - kwargs['truncation'] = kwargs.get('truncation', True) - kwargs['padding'] = kwargs.get( - 'padding', False if mode == ModeKeys.INFERENCE else 'max_length') - kwargs['max_length'] = kwargs.pop('sequence_length', 128) - super().__init__(model_dir, mode=mode, **kwargs) - - -@PREPROCESSORS.register_module( - Fields.nlp, module_name=Preprocessors.sentence_embedding) -class SentenceEmbeddingPreprocessor(NLPTokenizerPreprocessorBase): - """The tokenizer preprocessor used in sentence embedding. - """ - - def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): - kwargs['truncation'] = kwargs.get('truncation', True) - kwargs['padding'] = kwargs.get( - 'padding', False if mode == ModeKeys.INFERENCE else 'max_length') - kwargs['max_length'] = kwargs.pop('sequence_length', 128) - super().__init__(model_dir, pair=False, mode=mode, **kwargs) - - def __call__(self, data: Union[str, Dict]) -> Dict[str, Any]: - """process the raw input data - - Args: - data Dict: - keys: "source_sentence" && "sentences_to_compare" - values: list of sentences - Example: - {"source_sentence": ["how long it take to get a master's degree"], - "sentences_to_compare": ["On average, students take about 18 to 24 months - to complete a master's degree.", - "On the other hand, some students prefer to go at a slower pace - and choose to take several years to complete their studies.", - "It can take anywhere from two semesters"]} - Returns: - Dict[str, Any]: the preprocessed data - """ - source_sentence = data['source_sentence'] - compare_sentences = data['sentences_to_compare'] - sentences = [] - sentences.append(source_sentence[0]) - for sent in compare_sentences: - sentences.append(sent) - - tokenized_inputs = self.tokenizer( - sentences, - return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None, - padding=True, - truncation=True) - return tokenized_inputs - - -@PREPROCESSORS.register_module( - Fields.nlp, module_name=Preprocessors.zero_shot_cls_tokenizer) -class ZeroShotClassificationPreprocessor(NLPTokenizerPreprocessorBase): - """The tokenizer preprocessor used in zero shot classification. - """ - - def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): - """preprocess the data - - Args: - model_dir (str): model path - """ - self.sequence_length = kwargs.pop('sequence_length', 512) - super().__init__(model_dir, mode=mode, **kwargs) - - def __call__(self, data: Union[str, Dict], hypothesis_template: str, - candidate_labels: list) -> Dict[str, Any]: - """process the raw input data - - Args: - data (str or dict): a sentence - Example: - 'you are so handsome.' - - Returns: - Dict[str, Any]: the preprocessed data - """ - if isinstance(data, dict): - data = data.get(self.first_sequence) - - pairs = [[data, hypothesis_template.format(label)] - for label in candidate_labels] - - features = self.tokenizer( - pairs, - padding=True, - truncation=True, - max_length=self.sequence_length, - truncation_strategy='only_first', - return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None) - return features - - -@PREPROCESSORS.register_module( - Fields.nlp, module_name=Preprocessors.text2text_gen_preprocessor) -class Text2TextGenerationPreprocessor(NLPTokenizerPreprocessorBase): - """The tokenizer preprocessor used in text generation. - """ - - def __init__(self, - model_dir: str, - tokenizer=None, - mode=ModeKeys.INFERENCE, - **kwargs): - self.tokenizer = self.build_tokenizer( - model_dir) if tokenizer is None else tokenizer - kwargs['truncation'] = kwargs.get('truncation', 'do_not_truncate') - kwargs['padding'] = kwargs.get('padding', False) - kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', - False) - kwargs['max_length'] = kwargs.pop('sequence_length', 128) - super().__init__(model_dir, pair=False, mode=mode, **kwargs) - - def __call__(self, data: Union[Dict, str]) -> Dict[str, Any]: - text_a, _, _ = self.parse_text_and_label(data) - - inputs = self.tokenizer( - text_a, - return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None, - **self.tokenize_kwargs) - - # This is produced by tokenizers but is an invalid generate kwargs - if 'token_type_ids' in inputs: - del inputs['token_type_ids'] - return inputs - - -@PREPROCESSORS.register_module( - Fields.nlp, module_name=Preprocessors.text_gen_tokenizer) -class TextGenerationPreprocessor(NLPTokenizerPreprocessorBase): - """The tokenizer preprocessor used in text generation. - """ - - def __init__(self, - model_dir: str, - tokenizer=None, - mode=ModeKeys.INFERENCE, - **kwargs): - kwargs['truncation'] = kwargs.get('truncation', True) - kwargs['padding'] = kwargs.get('padding', 'max_length') - kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', - False) - kwargs['max_length'] = kwargs.pop('sequence_length', 128) - super().__init__(model_dir, mode=mode, **kwargs) - - @staticmethod - def get_roberta_tokenizer_dir(model_dir: str) -> Optional[str]: - import os - for name in os.listdir(model_dir): - full_name = os.path.join(model_dir, name) - if 'roberta' in name and os.path.isdir(full_name): - return full_name - - def build_tokenizer(self, model_dir: str): - roberta_tokenizer_dir = self.get_roberta_tokenizer_dir(model_dir) - if roberta_tokenizer_dir: - from transformers import RobertaTokenizer - return RobertaTokenizer.from_pretrained( - roberta_tokenizer_dir, do_lower_case=False) - return super().build_tokenizer(model_dir) - - def __call__(self, data: Union[Dict, str]) -> Dict[str, Any]: - if self._mode == ModeKeys.INFERENCE: - return super().__call__(data) - src_rst = super().__call__(data['src_txt']) - src_input_ids = src_rst['input_ids'] - src_attention_mask = src_rst['attention_mask'] - if 'tgt_txt' in data: - labels = super().__call__(data['tgt_txt'])['input_ids'] - else: - labels = src_input_ids[1:] - src_input_ids = src_input_ids[:-1] - src_attention_mask = src_attention_mask[:-1] - - return { - 'input_ids': src_input_ids, - 'attention_mask': src_attention_mask, - 'labels': labels, - } - - -@PREPROCESSORS.register_module( - Fields.nlp, - module_name=Preprocessors.word_segment_text_to_label_preprocessor) -class WordSegmentationBlankSetToLabelPreprocessor(Preprocessor): - """The preprocessor used to turn a single sentence to a labeled token-classification dict. - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.first_sequence: str = kwargs.pop('first_sequence', - 'first_sequence') - self.label = kwargs.pop('label', OutputKeys.LABELS) - - def __call__(self, data: str) -> Union[Dict[str, Any], Tuple]: - data = data.split(' ') - data = list(filter(lambda x: len(x) > 0, data)) - - def produce_train_sample(words): - chars = [] - labels = [] - for word in words: - chars.extend(list(word)) - if len(word) == 1: - labels.append('S-CWS') - else: - labels.extend(['B-CWS'] + ['I-CWS'] * (len(word) - 2) - + ['E-CWS']) - assert len(chars) == len(labels) - return chars, labels - - chars, labels = produce_train_sample(data) - return { - self.first_sequence: chars, - self.label: labels, - } - - -@PREPROCESSORS.register_module( - Fields.nlp, module_name=Preprocessors.ner_tokenizer) -@PREPROCESSORS.register_module( - Fields.nlp, module_name=Preprocessors.token_cls_tokenizer) -@PREPROCESSORS.register_module( - Fields.nlp, module_name=Preprocessors.sequence_labeling_tokenizer) -class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): - """The tokenizer preprocessor used in normal NER task. - """ - - def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): - """preprocess the data - - Args: - model_dir (str): model path - """ - kwargs['truncation'] = kwargs.get('truncation', True) - kwargs['padding'] = kwargs.get( - 'padding', False if mode == ModeKeys.INFERENCE else 'max_length') - kwargs['max_length'] = kwargs.pop('sequence_length', 128) - self.label_all_tokens = kwargs.pop('label_all_tokens', False) - super().__init__(model_dir, mode=mode, **kwargs) - - if 'is_split_into_words' in kwargs: - self.is_split_into_words = kwargs.pop('is_split_into_words') - else: - self.is_split_into_words = self.tokenizer.init_kwargs.get( - 'is_split_into_words', False) - if 'label2id' in kwargs: - kwargs.pop('label2id') - self.tokenize_kwargs = kwargs - - @type_assert(object, str) - def __call__(self, data: str) -> Dict[str, Any]: - """process the raw input data - - Args: - data (str): a sentence - Example: - 'you are so handsome.' - - Returns: - Dict[str, Any]: the preprocessed data - """ - - # preprocess the data for the model input - text = None - labels_list = None - if isinstance(data, str): - text = data - elif isinstance(data, dict): - text = data.get(self.first_sequence) - labels_list = data.get(self.label) - - input_ids = [] - label_mask = [] - offset_mapping = [] - if self.is_split_into_words: - for offset, token in enumerate(list(data)): - subtoken_ids = self.tokenizer.encode( - token, add_special_tokens=False) - if len(subtoken_ids) == 0: - subtoken_ids = [self.tokenizer.unk_token_id] - input_ids.extend(subtoken_ids) - label_mask.extend([1] + [0] * (len(subtoken_ids) - 1)) - offset_mapping.extend([(offset, offset + 1)]) - else: - if self.tokenizer.is_fast: - encodings = self.tokenizer( - text, - add_special_tokens=False, - return_offsets_mapping=True, - **self.tokenize_kwargs) - input_ids = encodings['input_ids'] - word_ids = encodings.word_ids() - for i in range(len(word_ids)): - if word_ids[i] is None: - label_mask.append(0) - elif word_ids[i] == word_ids[i - 1]: - label_mask.append(0) - offset_mapping[-1] = ( - offset_mapping[-1][0], - encodings['offset_mapping'][i][1]) - else: - label_mask.append(1) - offset_mapping.append(encodings['offset_mapping'][i]) - else: - encodings = self.tokenizer( - text, add_special_tokens=False, **self.tokenize_kwargs) - input_ids = encodings['input_ids'] - label_mask, offset_mapping = self.get_label_mask_and_offset_mapping( - text) - - if len(input_ids) >= self.sequence_length - 2: - input_ids = input_ids[:self.sequence_length - 2] - label_mask = label_mask[:self.sequence_length - 2] - input_ids = [self.tokenizer.cls_token_id - ] + input_ids + [self.tokenizer.sep_token_id] - label_mask = [0] + label_mask + [0] - attention_mask = [1] * len(input_ids) - offset_mapping = offset_mapping[:sum(label_mask)] - - if not self.is_transformer_based_model: - input_ids = input_ids[1:-1] - attention_mask = attention_mask[1:-1] - label_mask = label_mask[1:-1] - - if self._mode == ModeKeys.INFERENCE: - input_ids = torch.tensor(input_ids).unsqueeze(0) - attention_mask = torch.tensor(attention_mask).unsqueeze(0) - label_mask = torch.tensor( - label_mask, dtype=torch.bool).unsqueeze(0) - - # the token classification - output = { - 'text': text, - 'input_ids': input_ids, - 'attention_mask': attention_mask, - 'label_mask': label_mask, - 'offset_mapping': offset_mapping - } - - # align the labels with tokenized text - if labels_list is not None: - assert self.label2id is not None - # Map that sends B-Xxx label to its I-Xxx counterpart - b_to_i_label = [] - label_enumerate_values = [ - k for k, v in sorted( - self.label2id.items(), key=lambda item: item[1]) - ] - for idx, label in enumerate(label_enumerate_values): - if label.startswith('B-') and label.replace( - 'B-', 'I-') in label_enumerate_values: - b_to_i_label.append( - label_enumerate_values.index( - label.replace('B-', 'I-'))) - else: - b_to_i_label.append(idx) - - label_row = [self.label2id[lb] for lb in labels_list] - previous_word_idx = None - label_ids = [] - for word_idx in word_ids: - if word_idx is None: - label_ids.append(-100) - elif word_idx != previous_word_idx: - label_ids.append(label_row[word_idx]) - else: - if self.label_all_tokens: - label_ids.append(b_to_i_label[label_row[word_idx]]) - else: - label_ids.append(-100) - previous_word_idx = word_idx - labels = label_ids - output['labels'] = labels - return output - - def get_tokenizer_class(self): - tokenizer_class = self.tokenizer.__class__.__name__ - if tokenizer_class.endswith( - 'Fast') and tokenizer_class != 'PreTrainedTokenizerFast': - tokenizer_class = tokenizer_class[:-4] - return tokenizer_class - - def get_label_mask_and_offset_mapping(self, text): - label_mask = [] - offset_mapping = [] - tokens = self.tokenizer.tokenize(text) - offset = 0 - if self.get_tokenizer_class() == 'BertTokenizer': - for token in tokens: - is_start = (token[:2] != '##') - if is_start: - label_mask.append(True) - else: - token = token[2:] - label_mask.append(False) - start = offset + text[offset:].index(token) - end = start + len(token) - if is_start: - offset_mapping.append((start, end)) - else: - offset_mapping[-1] = (offset_mapping[-1][0], end) - offset = end - elif self.get_tokenizer_class() == 'XLMRobertaTokenizer': - last_is_blank = False - for token in tokens: - is_start = (token[0] == '▁') - if is_start: - token = token[1:] - label_mask.append(True) - if len(token) == 0: - last_is_blank = True - continue - else: - label_mask.append(False) - start = offset + text[offset:].index(token) - end = start + len(token) - if last_is_blank or is_start: - offset_mapping.append((start, end)) - else: - offset_mapping[-1] = (offset_mapping[-1][0], end) - offset = end - last_is_blank = False - else: - raise NotImplementedError - - return label_mask, offset_mapping - - -@PREPROCESSORS.register_module( - Fields.nlp, module_name=Preprocessors.re_tokenizer) -class RelationExtractionPreprocessor(Preprocessor): - """The relation extraction preprocessor used in normal RE task. - """ - - def __init__(self, model_dir: str, *args, **kwargs): - """preprocess the data - - Args: - model_dir (str): model path - """ - - super().__init__(*args, **kwargs) - - self.model_dir: str = model_dir - self.sequence_length = kwargs.pop('sequence_length', 512) - self.tokenizer = AutoTokenizer.from_pretrained( - model_dir, use_fast=True) - - @type_assert(object, str) - def __call__(self, data: str) -> Dict[str, Any]: - """process the raw input data - - Args: - data (str): a sentence - Example: - 'you are so handsome.' - - Returns: - Dict[str, Any]: the preprocessed data - """ - - # preprocess the data for the model input - text = data - output = self.tokenizer([text], return_tensors='pt') - return { - 'text': text, - 'input_ids': output['input_ids'], - 'attention_mask': output['attention_mask'], - 'offsets': output[0].offsets - } - - -@PREPROCESSORS.register_module( - Fields.nlp, module_name=Preprocessors.faq_question_answering_preprocessor) -class FaqQuestionAnsweringPreprocessor(Preprocessor): - - def __init__(self, model_dir: str, *args, **kwargs): - super(FaqQuestionAnsweringPreprocessor, self).__init__( - model_dir, mode=ModeKeys.INFERENCE, **kwargs) - import os - from transformers import BertTokenizer - - from modelscope.utils.config import Config - from modelscope.utils.constant import ModelFile - self.tokenizer = BertTokenizer.from_pretrained(model_dir) - preprocessor_config = Config.from_file( - os.path.join(model_dir, ModelFile.CONFIGURATION)).get( - ConfigFields.preprocessor, {}) - self.MAX_LEN = preprocessor_config.get('max_seq_length', 50) - self.label_dict = None - - def pad(self, samples, max_len): - result = [] - for sample in samples: - pad_len = max_len - len(sample[:max_len]) - result.append(sample[:max_len] - + [self.tokenizer.pad_token_id] * pad_len) - return result - - def set_label_dict(self, label_dict): - self.label_dict = label_dict - - def get_label(self, label_id): - assert self.label_dict is not None and label_id < len(self.label_dict) - return self.label_dict[label_id] - - def encode_plus(self, text): - return [ - self.tokenizer.cls_token_id - ] + self.tokenizer.convert_tokens_to_ids( - self.tokenizer.tokenize(text)) + [self.tokenizer.sep_token_id] - - @type_assert(object, Dict) - def __call__(self, data: Dict[str, Any], - **preprocessor_param) -> Dict[str, Any]: - TMP_MAX_LEN = preprocessor_param.get('max_seq_length', self.MAX_LEN) - queryset = data['query_set'] - if not isinstance(queryset, list): - queryset = [queryset] - supportset = data['support_set'] - supportset = sorted(supportset, key=lambda d: d['label']) - - queryset_tokenized = [self.encode_plus(text) for text in queryset] - supportset_tokenized = [ - self.encode_plus(item['text']) for item in supportset - ] - - max_len = max( - [len(seq) for seq in queryset_tokenized + supportset_tokenized]) - max_len = min(TMP_MAX_LEN, max_len) - queryset_padded = self.pad(queryset_tokenized, max_len) - supportset_padded = self.pad(supportset_tokenized, max_len) - - supportset_labels_ori = [item['label'] for item in supportset] - label_dict = [] - for label in supportset_labels_ori: - if label not in label_dict: - label_dict.append(label) - self.set_label_dict(label_dict) - supportset_labels_ids = [ - label_dict.index(label) for label in supportset_labels_ori - ] - return { - 'query': queryset_padded, - 'support': supportset_padded, - 'support_labels': supportset_labels_ids - } - - def batch_encode(self, sentence_list: list, max_length=None): - if not max_length: - max_length = self.MAX_LEN - return self.tokenizer.batch_encode_plus( - sentence_list, padding=True, max_length=max_length) - - -@PREPROCESSORS.register_module( - Fields.nlp, module_name=Preprocessors.document_segmentation) -class DocumentSegmentationPreprocessor(Preprocessor): - - def __init__(self, model_dir: str, config, *args, **kwargs): - """preprocess the data - - Args: - model_dir (str): model path - """ - - super().__init__(*args, **kwargs) - from transformers import BertTokenizerFast - self.tokenizer = BertTokenizerFast.from_pretrained( - model_dir, - use_fast=True, - ) - self.question_column_name = 'labels' - self.context_column_name = 'sentences' - self.example_id_column_name = 'example_id' - self.label_to_id = {'B-EOP': 0, 'O': 1} - self.target_specical_ids = set() - self.target_specical_ids.add(self.tokenizer.eos_token_id) - self.max_seq_length = config.max_position_embeddings - self.label_list = ['B-EOP', 'O'] - - def __call__(self, examples) -> Dict[str, Any]: - questions = examples[self.question_column_name] - contexts = examples[self.context_column_name] - example_ids = examples[self.example_id_column_name] - num_examples = len(questions) - - sentences = [] - for sentence_list in contexts: - sentence_list = [_ + '[EOS]' for _ in sentence_list] - sentences.append(sentence_list) - - try: - tokenized_examples = self.tokenizer( - sentences, - is_split_into_words=True, - add_special_tokens=False, - return_token_type_ids=True, - return_attention_mask=True, - ) - except Exception as e: - logger.error(e) - return {} - - segment_ids = [] - token_seq_labels = [] - for example_index in range(num_examples): - example_input_ids = tokenized_examples['input_ids'][example_index] - example_labels = questions[example_index] - example_labels = [ - self.label_to_id[_] if _ in self.label_to_id else -100 - for _ in example_labels - ] - example_token_labels = [] - segment_id = [] - cur_seg_id = 1 - for token_index in range(len(example_input_ids)): - if example_input_ids[token_index] in self.target_specical_ids: - example_token_labels.append(example_labels[cur_seg_id - 1]) - segment_id.append(cur_seg_id) - cur_seg_id += 1 - else: - example_token_labels.append(-100) - segment_id.append(cur_seg_id) - - segment_ids.append(segment_id) - token_seq_labels.append(example_token_labels) - - tokenized_examples['segment_ids'] = segment_ids - tokenized_examples['token_seq_labels'] = token_seq_labels - - new_segment_ids = [] - new_token_seq_labels = [] - new_input_ids = [] - new_token_type_ids = [] - new_attention_mask = [] - new_example_ids = [] - new_sentences = [] - - for example_index in range(num_examples): - example_input_ids = tokenized_examples['input_ids'][example_index] - example_token_type_ids = tokenized_examples['token_type_ids'][ - example_index] - example_attention_mask = tokenized_examples['attention_mask'][ - example_index] - example_segment_ids = tokenized_examples['segment_ids'][ - example_index] - example_token_seq_labels = tokenized_examples['token_seq_labels'][ - example_index] - example_sentences = contexts[example_index] - example_id = example_ids[example_index] - example_total_num_sentences = len(questions[example_index]) - example_total_num_tokens = len( - tokenized_examples['input_ids'][example_index]) - accumulate_length = [ - i for i, x in enumerate(tokenized_examples['input_ids'] - [example_index]) - if x == self.tokenizer.eos_token_id - ] - samples_boundary = [] - left_index = 0 - sent_left_index = 0 - sent_i = 0 - - # for sent_i, length in enumerate(accumulate_length): - while sent_i < len(accumulate_length): - length = accumulate_length[sent_i] - right_index = length + 1 - sent_right_index = sent_i + 1 - if right_index - left_index >= self.max_seq_length - 1 or right_index == example_total_num_tokens: - samples_boundary.append([left_index, right_index]) - - sample_input_ids = [ - self.tokenizer.cls_token_id - ] + example_input_ids[left_index:right_index] - sample_input_ids = sample_input_ids[:self.max_seq_length] - - sample_token_type_ids = [ - 0 - ] + example_token_type_ids[left_index:right_index] - sample_token_type_ids = sample_token_type_ids[:self. - max_seq_length] - - sample_attention_mask = [ - 1 - ] + example_attention_mask[left_index:right_index] - sample_attention_mask = sample_attention_mask[:self. - max_seq_length] - - sample_segment_ids = [ - 0 - ] + example_segment_ids[left_index:right_index] - sample_segment_ids = sample_segment_ids[:self. - max_seq_length] - - sample_token_seq_labels = [ - -100 - ] + example_token_seq_labels[left_index:right_index] - sample_token_seq_labels = sample_token_seq_labels[:self. - max_seq_length] - - if sent_right_index - 1 == sent_left_index: - left_index = right_index - sample_input_ids[-1] = self.tokenizer.eos_token_id - sample_token_seq_labels[-1] = -100 - else: - left_index = accumulate_length[sent_i - 1] + 1 - if sample_token_seq_labels[-1] != -100: - sample_token_seq_labels[-1] = -100 - - if sent_right_index - 1 == sent_left_index or right_index == example_total_num_tokens: - sample_sentences = example_sentences[ - sent_left_index:sent_right_index] - sent_left_index = sent_right_index - sent_i += 1 - else: - sample_sentences = example_sentences[ - sent_left_index:sent_right_index - 1] - sent_left_index = sent_right_index - 1 - - if (len([_ for _ in sample_token_seq_labels if _ != -100 - ])) != len(sample_sentences) - 1 and (len([ - _ - for _ in sample_token_seq_labels if _ != -100 - ])) != len(sample_sentences): - tmp = [] - for w_i, w, l in zip( - sample_input_ids, - self.tokenizer.decode(sample_input_ids).split( - ' '), sample_token_seq_labels): - tmp.append((w_i, w, l)) - while len(sample_input_ids) < self.max_seq_length: - sample_input_ids.append(self.tokenizer.pad_token_id) - sample_token_type_ids.append(0) - sample_attention_mask.append(0) - sample_segment_ids.append(example_total_num_sentences - + 1) - sample_token_seq_labels.append(-100) - - new_input_ids.append(sample_input_ids) - new_token_type_ids.append(sample_token_type_ids) - new_attention_mask.append(sample_attention_mask) - new_segment_ids.append(sample_segment_ids) - new_token_seq_labels.append(sample_token_seq_labels) - new_example_ids.append(example_id) - new_sentences.append(sample_sentences) - else: - sent_i += 1 - continue - - output_samples = {} - - output_samples['input_ids'] = new_input_ids - output_samples['token_type_ids'] = new_token_type_ids - output_samples['attention_mask'] = new_attention_mask - - output_samples['segment_ids'] = new_segment_ids - output_samples['example_id'] = new_example_ids - output_samples['labels'] = new_token_seq_labels - output_samples['sentences'] = new_sentences - - return output_samples - - -@PREPROCESSORS.register_module( - Fields.nlp, module_name=Preprocessors.fill_mask_ponet) -class FillMaskPoNetPreprocessor(NLPTokenizerPreprocessorBase): - """The tokenizer preprocessor used in MLM task. - """ - - def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): - kwargs['truncation'] = kwargs.get('truncation', True) - kwargs['padding'] = kwargs.get('padding', 'max_length') - kwargs['max_length'] = kwargs.pop('sequence_length', 512) - kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', - True) - super().__init__(model_dir, pair=False, mode=mode, **kwargs) - - self.cfg = Config.from_file( - osp.join(model_dir, ModelFile.CONFIGURATION)) - self.language = self.cfg.model.get('language', 'en') - if self.language == 'en': - from nltk.tokenize import sent_tokenize - import_external_nltk_data( - osp.join(model_dir, 'nltk_data'), 'tokenizers/punkt') - elif self.language in ['zh', 'cn']: - - def sent_tokenize(para): - para = re.sub(r'([。!!?\?])([^”’])', r'\1\n\2', para) # noqa * - para = re.sub(r'(\.{6})([^”’])', r'\1\n\2', para) # noqa * - para = re.sub(r'(\…{2})([^”’])', r'\1\n\2', para) # noqa * - para = re.sub(r'([。!?\?][”’])([^,。!?\?])', r'\1\n\2', - para) # noqa * - para = para.rstrip() - return [_ for _ in para.split('\n') if _] - else: - raise NotImplementedError - - self.sent_tokenize = sent_tokenize - self.max_length = kwargs['max_length'] - - def __call__(self, data: Union[str, Tuple, Dict]) -> Dict[str, Any]: - """process the raw input data - - Args: - data (tuple): [sentence1, sentence2] - sentence1 (str): a sentence - Example: - 'you are so handsome.' - sentence2 (str): a sentence - Example: - 'you are so beautiful.' - Returns: - Dict[str, Any]: the preprocessed data - """ - - text_a, text_b, labels = self.parse_text_and_label(data) - output = self.tokenizer( - text_a, - text_b, - return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None, - **self.tokenize_kwargs) - max_seq_length = self.max_length - - if text_b is None: - segment_ids = [] - seg_lens = list( - map( - len, - self.tokenizer( - self.sent_tokenize(text_a), - add_special_tokens=False, - truncation=True)['input_ids'])) - segment_id = [0] + sum( - [[i] * sl for i, sl in enumerate(seg_lens, start=1)], []) - segment_id = segment_id[:max_seq_length - 1] - segment_ids.append(segment_id + [segment_id[-1] + 1] - * (max_seq_length - len(segment_id))) - output['segment_ids'] = segment_ids - - output = { - k: np.array(v) if isinstance(v, list) else v - for k, v in output.items() - } - - self.labels_to_id(labels, output) - return output + except KeyError as e: + logger.error( + f'Label {labels} cannot be found in the label mapping {self.label2id},' + f'which comes from the user input or the configuration files. ' + f'Please consider matching your labels with this mapping.') + raise e diff --git a/modelscope/preprocessors/nlp/relation_extraction_preprocessor.py b/modelscope/preprocessors/nlp/relation_extraction_preprocessor.py new file mode 100644 index 00000000..9a426ab7 --- /dev/null +++ b/modelscope/preprocessors/nlp/relation_extraction_preprocessor.py @@ -0,0 +1,55 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict + +from transformers import AutoTokenizer + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields +from modelscope.utils.type_assert import type_assert +from .nlp_base import NLPBasePreprocessor + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.re_tokenizer) +class RelationExtractionPreprocessor(NLPBasePreprocessor): + """The relation extraction preprocessor used in normal RE task. + """ + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data + + Args: + model_dir (str): model path + """ + + super().__init__(model_dir, *args, **kwargs) + + self.model_dir: str = model_dir + self.sequence_length = kwargs.pop('sequence_length', 512) + self.tokenizer = AutoTokenizer.from_pretrained( + model_dir, use_fast=True) + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + 'you are so handsome.' + + Returns: + Dict[str, Any]: the preprocessed data + """ + + # preprocess the data for the model input + text = data + output = self.tokenizer([text], return_tensors='pt') + return { + 'text': text, + 'input_ids': output['input_ids'], + 'attention_mask': output['attention_mask'], + 'offsets': output[0].offsets + } diff --git a/modelscope/preprocessors/nlp/sentence_classification_preprocessor.py b/modelscope/preprocessors/nlp/sentence_classification_preprocessor.py new file mode 100644 index 00000000..f1295c50 --- /dev/null +++ b/modelscope/preprocessors/nlp/sentence_classification_preprocessor.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields, ModeKeys +from .nlp_base import NLPTokenizerPreprocessorBase + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.nli_tokenizer) +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.sen_sim_tokenizer) +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.bert_seq_cls_tokenizer) +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.sen_cls_tokenizer) +class SequenceClassificationPreprocessor(NLPTokenizerPreprocessorBase): + """The tokenizer preprocessor used in sequence classification. + """ + + def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): + kwargs['truncation'] = kwargs.get('truncation', True) + kwargs['padding'] = kwargs.get('padding', 'max_length') + kwargs['max_length'] = kwargs.pop('sequence_length', 128) + super().__init__(model_dir, mode=mode, **kwargs) diff --git a/modelscope/preprocessors/nlp/sentence_embedding_preprocessor.py b/modelscope/preprocessors/nlp/sentence_embedding_preprocessor.py new file mode 100644 index 00000000..519de60c --- /dev/null +++ b/modelscope/preprocessors/nlp/sentence_embedding_preprocessor.py @@ -0,0 +1,52 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Union + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields, ModeKeys +from .nlp_base import NLPTokenizerPreprocessorBase + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.sentence_embedding) +class SentenceEmbeddingPreprocessor(NLPTokenizerPreprocessorBase): + """The tokenizer preprocessor used in sentence embedding. + """ + + def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): + kwargs['truncation'] = kwargs.get('truncation', True) + kwargs['padding'] = kwargs.get('padding', 'max_length') + kwargs['max_length'] = kwargs.pop('sequence_length', 128) + super().__init__(model_dir, mode=mode, **kwargs) + + def __call__(self, data: Union[str, Dict]) -> Dict[str, Any]: + """process the raw input data + + Args: + data Dict: + keys: "source_sentence" && "sentences_to_compare" + values: list of sentences + Example: + {"source_sentence": ["how long it take to get a master's degree"], + "sentences_to_compare": ["On average, students take about 18 to 24 months + to complete a master's degree.", + "On the other hand, some students prefer to go at a slower pace + and choose to take several years to complete their studies.", + "It can take anywhere from two semesters"]} + Returns: + Dict[str, Any]: the preprocessed data + """ + source_sentence = data['source_sentence'] + compare_sentences = data['sentences_to_compare'] + sentences = [] + sentences.append(source_sentence[0]) + for sent in compare_sentences: + sentences.append(sent) + + tokenized_inputs = self.tokenizer( + sentences, + return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None, + padding=True, + truncation=True) + return tokenized_inputs diff --git a/modelscope/preprocessors/nlp/sentence_piece_preprocessor.py b/modelscope/preprocessors/nlp/sentence_piece_preprocessor.py new file mode 100644 index 00000000..1d1ef19d --- /dev/null +++ b/modelscope/preprocessors/nlp/sentence_piece_preprocessor.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os.path as osp +from typing import Any, Dict + +import sentencepiece as spm +import torch + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.sentence_piece) +class SentencePiecePreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + import os + + super().__init__(*args, **kwargs) + self.tokenizer = None + for file_name in os.listdir(model_dir): + if file_name.endswith('.model'): + m_file = osp.join(model_dir, file_name) + self.tokenizer = spm.SentencePieceProcessor(model_file=m_file) + break + assert self.tokenizer is not None, 'Can not find .model file' + + def __call__(self, data: str) -> Dict[str, Any]: + return torch.tensor(self.tokenizer.encode([data]), dtype=torch.long) diff --git a/modelscope/preprocessors/space/__init__.py b/modelscope/preprocessors/nlp/space/__init__.py similarity index 100% rename from modelscope/preprocessors/space/__init__.py rename to modelscope/preprocessors/nlp/space/__init__.py diff --git a/modelscope/preprocessors/space/args.py b/modelscope/preprocessors/nlp/space/args.py similarity index 97% rename from modelscope/preprocessors/space/args.py rename to modelscope/preprocessors/nlp/space/args.py index d9e91e74..17c6828b 100644 --- a/modelscope/preprocessors/space/args.py +++ b/modelscope/preprocessors/nlp/space/args.py @@ -1,7 +1,4 @@ -""" -Parse argument. -""" - +# Copyright (c) Alibaba, Inc. and its affiliates. import argparse import json diff --git a/modelscope/preprocessors/space/batch.py b/modelscope/preprocessors/nlp/space/batch.py similarity index 96% rename from modelscope/preprocessors/space/batch.py rename to modelscope/preprocessors/nlp/space/batch.py index fe0ad0ec..d27776f5 100644 --- a/modelscope/preprocessors/space/batch.py +++ b/modelscope/preprocessors/nlp/space/batch.py @@ -1,3 +1,6 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + + def batch(reader, batch_size, drop_last=False): """ This operator creates a batched reader which combines the data from the diff --git a/modelscope/preprocessors/space/data_loader.py b/modelscope/preprocessors/nlp/space/data_loader.py similarity index 87% rename from modelscope/preprocessors/space/data_loader.py rename to modelscope/preprocessors/nlp/space/data_loader.py index bd04a79c..290b64f3 100644 --- a/modelscope/preprocessors/space/data_loader.py +++ b/modelscope/preprocessors/nlp/space/data_loader.py @@ -1,18 +1,16 @@ -""" -DataLoader class -""" +# Copyright (c) Alibaba, Inc. and its affiliates. import math import os import numpy as np -from modelscope.preprocessors.space.args import str2bool -from modelscope.preprocessors.space.batch import batch -from modelscope.preprocessors.space.lazy_dataset import LazyDataset -from modelscope.preprocessors.space.sampler import (RandomSampler, - SequentialSampler, - SortedSampler) +from modelscope.preprocessors.nlp.space.args import str2bool +from modelscope.preprocessors.nlp.space.batch import batch +from modelscope.preprocessors.nlp.space.lazy_dataset import LazyDataset +from modelscope.preprocessors.nlp.space.sampler import (RandomSampler, + SequentialSampler, + SortedSampler) def get_data_loader(batch_size, reader, hparams, file, collate_fn, is_test): diff --git a/modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py b/modelscope/preprocessors/nlp/space/dialog_intent_prediction_preprocessor.py similarity index 64% rename from modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py rename to modelscope/preprocessors/nlp/space/dialog_intent_prediction_preprocessor.py index e2602eaa..2923157e 100644 --- a/modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py +++ b/modelscope/preprocessors/nlp/space/dialog_intent_prediction_preprocessor.py @@ -8,8 +8,7 @@ import json from modelscope.metainfo import Preprocessors from modelscope.preprocessors.base import Preprocessor from modelscope.preprocessors.builder import PREPROCESSORS -from modelscope.preprocessors.space.fields.intent_field import \ - IntentBPETextField +from modelscope.preprocessors.nlp import IntentBPETextField from modelscope.utils.config import Config from modelscope.utils.constant import Fields, ModelFile from modelscope.utils.type_assert import type_assert @@ -47,10 +46,25 @@ class DialogIntentPredictionPreprocessor(Preprocessor): Args: data (str): a sentence Example: - 'you are so handsome.' + 'What do I need to do for the card activation?' Returns: Dict[str, Any]: the preprocessed data + Example: + { + 'src_token': array([[13, 2054, 2079, 1045...]]), + 'src_pos': array([[ 0, 1, 2, 3...]]), + 'src_type': array([[1, 1, 1, 1...]]), + 'src_turn': array([[1, 1, 1, 1...]]), + 'src_mask': array([[1, 1, 1, 1...]]), + 'mlm_token': array([[13, 2054, 2079, 1045...]]), + 'mlm_label': array([[0, 0, 0, 0...]]), + 'mlm_mask': array([[0, 0, 0, 0...]]), + 'tgt_token': array([[29, 30, 31, 32...]]), + 'tgt_mask': array([[1, 1, 1, 1...]]), + 'ids': array([0]), + 'intent_label': array([-1]) + } """ samples = self.text_field.preprocessor([data]) samples, _ = self.text_field.collate_fn_multi_turn(samples) diff --git a/modelscope/preprocessors/space/dialog_modeling_preprocessor.py b/modelscope/preprocessors/nlp/space/dialog_modeling_preprocessor.py similarity index 75% rename from modelscope/preprocessors/space/dialog_modeling_preprocessor.py rename to modelscope/preprocessors/nlp/space/dialog_modeling_preprocessor.py index c461ade1..ae3c214a 100644 --- a/modelscope/preprocessors/space/dialog_modeling_preprocessor.py +++ b/modelscope/preprocessors/nlp/space/dialog_modeling_preprocessor.py @@ -6,8 +6,7 @@ from typing import Any, Dict from modelscope.metainfo import Preprocessors from modelscope.preprocessors.base import Preprocessor from modelscope.preprocessors.builder import PREPROCESSORS -from modelscope.preprocessors.space.fields.gen_field import \ - MultiWOZBPETextField +from modelscope.preprocessors.nlp import MultiWOZBPETextField from modelscope.utils.config import Config from modelscope.utils.constant import Fields, ModelFile from modelscope.utils.type_assert import type_assert @@ -42,9 +41,19 @@ class DialogModelingPreprocessor(Preprocessor): """process the raw input data Args: - data (str): a sentence + data (Dict[str, Any]): A sentence and dialogue history info. Example: - 'you are so handsome.' + { + 'user_input': 'i want to leave after 17:15 .', + 'history': { + 'labels': [[13, 1045, 2052, 2066...]], + 'resp': [14, 1045, 2064, 2393...], + 'bspn': [15, 43, 7688, 10733...], + 'db': [19, 24, 20], + 'aspn': [16, 43, 48, 2681, 7180, 10], + 'output': ['i', 'can', 'help', 'with'...] + } + } Returns: Dict[str, Any]: the preprocessed data diff --git a/modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py b/modelscope/preprocessors/nlp/space/dialog_state_tracking_preprocessor.py similarity index 92% rename from modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py rename to modelscope/preprocessors/nlp/space/dialog_state_tracking_preprocessor.py index 6eb17288..cff39577 100644 --- a/modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py +++ b/modelscope/preprocessors/nlp/space/dialog_state_tracking_preprocessor.py @@ -31,13 +31,17 @@ class DialogStateTrackingPreprocessor(Preprocessor): self.processor = multiwoz22Processor() @type_assert(object, dict) - def __call__(self, data: Dict) -> Dict[str, Any]: + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """process the raw input data Args: - data (str): a sentence + data (Dict[str, Any]): a sentence Example: - 'you are so handsome.' + { + 'utter': {'User-1': "Hi, I'm looking for a train that is going" + "to cambridge and arriving there by 20:45, is there anything like that?"}, + 'history_states': [{}] + } Returns: Dict[str, Any]: the preprocessed data diff --git a/modelscope/preprocessors/space/dst_processors.py b/modelscope/preprocessors/nlp/space/dst_processors.py similarity index 100% rename from modelscope/preprocessors/space/dst_processors.py rename to modelscope/preprocessors/nlp/space/dst_processors.py diff --git a/modelscope/preprocessors/nlp/space/fields/__init__.py b/modelscope/preprocessors/nlp/space/fields/__init__.py new file mode 100644 index 00000000..475a99dc --- /dev/null +++ b/modelscope/preprocessors/nlp/space/fields/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .gen_field import MultiWOZBPETextField + from .intent_field import IntentBPETextField +else: + _import_structure = { + 'gen_field': ['MultiWOZBPETextField'], + 'intent_field': ['IntentBPETextField'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/preprocessors/space/fields/gen_field.py b/modelscope/preprocessors/nlp/space/fields/gen_field.py similarity index 99% rename from modelscope/preprocessors/space/fields/gen_field.py rename to modelscope/preprocessors/nlp/space/fields/gen_field.py index 32346bd5..1d1879fe 100644 --- a/modelscope/preprocessors/space/fields/gen_field.py +++ b/modelscope/preprocessors/nlp/space/fields/gen_field.py @@ -9,7 +9,7 @@ from itertools import chain import json import numpy as np -from modelscope.preprocessors.space.tokenizer import Tokenizer +from modelscope.preprocessors.nlp.space.tokenizer import Tokenizer from modelscope.utils.constant import ModelFile from modelscope.utils.logger import get_logger from modelscope.utils.nlp.space import ontology, utils diff --git a/modelscope/preprocessors/space/fields/intent_field.py b/modelscope/preprocessors/nlp/space/fields/intent_field.py similarity index 99% rename from modelscope/preprocessors/space/fields/intent_field.py rename to modelscope/preprocessors/nlp/space/fields/intent_field.py index 6d3b5fff..29ea915e 100644 --- a/modelscope/preprocessors/space/fields/intent_field.py +++ b/modelscope/preprocessors/nlp/space/fields/intent_field.py @@ -13,7 +13,7 @@ import json import numpy as np from tqdm import tqdm -from modelscope.preprocessors.space.tokenizer import Tokenizer +from modelscope.preprocessors.nlp.space.tokenizer import Tokenizer from modelscope.utils.constant import ModelFile from modelscope.utils.nlp.space import ontology from modelscope.utils.nlp.space.scores import hierarchical_set_score diff --git a/modelscope/preprocessors/space/lazy_dataset.py b/modelscope/preprocessors/nlp/space/lazy_dataset.py similarity index 93% rename from modelscope/preprocessors/space/lazy_dataset.py rename to modelscope/preprocessors/nlp/space/lazy_dataset.py index 8da21db7..536d9341 100644 --- a/modelscope/preprocessors/space/lazy_dataset.py +++ b/modelscope/preprocessors/nlp/space/lazy_dataset.py @@ -1,11 +1,6 @@ -""" -Dataset class -""" - +# Copyright (c) Alibaba, Inc. and its affiliates. import json -from modelscope.preprocessors.space.args import str2bool - class LazyDataset(object): """ diff --git a/modelscope/preprocessors/space/preprocess.py b/modelscope/preprocessors/nlp/space/preprocess.py similarity index 92% rename from modelscope/preprocessors/space/preprocess.py rename to modelscope/preprocessors/nlp/space/preprocess.py index bd8d64d1..8aab4711 100644 --- a/modelscope/preprocessors/space/preprocess.py +++ b/modelscope/preprocessors/nlp/space/preprocess.py @@ -1,12 +1,9 @@ -""" -Preprocess script. -""" +# Copyright (c) Alibaba, Inc. and its affiliates. import glob import os -from modelscope.preprocessors.space.args import parse_args -from modelscope.preprocessors.space.fields.intent_field import \ +from modelscope.preprocessors.nlp.space.fields.intent_field import \ IntentBPETextField FILE_NAME = 'train.json' diff --git a/modelscope/preprocessors/space/sampler.py b/modelscope/preprocessors/nlp/space/sampler.py similarity index 96% rename from modelscope/preprocessors/space/sampler.py rename to modelscope/preprocessors/nlp/space/sampler.py index 49a216d1..e549c343 100644 --- a/modelscope/preprocessors/space/sampler.py +++ b/modelscope/preprocessors/nlp/space/sampler.py @@ -1,6 +1,4 @@ -""" -Sampler class. -""" +# Copyright (c) Alibaba, Inc. and its affiliates. import numpy as np diff --git a/modelscope/preprocessors/space/tensorlistdataset.py b/modelscope/preprocessors/nlp/space/tensorlistdataset.py similarity index 100% rename from modelscope/preprocessors/space/tensorlistdataset.py rename to modelscope/preprocessors/nlp/space/tensorlistdataset.py diff --git a/modelscope/preprocessors/space/tokenizer.py b/modelscope/preprocessors/nlp/space/tokenizer.py similarity index 99% rename from modelscope/preprocessors/space/tokenizer.py rename to modelscope/preprocessors/nlp/space/tokenizer.py index 87f7e8c3..1bd0ce11 100644 --- a/modelscope/preprocessors/space/tokenizer.py +++ b/modelscope/preprocessors/nlp/space/tokenizer.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + from __future__ import (absolute_import, division, print_function, unicode_literals) import collections diff --git a/modelscope/preprocessors/star3/__init__.py b/modelscope/preprocessors/nlp/space_T_cn/__init__.py similarity index 100% rename from modelscope/preprocessors/star3/__init__.py rename to modelscope/preprocessors/nlp/space_T_cn/__init__.py diff --git a/modelscope/preprocessors/nlp/space_T_cn/fields/__init__.py b/modelscope/preprocessors/nlp/space_T_cn/fields/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/preprocessors/star3/fields/database.py b/modelscope/preprocessors/nlp/space_T_cn/fields/database.py similarity index 55% rename from modelscope/preprocessors/star3/fields/database.py rename to modelscope/preprocessors/nlp/space_T_cn/fields/database.py index a99800cf..2fef8d7e 100644 --- a/modelscope/preprocessors/star3/fields/database.py +++ b/modelscope/preprocessors/nlp/space_T_cn/fields/database.py @@ -1,24 +1,47 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import sqlite3 + import json import tqdm -from modelscope.preprocessors.star3.fields.struct import Trie +from .struct import Trie class Database: - def __init__(self, tokenizer, table_file_path, syn_dict_file_path): + def __init__(self, + tokenizer, + table_file_path, + syn_dict_file_path, + is_use_sqlite=True): self.tokenizer = tokenizer + self.is_use_sqlite = is_use_sqlite + if self.is_use_sqlite: + self.connection_obj = sqlite3.connect( + ':memory:', check_same_thread=False) + self.type_dict = {'text': 'TEXT', 'number': 'INT', 'date': 'TEXT'} self.tables = self.init_tables(table_file_path=table_file_path) self.syn_dict = self.init_syn_dict( syn_dict_file_path=syn_dict_file_path) + def __del__(self): + if self.is_use_sqlite: + self.connection_obj.close() + def init_tables(self, table_file_path): tables = {} lines = [] - with open(table_file_path, 'r') as fo: - for line in fo: - lines.append(line) + if type(table_file_path) == str: + with open(table_file_path, 'r') as fo: + for line in fo: + lines.append(line) + elif type(table_file_path) == list: + for path in table_file_path: + with open(path, 'r') as fo: + for line in fo: + lines.append(line) + else: + raise ValueError() for line in tqdm.tqdm(lines, desc='Load Tables'): table = json.loads(line.strip()) @@ -34,6 +57,9 @@ class Database: headers_tokens.append(empty_column) table['tablelen'] = table_header_length table['header_tok'] = headers_tokens + table['headerid2name'] = {} + for hid, hname in zip(table['header_id'], table['header_name']): + table['headerid2name'][hid] = hname table['header_types'].append('null') table['header_units'] = [ @@ -51,6 +77,26 @@ class Database: trie_set[ii].insert(word, word) table['value_trie'] = trie_set + + # create sqlite + if self.is_use_sqlite: + cursor_obj = self.connection_obj.cursor() + cursor_obj.execute('DROP TABLE IF EXISTS %s' % + (table['table_id'])) + header_string = ', '.join([ + '%s %s' % + (name, self.type_dict[htype]) for name, htype in zip( + table['header_id'], table['header_types']) + ]) + create_table_string = 'CREATE TABLE %s (%s);' % ( + table['table_id'], header_string) + cursor_obj.execute(create_table_string) + for row in table['rows']: + value_string = ', '.join(['"%s"' % (val) for val in row]) + insert_row_string = 'INSERT INTO %s VALUES(%s)' % ( + table['table_id'], value_string) + cursor_obj.execute(insert_row_string) + tables[table['table_id']] = table return tables diff --git a/modelscope/preprocessors/star3/fields/schema_link.py b/modelscope/preprocessors/nlp/space_T_cn/fields/schema_link.py similarity index 94% rename from modelscope/preprocessors/star3/fields/schema_link.py rename to modelscope/preprocessors/nlp/space_T_cn/fields/schema_link.py index 40613f78..b62d03e4 100644 --- a/modelscope/preprocessors/star3/fields/schema_link.py +++ b/modelscope/preprocessors/nlp/space_T_cn/fields/schema_link.py @@ -1,7 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import re -from modelscope.preprocessors.star3.fields.struct import TypeInfo +from .struct import TypeInfo class SchemaLinker: @@ -287,13 +287,23 @@ class SchemaLinker: return match_len / (len(nlu_t) + 0.1) - def get_entity_linking(self, tokenizer, nlu, nlu_t, tables, col_syn_dict): + def get_entity_linking(self, + tokenizer, + nlu, + nlu_t, + tables, + col_syn_dict, + table_id=None, + history_sql=None): """ get linking between question and schema column """ typeinfos = [] numbers = re.findall(r'[-]?\d*\.\d+|[-]?\d+|\d+', nlu) + if table_id is not None and table_id in tables: + tables = {table_id: tables[table_id]} + # search schema link in every table search_result_list = [] for tablename in tables: @@ -305,8 +315,7 @@ class SchemaLinker: typeinfos = [] for ii, column in enumerate(table['header_name']): column = column.lower() - column_new = re.sub('(.*?)', '', column) - column_new = re.sub('(.*?)', '', column_new) + column_new = column cphrase, cscore = self.get_match_phrase( nlu.lower(), column_new) if cscore > 0.3 and cphrase.strip() != '': @@ -330,7 +339,6 @@ class SchemaLinker: for cell in ans.keys(): vphrase = cell vscore = 1.0 - # print("trie_set find:", cell, ans[cell]) phrase_tok = tokenizer.tokenize(vphrase) if len(phrase_tok) == 0 or len(vphrase) < 2: continue @@ -407,17 +415,25 @@ class SchemaLinker: # get the match score of each table match_score = self.get_table_match_score(nlu_t, schema_link) + # cal table_score + if history_sql is not None and 'from' in history_sql: + table_score = int(table['table_id'] == history_sql['from'][0]) + else: + table_score = 0 + search_result = { 'table_id': table['table_id'], 'question_knowledge': final_question, 'header_knowledge': final_header, 'schema_link': schema_link, - 'match_score': match_score + 'match_score': match_score, + 'table_score': table_score } search_result_list.append(search_result) search_result_list = sorted( - search_result_list, key=lambda x: x['match_score'], - reverse=True)[0:4] + search_result_list, + key=lambda x: (x['match_score'], x['table_score']), + reverse=True)[0:1] return search_result_list diff --git a/modelscope/preprocessors/star3/fields/struct.py b/modelscope/preprocessors/nlp/space_T_cn/fields/struct.py similarity index 90% rename from modelscope/preprocessors/star3/fields/struct.py rename to modelscope/preprocessors/nlp/space_T_cn/fields/struct.py index 3c2e664b..917e1aaa 100644 --- a/modelscope/preprocessors/star3/fields/struct.py +++ b/modelscope/preprocessors/nlp/space_T_cn/fields/struct.py @@ -179,3 +179,25 @@ class Constant: self.max_select_num = 4 self.max_where_num = 6 + + self.limit_dict = { + '最': 1, + '1': 1, + '一': 1, + '2': 2, + '二': 2, + '3': 3, + '三': 3, + '4': 4, + '四': 4, + '5': 5, + '五': 5, + '6': 6, + '六': 6, + '7': 7, + '七': 7, + '8': 8, + '八': 8, + '9': 9, + '九': 9 + } diff --git a/modelscope/preprocessors/star3/table_question_answering_preprocessor.py b/modelscope/preprocessors/nlp/space_T_cn/table_question_answering_preprocessor.py similarity index 90% rename from modelscope/preprocessors/star3/table_question_answering_preprocessor.py rename to modelscope/preprocessors/nlp/space_T_cn/table_question_answering_preprocessor.py index 163759a1..3aabc6a9 100644 --- a/modelscope/preprocessors/star3/table_question_answering_preprocessor.py +++ b/modelscope/preprocessors/nlp/space_T_cn/table_question_answering_preprocessor.py @@ -8,8 +8,9 @@ from transformers import BertTokenizer from modelscope.metainfo import Preprocessors from modelscope.preprocessors.base import Preprocessor from modelscope.preprocessors.builder import PREPROCESSORS -from modelscope.preprocessors.star3.fields.database import Database -from modelscope.preprocessors.star3.fields.schema_link import SchemaLinker +from modelscope.preprocessors.nlp.space_T_cn.fields.database import Database +from modelscope.preprocessors.nlp.space_T_cn.fields.schema_link import \ + SchemaLinker from modelscope.utils.config import Config from modelscope.utils.constant import Fields, ModelFile from modelscope.utils.type_assert import type_assert @@ -95,7 +96,8 @@ class TableQuestionAnsweringPreprocessor(Preprocessor): # tokenize question question = data['question'] - history_sql = data['history_sql'] + table_id = data.get('table_id', None) + history_sql = data.get('history_sql', None) nlu = question.lower() nlu_t = self.tokenizer.tokenize(nlu) @@ -105,7 +107,9 @@ class TableQuestionAnsweringPreprocessor(Preprocessor): nlu=nlu, nlu_t=nlu_t, tables=self.db.tables, - col_syn_dict=self.db.syn_dict) + col_syn_dict=self.db.syn_dict, + table_id=table_id, + history_sql=history_sql) # collect data datas = self.construct_data( diff --git a/modelscope/preprocessors/star/__init__.py b/modelscope/preprocessors/nlp/space_T_en/__init__.py similarity index 100% rename from modelscope/preprocessors/star/__init__.py rename to modelscope/preprocessors/nlp/space_T_en/__init__.py diff --git a/modelscope/preprocessors/star/conversational_text_to_sql_preprocessor.py b/modelscope/preprocessors/nlp/space_T_en/conversational_text_to_sql_preprocessor.py similarity index 84% rename from modelscope/preprocessors/star/conversational_text_to_sql_preprocessor.py rename to modelscope/preprocessors/nlp/space_T_en/conversational_text_to_sql_preprocessor.py index b5dd73a9..00c7bcd7 100644 --- a/modelscope/preprocessors/star/conversational_text_to_sql_preprocessor.py +++ b/modelscope/preprocessors/nlp/space_T_en/conversational_text_to_sql_preprocessor.py @@ -12,9 +12,10 @@ from text2sql_lgesql.utils.example import Example from modelscope.metainfo import Preprocessors from modelscope.preprocessors.base import Preprocessor from modelscope.preprocessors.builder import PREPROCESSORS -from modelscope.preprocessors.star.fields.preprocess_dataset import \ +from modelscope.preprocessors.nlp.space_T_en.fields import SubPreprocessor +from modelscope.preprocessors.nlp.space_T_en.fields.preprocess_dataset import \ preprocess_dataset -from modelscope.preprocessors.star.fields.process_dataset import ( +from modelscope.preprocessors.nlp.space_T_en.fields.process_dataset import ( process_dataset, process_tables) from modelscope.utils.config import Config from modelscope.utils.constant import Fields, ModelFile @@ -56,6 +57,18 @@ class ConversationalTextToSqlPreprocessor(Preprocessor): model_dir=self.model_dir, db_dir=os.path.join(model_dir, 'db')) + self.device = 'cuda' if \ + ('device' not in kwargs or kwargs['device'] == 'gpu') \ + and torch.cuda.is_available() else 'cpu' + use_device = True if self.device == 'cuda' else False + self.processor = \ + SubPreprocessor(model_dir=model_dir, + db_content=True, + use_gpu=use_device) + self.output_tables = \ + process_tables(self.processor, + self.tables) + @type_assert(object, dict) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """process the raw input data diff --git a/modelscope/preprocessors/star/fields/__init__.py b/modelscope/preprocessors/nlp/space_T_en/fields/__init__.py similarity index 100% rename from modelscope/preprocessors/star/fields/__init__.py rename to modelscope/preprocessors/nlp/space_T_en/fields/__init__.py diff --git a/modelscope/preprocessors/star/fields/common_utils.py b/modelscope/preprocessors/nlp/space_T_en/fields/common_utils.py similarity index 100% rename from modelscope/preprocessors/star/fields/common_utils.py rename to modelscope/preprocessors/nlp/space_T_en/fields/common_utils.py diff --git a/modelscope/preprocessors/star/fields/parse.py b/modelscope/preprocessors/nlp/space_T_en/fields/parse.py similarity index 100% rename from modelscope/preprocessors/star/fields/parse.py rename to modelscope/preprocessors/nlp/space_T_en/fields/parse.py diff --git a/modelscope/preprocessors/star/fields/preprocess_dataset.py b/modelscope/preprocessors/nlp/space_T_en/fields/preprocess_dataset.py similarity index 95% rename from modelscope/preprocessors/star/fields/preprocess_dataset.py rename to modelscope/preprocessors/nlp/space_T_en/fields/preprocess_dataset.py index 6c84c0e7..a0fd13d1 100644 --- a/modelscope/preprocessors/star/fields/preprocess_dataset.py +++ b/modelscope/preprocessors/nlp/space_T_en/fields/preprocess_dataset.py @@ -3,7 +3,7 @@ from text2sql_lgesql.preprocess.parse_raw_json import Schema, get_schemas from text2sql_lgesql.process_sql import get_sql -from modelscope.preprocessors.star.fields.parse import get_label +from .parse import get_label def preprocess_dataset(processor, dataset, output_tables, database_id, tables): diff --git a/modelscope/preprocessors/star/fields/process_dataset.py b/modelscope/preprocessors/nlp/space_T_en/fields/process_dataset.py similarity index 94% rename from modelscope/preprocessors/star/fields/process_dataset.py rename to modelscope/preprocessors/nlp/space_T_en/fields/process_dataset.py index d8ac094a..88059351 100644 --- a/modelscope/preprocessors/star/fields/process_dataset.py +++ b/modelscope/preprocessors/nlp/space_T_en/fields/process_dataset.py @@ -1,17 +1,12 @@ # Copyright (c) rhythmcao modified from https://github.com/rhythmcao/text2sql-lgesql. -import argparse import os import pickle import sys -import time -import json from text2sql_lgesql.asdl.asdl import ASDLGrammar from text2sql_lgesql.asdl.transition_system import TransitionSystem -from modelscope.preprocessors.star.fields.common_utils import SubPreprocessor - sys.path.append(os.path.dirname(os.path.dirname(__file__))) diff --git a/modelscope/preprocessors/nlp/text2text_generation_preprocessor.py b/modelscope/preprocessors/nlp/text2text_generation_preprocessor.py new file mode 100644 index 00000000..5693d36e --- /dev/null +++ b/modelscope/preprocessors/nlp/text2text_generation_preprocessor.py @@ -0,0 +1,40 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Union + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields, ModeKeys +from .nlp_base import NLPTokenizerPreprocessorBase + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.text2text_gen_preprocessor) +class Text2TextGenerationPreprocessor(NLPTokenizerPreprocessorBase): + """The tokenizer preprocessor used in text generation. + """ + + def __init__(self, + model_dir: str, + tokenizer=None, + mode=ModeKeys.INFERENCE, + **kwargs): + kwargs['truncation'] = kwargs.get('truncation', 'do_not_truncate') + kwargs['padding'] = kwargs.get('padding', False) + kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', + False) + kwargs['max_length'] = kwargs.pop('sequence_length', 128) + super().__init__(model_dir, mode=mode, **kwargs) + + def __call__(self, data: Union[Dict, str]) -> Dict[str, Any]: + text_a, _, _ = self.parse_text_and_label(data) + + inputs = self.tokenizer( + text_a, + return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None, + **self.tokenize_kwargs) + + # This is produced by tokenizers but is an invalid generate kwargs + if 'token_type_ids' in inputs: + del inputs['token_type_ids'] + return inputs diff --git a/modelscope/preprocessors/nlp/text_error_correction.py b/modelscope/preprocessors/nlp/text_error_correction.py index 357a946f..4e5ba3bd 100644 --- a/modelscope/preprocessors/nlp/text_error_correction.py +++ b/modelscope/preprocessors/nlp/text_error_correction.py @@ -7,11 +7,12 @@ from modelscope.metainfo import Preprocessors from modelscope.preprocessors.base import Preprocessor from modelscope.preprocessors.builder import PREPROCESSORS from modelscope.utils.constant import Fields +from .nlp_base import NLPBasePreprocessor @PREPROCESSORS.register_module( Fields.nlp, module_name=Preprocessors.text_error_correction) -class TextErrorCorrectionPreprocessor(Preprocessor): +class TextErrorCorrectionPreprocessor(NLPBasePreprocessor): """The preprocessor used in text correction task. """ @@ -22,7 +23,7 @@ class TextErrorCorrectionPreprocessor(Preprocessor): Args: model_dir (str): model path """ - super().__init__(*args, **kwargs) + super().__init__(model_dir, *args, **kwargs) self.vocab = Dictionary.load(osp.join(model_dir, 'dict.src.txt')) def __call__(self, data: str) -> Dict[str, Any]: diff --git a/modelscope/preprocessors/nlp/text_generation_jieba_preprocessor.py b/modelscope/preprocessors/nlp/text_generation_jieba_preprocessor.py new file mode 100644 index 00000000..1e972d64 --- /dev/null +++ b/modelscope/preprocessors/nlp/text_generation_jieba_preprocessor.py @@ -0,0 +1,44 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os.path as osp +from typing import Any, Dict + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.text_gen_jieba_tokenizer) +class TextGenerationJiebaPreprocessor(Preprocessor): + """The jieba tokenizer preprocessor used in text generation. + """ + + def __init__(self, model_dir: str, *args, **kwargs): + from modelscope.models.nlp.gpt3 import JiebaBPETokenizer + super().__init__(*args, **kwargs) + self.tokenizer = JiebaBPETokenizer( + osp.join(model_dir, 'tokenizer.json')) + + def __call__(self, data: str) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + '深蓝的天空中挂着一轮金黄的圆月,下面是海边的沙地' + Returns: + Dict[str, Any]: the preprocessed data + Example: + {'net_input': + {'src_tokens':tensor([1,2,3,4]), + 'src_lengths': tensor([4])} + } + """ + import torch + + return { + 'input_ids': + torch.tensor(self.tokenizer.tokenize(data)).unsqueeze_(0) + } diff --git a/modelscope/preprocessors/nlp/text_generation_preprocessor.py b/modelscope/preprocessors/nlp/text_generation_preprocessor.py new file mode 100644 index 00000000..238e2972 --- /dev/null +++ b/modelscope/preprocessors/nlp/text_generation_preprocessor.py @@ -0,0 +1,62 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Optional, Union + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields, ModeKeys +from .nlp_base import NLPTokenizerPreprocessorBase + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.text_gen_tokenizer) +class TextGenerationPreprocessor(NLPTokenizerPreprocessorBase): + """The tokenizer preprocessor used in text generation. + """ + + def __init__(self, + model_dir: str, + tokenizer=None, + mode=ModeKeys.INFERENCE, + **kwargs): + kwargs['truncation'] = kwargs.get('truncation', True) + kwargs['padding'] = kwargs.get('padding', 'max_length') + kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', + False) + kwargs['max_length'] = kwargs.pop('sequence_length', 128) + super().__init__(model_dir, mode=mode, **kwargs) + + @staticmethod + def get_roberta_tokenizer_dir(model_dir: str) -> Optional[str]: + import os + for name in os.listdir(model_dir): + full_name = os.path.join(model_dir, name) + if 'roberta' in name and os.path.isdir(full_name): + return full_name + + def build_tokenizer(self, model_dir: str): + roberta_tokenizer_dir = self.get_roberta_tokenizer_dir(model_dir) + if roberta_tokenizer_dir: + from transformers import RobertaTokenizer + return RobertaTokenizer.from_pretrained( + roberta_tokenizer_dir, do_lower_case=False) + return super().build_tokenizer(model_dir) + + def __call__(self, data: Union[Dict, str]) -> Dict[str, Any]: + if self._mode == ModeKeys.INFERENCE: + return super().__call__(data) + src_rst = super().__call__(data['src_txt']) + src_input_ids = src_rst['input_ids'] + src_attention_mask = src_rst['attention_mask'] + if 'tgt_txt' in data: + labels = super().__call__(data['tgt_txt'])['input_ids'] + else: + labels = src_input_ids[1:] + src_input_ids = src_input_ids[:-1] + src_attention_mask = src_attention_mask[:-1] + + return { + 'input_ids': src_input_ids, + 'attention_mask': src_attention_mask, + 'labels': labels, + } diff --git a/modelscope/preprocessors/nlp/text_ranking_preprocessor.py b/modelscope/preprocessors/nlp/text_ranking_preprocessor.py new file mode 100644 index 00000000..2ada6892 --- /dev/null +++ b/modelscope/preprocessors/nlp/text_ranking_preprocessor.py @@ -0,0 +1,67 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Union + +from transformers import AutoTokenizer + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields, ModeKeys +from modelscope.utils.type_assert import type_assert +from .nlp_base import NLPTokenizerPreprocessorBase + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.text_ranking) +class TextRankingPreprocessor(NLPTokenizerPreprocessorBase): + """The tokenizer preprocessor used in passage ranking model. + """ + + def __init__(self, + model_dir: str, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): + """preprocess the data + + Args: + model_dir (str): model path + """ + super().__init__(model_dir, mode=mode, *args, **kwargs) + self.model_dir: str = model_dir + self.first_sequence: str = kwargs.pop('first_sequence', + 'source_sentence') + self.second_sequence = kwargs.pop('second_sequence', + 'sentences_to_compare') + self.sequence_length = kwargs.pop('sequence_length', 128) + + self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) + + @type_assert(object, (str, tuple, Dict)) + def __call__(self, data: Union[tuple, Dict]) -> Dict[str, Any]: + if isinstance(data, tuple): + sentence1, sentence2 = data + elif isinstance(data, dict): + sentence1 = data.get(self.first_sequence) + sentence2 = data.get(self.second_sequence) + if isinstance(sentence2, str): + sentence2 = [sentence2] + if isinstance(sentence1, str): + sentence1 = [sentence1] + sentence1 = sentence1 * len(sentence2) + + max_seq_length = self.sequence_length + feature = self.tokenizer( + sentence1, + sentence2, + padding='max_length', + truncation=True, + max_length=max_seq_length, + return_tensors='pt') + if 'labels' in data: + labels = data['labels'] + feature['labels'] = labels + if 'qid' in data: + qid = data['qid'] + feature['qid'] = qid + return feature diff --git a/modelscope/preprocessors/nlp/token_classification_preprocessor.py b/modelscope/preprocessors/nlp/token_classification_preprocessor.py new file mode 100644 index 00000000..2de0c806 --- /dev/null +++ b/modelscope/preprocessors/nlp/token_classification_preprocessor.py @@ -0,0 +1,261 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Tuple, Union + +import torch + +from modelscope.metainfo import Preprocessors +from modelscope.outputs import OutputKeys +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields, ModeKeys +from modelscope.utils.type_assert import type_assert +from .nlp_base import NLPBasePreprocessor, NLPTokenizerPreprocessorBase + + +@PREPROCESSORS.register_module( + Fields.nlp, + module_name=Preprocessors.word_segment_text_to_label_preprocessor) +class WordSegmentationBlankSetToLabelPreprocessor(NLPBasePreprocessor): + """The preprocessor used to turn a single sentence to a labeled token-classification dict. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.first_sequence: str = kwargs.pop('first_sequence', + 'first_sequence') + self.label = kwargs.pop('label', OutputKeys.LABELS) + + def __call__(self, data: str) -> Union[Dict[str, Any], Tuple]: + data = data.split(' ') + data = list(filter(lambda x: len(x) > 0, data)) + + def produce_train_sample(words): + chars = [] + labels = [] + for word in words: + chars.extend(list(word)) + if len(word) == 1: + labels.append('S-CWS') + else: + labels.extend(['B-CWS'] + ['I-CWS'] * (len(word) - 2) + + ['E-CWS']) + assert len(chars) == len(labels) + return chars, labels + + chars, labels = produce_train_sample(data) + return { + self.first_sequence: chars, + self.label: labels, + } + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.ner_tokenizer) +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.token_cls_tokenizer) +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.sequence_labeling_tokenizer) +class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): + """The tokenizer preprocessor used in normal NER task. + """ + + def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): + """preprocess the data + + Args: + model_dir (str): model path + """ + kwargs['truncation'] = kwargs.get('truncation', True) + kwargs['padding'] = kwargs.get( + 'padding', False if mode == ModeKeys.INFERENCE else 'max_length') + kwargs['max_length'] = kwargs.pop('sequence_length', 128) + self.sequence_length = kwargs['max_length'] + self.label_all_tokens = kwargs.pop('label_all_tokens', False) + super().__init__(model_dir, mode=mode, **kwargs) + + if 'is_split_into_words' in kwargs: + self.is_split_into_words = kwargs.pop('is_split_into_words') + else: + self.is_split_into_words = self.tokenizer.init_kwargs.get( + 'is_split_into_words', False) + if 'label2id' in kwargs: + kwargs.pop('label2id') + self.tokenize_kwargs = kwargs + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + 'you are so handsome.' + + Returns: + Dict[str, Any]: the preprocessed data + """ + + # preprocess the data for the model input + text = None + labels_list = None + if isinstance(data, str): + text = data + elif isinstance(data, dict): + text = data.get(self.first_sequence) + labels_list = data.get(self.label) + + input_ids = [] + label_mask = [] + offset_mapping = [] + if self.is_split_into_words: + for offset, token in enumerate(list(data)): + subtoken_ids = self.tokenizer.encode( + token, add_special_tokens=False) + if len(subtoken_ids) == 0: + subtoken_ids = [self.tokenizer.unk_token_id] + input_ids.extend(subtoken_ids) + label_mask.extend([1] + [0] * (len(subtoken_ids) - 1)) + offset_mapping.extend([(offset, offset + 1)]) + else: + if self.tokenizer.is_fast: + encodings = self.tokenizer( + text, + add_special_tokens=False, + return_offsets_mapping=True, + **self.tokenize_kwargs) + input_ids = encodings['input_ids'] + word_ids = encodings.word_ids() + for i in range(len(word_ids)): + if word_ids[i] is None: + label_mask.append(0) + elif word_ids[i] == word_ids[i - 1]: + label_mask.append(0) + offset_mapping[-1] = ( + offset_mapping[-1][0], + encodings['offset_mapping'][i][1]) + else: + label_mask.append(1) + offset_mapping.append(encodings['offset_mapping'][i]) + else: + encodings = self.tokenizer( + text, add_special_tokens=False, **self.tokenize_kwargs) + input_ids = encodings['input_ids'] + label_mask, offset_mapping = self.get_label_mask_and_offset_mapping( + text) + + if len(input_ids) >= self.sequence_length - 2: + input_ids = input_ids[:self.sequence_length - 2] + label_mask = label_mask[:self.sequence_length - 2] + input_ids = [self.tokenizer.cls_token_id + ] + input_ids + [self.tokenizer.sep_token_id] + label_mask = [0] + label_mask + [0] + attention_mask = [1] * len(input_ids) + offset_mapping = offset_mapping[:sum(label_mask)] + + if not self.is_transformer_based_model: + input_ids = input_ids[1:-1] + attention_mask = attention_mask[1:-1] + label_mask = label_mask[1:-1] + + if self._mode == ModeKeys.INFERENCE: + input_ids = torch.tensor(input_ids).unsqueeze(0) + attention_mask = torch.tensor(attention_mask).unsqueeze(0) + label_mask = torch.tensor( + label_mask, dtype=torch.bool).unsqueeze(0) + + # the token classification + output = { + 'text': text, + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'label_mask': label_mask, + 'offset_mapping': offset_mapping + } + + # align the labels with tokenized text + if labels_list is not None: + assert self.label2id is not None + # Map that sends B-Xxx label to its I-Xxx counterpart + b_to_i_label = [] + label_enumerate_values = [ + k for k, v in sorted( + self.label2id.items(), key=lambda item: item[1]) + ] + for idx, label in enumerate(label_enumerate_values): + if label.startswith('B-') and label.replace( + 'B-', 'I-') in label_enumerate_values: + b_to_i_label.append( + label_enumerate_values.index( + label.replace('B-', 'I-'))) + else: + b_to_i_label.append(idx) + + label_row = [self.label2id[lb] for lb in labels_list] + previous_word_idx = None + label_ids = [] + for word_idx in word_ids: + if word_idx is None: + label_ids.append(-100) + elif word_idx != previous_word_idx: + label_ids.append(label_row[word_idx]) + else: + if self.label_all_tokens: + label_ids.append(b_to_i_label[label_row[word_idx]]) + else: + label_ids.append(-100) + previous_word_idx = word_idx + labels = label_ids + output['labels'] = labels + return output + + def get_tokenizer_class(self): + tokenizer_class = self.tokenizer.__class__.__name__ + if tokenizer_class.endswith( + 'Fast') and tokenizer_class != 'PreTrainedTokenizerFast': + tokenizer_class = tokenizer_class[:-4] + return tokenizer_class + + def get_label_mask_and_offset_mapping(self, text): + label_mask = [] + offset_mapping = [] + tokens = self.tokenizer.tokenize(text) + offset = 0 + if self.get_tokenizer_class() == 'BertTokenizer': + for token in tokens: + is_start = (token[:2] != '##') + if is_start: + label_mask.append(True) + else: + token = token[2:] + label_mask.append(False) + start = offset + text[offset:].index(token) + end = start + len(token) + if is_start: + offset_mapping.append((start, end)) + else: + offset_mapping[-1] = (offset_mapping[-1][0], end) + offset = end + elif self.get_tokenizer_class() == 'XLMRobertaTokenizer': + last_is_blank = False + for token in tokens: + is_start = (token[0] == '▁') + if is_start: + token = token[1:] + label_mask.append(True) + if len(token) == 0: + last_is_blank = True + continue + else: + label_mask.append(False) + start = offset + text[offset:].index(token) + end = start + len(token) + if last_is_blank or is_start: + offset_mapping.append((start, end)) + else: + offset_mapping[-1] = (offset_mapping[-1][0], end) + offset = end + last_is_blank = False + else: + raise NotImplementedError + + return label_mask, offset_mapping diff --git a/modelscope/preprocessors/nlp/token_classification_thai_preprocessor.py b/modelscope/preprocessors/nlp/token_classification_thai_preprocessor.py new file mode 100644 index 00000000..a356cea7 --- /dev/null +++ b/modelscope/preprocessors/nlp/token_classification_thai_preprocessor.py @@ -0,0 +1,44 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Tuple, Union + +import torch + +from modelscope.metainfo import Preprocessors +from modelscope.outputs import OutputKeys +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields, ModeKeys +from modelscope.utils.type_assert import type_assert +from .token_classification_preprocessor import TokenClassificationPreprocessor + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.thai_ner_tokenizer) +class NERPreprocessorThai(TokenClassificationPreprocessor): + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + from pythainlp import word_tokenize + + segmented_data = ' '.join([ + w.strip(' ') for w in word_tokenize(text=data, engine='newmm') + if w.strip(' ') != '' + ]) + output = super().__call__(segmented_data) + + return output + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.thai_wseg_tokenizer) +class WordSegmentationPreprocessorThai(TokenClassificationPreprocessor): + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + import regex + data = regex.findall(r'\X', data) + data = ' '.join([char for char in data]) + + output = super().__call__(data) + + return output diff --git a/modelscope/preprocessors/nlp/token_classification_viet_preprocessor.py b/modelscope/preprocessors/nlp/token_classification_viet_preprocessor.py new file mode 100644 index 00000000..f8970d1a --- /dev/null +++ b/modelscope/preprocessors/nlp/token_classification_viet_preprocessor.py @@ -0,0 +1,33 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Tuple, Union + +import torch + +from modelscope.metainfo import Preprocessors +from modelscope.outputs import OutputKeys +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields, ModeKeys +from modelscope.utils.type_assert import type_assert +from .token_classification_preprocessor import TokenClassificationPreprocessor + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.viet_ner_tokenizer) +class NERPreprocessorViet(TokenClassificationPreprocessor): + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + from pyvi import ViTokenizer + + seg_words = [ + t.strip(' ') for t in ViTokenizer.tokenize(data).split(' ') + if t.strip(' ') != '' + ] + raw_words = [] + for w in seg_words: + raw_words.extend(w.split('_')) + segmented_data = ' '.join(raw_words) + output = super().__call__(segmented_data) + + return output diff --git a/modelscope/preprocessors/nlp/zero_shot_classification_reprocessor.py b/modelscope/preprocessors/nlp/zero_shot_classification_reprocessor.py new file mode 100644 index 00000000..eb3c4b37 --- /dev/null +++ b/modelscope/preprocessors/nlp/zero_shot_classification_reprocessor.py @@ -0,0 +1,51 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Union + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields, ModeKeys +from .nlp_base import NLPTokenizerPreprocessorBase + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.zero_shot_cls_tokenizer) +class ZeroShotClassificationPreprocessor(NLPTokenizerPreprocessorBase): + """The tokenizer preprocessor used in zero shot classification. + """ + + def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): + """preprocess the data + + Args: + model_dir (str): model path + """ + self.sequence_length = kwargs.pop('sequence_length', 512) + super().__init__(model_dir, mode=mode, **kwargs) + + def __call__(self, data: Union[str, Dict], hypothesis_template: str, + candidate_labels: list) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str or dict): a sentence + Example: + 'you are so handsome.' + + Returns: + Dict[str, Any]: the preprocessed data + """ + if isinstance(data, dict): + data = data.get(self.first_sequence) + + pairs = [[data, hypothesis_template.format(label)] + for label in candidate_labels] + + features = self.tokenizer( + pairs, + padding=True, + truncation=True, + max_length=self.sequence_length, + truncation_strategy='only_first', + return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None) + return features diff --git a/modelscope/preprocessors/ofa/__init__.py b/modelscope/preprocessors/ofa/__init__.py index 95d72fe1..59b94b2b 100644 --- a/modelscope/preprocessors/ofa/__init__.py +++ b/modelscope/preprocessors/ofa/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .image_captioning import OfaImageCaptioningPreprocessor from .image_classification import OfaImageClassificationPreprocessor +from .ocr_recognition import OfaOcrRecognitionPreprocessor from .summarization import OfaSummarizationPreprocessor from .text_classification import OfaTextClassificationPreprocessor from .text_to_image_synthesis import OfaTextToImageSynthesisPreprocessor diff --git a/modelscope/preprocessors/ofa/base.py b/modelscope/preprocessors/ofa/base.py index 691f8b36..55b3895d 100644 --- a/modelscope/preprocessors/ofa/base.py +++ b/modelscope/preprocessors/ofa/base.py @@ -1,26 +1,31 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import re +import string from os import path as osp import json import numpy as np import torch +from PIL import Image from modelscope.models.multi_modal.ofa import OFATokenizer, OFATokenizerZH +from modelscope.preprocessors.image import load_image from modelscope.utils.trie import Trie +from .utils.constant import OFA_TASK_KEY_MAPPING from .utils.random_help import set_torch_seed class OfaBasePreprocessor: - def __init__(self, cfg, model_dir): - """preprocess the data + def __init__(self, cfg, model_dir, mode, *args, **kwargs): + """preprocess the data via the vocab.txt from the `model_dir` path Args: cfg(modelscope.utils.config.ConfigDict) : model config model_dir (str): model path """ self.cfg = cfg + self.mode = mode self.language = self.cfg.model.get('language', 'en') if self.language == 'en': tokenizer = OFATokenizer.from_pretrained(model_dir) @@ -41,6 +46,7 @@ class OfaBasePreprocessor: for key, value in tokenizer.get_vocab().items() } self.max_src_length = cfg.model.get('max_src_length', 256) + self.max_tgt_length = cfg.model.get('max_tgt_length', 256) self.max_image_size = cfg.model.get('max_image_size', 512) self.language = self.cfg.model.get('language', 'en') self.prompt_type = self.cfg.model.get('prompt_type', 'none') @@ -56,26 +62,40 @@ class OfaBasePreprocessor: self.mean = [0.5, 0.5, 0.5] self.std = [0.5, 0.5, 0.5] self.patch_image_size = self.cfg.model.get('patch_image_size', 480) + self.column_map = { + key: key + for key in OFA_TASK_KEY_MAPPING[self.cfg.task] + } + if hasattr(self.cfg, + 'dataset') and self.cfg.dataset.column_map is not None: + for k, v in self.cfg.dataset.column_map.items(): + self.column_map[k] = v + self.transtab = str.maketrans( + {key: None + for key in string.punctuation}) self.constraint_trie = None - self.index2ans = {} - if self.cfg.model.get('answer2label', False): + if self.cfg.model.get('answer2label', None): ans2label_file = osp.join(model_dir, self.cfg.model.answer2label) - ans2label_dict = json.load(open(ans2label_file, 'r')) + with open(ans2label_file, 'r') as reader: + ans2label_dict = json.load(reader) + self.ans2label = ans2label_dict + self.label2ans = {v: k for k, v in self.ans2label.items()} self.constraint_trie = Trie(tokenizer.eos_token_id) for i, answer in enumerate(ans2label_dict.keys()): - answer_item = tokenizer( - ' ' + answer, - return_tensors='pt', - add_special_tokens=False).input_ids.squeeze(0) + answer_item = self.tokenize_text( + ' ' + answer, add_bos=False, add_eos=False) self.constraint_trie.insert([tokenizer.bos_token_id] + answer_item.tolist() + [tokenizer.eos_token_id]) - def get_inputs(self, text, add_bos=True, add_eos=True): + def tokenize_text(self, text, add_bos=True, add_eos=True): + if text is None: + return None inputs = self.tokenizer( text, max_length=self.max_src_length, add_special_tokens=False, + truncation=True, return_tensors='pt')['input_ids'].squeeze(0) if add_bos: inputs = torch.cat([self.bos_item, inputs]) @@ -85,7 +105,7 @@ class OfaBasePreprocessor: @staticmethod def pre_caption(caption, max_words=None): - caption = caption.lower().lstrip(',.!?*#:;~').replace('-', ' ')\ + caption = caption.lower().lstrip(',.!?*#:;~').replace('-', ' ') \ .replace('/', ' ').replace('', 'person') caption = re.sub( @@ -123,3 +143,23 @@ class OfaBasePreprocessor: question = ' '.join(question_words[:max_ques_words]) return question + + def add_constraint_mask(self, sample): + target_itm = sample['target'] + len_label_itm = target_itm.ne(self.pad_item).sum(dim=0).item() + if self.constraint_trie: + constraint_mask = torch.zeros( + (len(target_itm), len(self.tgt_dict))).bool() + start_idx = len(target_itm) - len_label_itm + for i in range(start_idx, len(target_itm)): + constraint_prefix_token = self.bos_item.tolist( + ) + target_itm[start_idx:i].tolist() + constraint_nodes = self.constraint_trie.get_next_layer( + constraint_prefix_token) + constraint_mask[i][constraint_nodes] = True + sample['constraint_mask'] = constraint_mask + + def get_img_pil(self, path_or_url_or_pil): + image = path_or_url_or_pil if isinstance(path_or_url_or_pil, Image.Image) \ + else load_image(path_or_url_or_pil) + return image diff --git a/modelscope/preprocessors/ofa/image_captioning.py b/modelscope/preprocessors/ofa/image_captioning.py index 318a8a6d..af623297 100644 --- a/modelscope/preprocessors/ofa/image_captioning.py +++ b/modelscope/preprocessors/ofa/image_captioning.py @@ -1,42 +1,67 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Any, Dict, Union +from typing import Any, Dict import torch -from PIL import Image from torchvision import transforms -from modelscope.preprocessors.image import load_image +from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor class OfaImageCaptioningPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): """preprocess the data Args: cfg(modelscope.utils.config.ConfigDict) : model config - model_dir (str): model path + model_dir (str): model path, + mode: preprocessor mode (model mode) """ - super(OfaImageCaptioningPreprocessor, self).__init__(cfg, model_dir) + super(OfaImageCaptioningPreprocessor, + self).__init__(cfg, model_dir, mode, *args, **kwargs) # Initialize transform self.patch_resize_transform = transforms.Compose([ lambda image: image.convert('RGB'), - transforms.Resize((self.patch_image_size, self.patch_image_size), - interpolation=Image.BICUBIC), + transforms.Resize( + (self.patch_image_size, self.patch_image_size), + interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std), ]) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: - image = data['image'] if isinstance( - data['image'], Image.Image) else load_image(data['image']) + if self.mode == ModeKeys.TRAIN: + return self._build_train_sample(data) + else: + return self._build_infer_sample(data) + + def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + sample = self._build_infer_sample(data) + target = data[self.column_map['text']] + target = target.translate(self.transtab).strip() + target_token_list = target.strip().split() + target = ' '.join(target_token_list[:self.max_tgt_length]) + sample['target'] = self.tokenize_text(target, add_bos=False) + sample['prev_output_tokens'] = torch.cat( + [self.bos_item, sample['target'][:-1]]) + return sample + + def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + image = self.get_img_pil(data[self.column_map['image']]) patch_image = self.patch_resize_transform(image) prompt = self.cfg.model.get('prompt', ' what does the image describe?') - inputs = self.get_inputs(prompt) + inputs = self.tokenize_text(prompt) sample = { 'source': inputs, 'patch_image': patch_image, 'patch_mask': torch.tensor([True]) } + if 'text' in self.column_map and self.column_map['text'] in data: + sample['label'] = data[self.column_map['text']] return sample diff --git a/modelscope/preprocessors/ofa/image_classification.py b/modelscope/preprocessors/ofa/image_classification.py index dd2de634..49968823 100644 --- a/modelscope/preprocessors/ofa/image_classification.py +++ b/modelscope/preprocessors/ofa/image_classification.py @@ -6,25 +6,33 @@ from PIL import Image from torchvision import transforms from modelscope.preprocessors.image import load_image +from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor class OfaImageClassificationPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): """preprocess the data Args: cfg(modelscope.utils.config.ConfigDict) : model config - model_dir (str): model path + model_dir (str): model path, + mode: preprocessor mode (model mode) """ super(OfaImageClassificationPreprocessor, - self).__init__(cfg, model_dir) + self).__init__(cfg, model_dir, mode, *args, **kwargs) # Initialize transform self.patch_resize_transform = transforms.Compose([ lambda image: image.convert('RGB'), - transforms.Resize((self.patch_image_size, self.patch_image_size), - interpolation=Image.BICUBIC), + transforms.Resize( + (self.patch_image_size, self.patch_image_size), + interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std), ]) @@ -34,7 +42,7 @@ class OfaImageClassificationPreprocessor(OfaBasePreprocessor): data['image'], Image.Image) else load_image(data['image']) patch_image = self.patch_resize_transform(image) prompt = self.cfg.model.get('prompt', ' what does the image describe?') - inputs = self.get_inputs(prompt) + inputs = self.tokenize_text(prompt) sample = { 'source': inputs, 'patch_image': patch_image, diff --git a/modelscope/preprocessors/ofa/ocr_recognition.py b/modelscope/preprocessors/ofa/ocr_recognition.py new file mode 100644 index 00000000..1761dbd4 --- /dev/null +++ b/modelscope/preprocessors/ofa/ocr_recognition.py @@ -0,0 +1,105 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import torch +from PIL import Image +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from torchvision.transforms import functional as F + +from modelscope.preprocessors.image import load_image +from modelscope.utils.constant import ModeKeys +from .base import OfaBasePreprocessor + +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + + +def ocr_resize(img, patch_image_size, is_document=False): + img = img.convert('RGB') + width, height = img.size + + if is_document: + new_height, new_width = 64, 1920 + else: + if width >= height: + new_width = max(64, patch_image_size) + new_height = max(64, int(patch_image_size * (height / width))) + top = (patch_image_size - new_height) // 2 + bottom = patch_image_size - new_height - top + left, right = 0, 0 + else: + new_height = max(64, patch_image_size) + new_width = max(64, int(patch_image_size * (width / height))) + left = (patch_image_size - new_width) // 2 + right = patch_image_size - new_width - left + top, bottom = 0, 0 + + img_new = F.resize( + img, + (new_height, new_width), + interpolation=InterpolationMode.BICUBIC, + ) + + if is_document: + img_split = transforms.ToTensor()(img_new).chunk(4, dim=-1) + img_new = transforms.ToPILImage()(torch.cat(img_split, dim=-2)) + new_width, new_height = img_new.size + top = (patch_image_size - new_height) // 2 + bottom = patch_image_size - new_height - top + left, right = 0, 0 + + img_new = F.pad( + img_new, padding=[left, top, right, bottom], padding_mode='edge') + assert img_new.size == (patch_image_size, patch_image_size) + + return img_new + + +class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): + + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): + """preprocess the data + + Args: + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path, + mode: preprocessor mode (model mode) + """ + super(OfaOcrRecognitionPreprocessor, + self).__init__(cfg, model_dir, mode, *args, **kwargs) + # Initialize transform + if self.cfg.model.imagenet_default_mean_and_std: + mean = IMAGENET_DEFAULT_MEAN + std = IMAGENET_DEFAULT_STD + else: + mean = [0.5, 0.5, 0.5] + std = [0.5, 0.5, 0.5] + + self.patch_resize_transform = transforms.Compose([ + lambda image: ocr_resize( + image, + self.cfg.model.patch_image_size, + is_document=self.cfg.model.is_document), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + image = data['image'] if isinstance( + data['image'], Image.Image) else load_image(data['image']) + patch_image = self.patch_resize_transform(image) + prompt = self.cfg.model.get('prompt', '图片上的文字是什么?') + inputs = self.tokenize_text(prompt) + + sample = { + 'source': inputs, + 'patch_image': patch_image, + 'patch_mask': torch.tensor([True]) + } + return sample diff --git a/modelscope/preprocessors/ofa/summarization.py b/modelscope/preprocessors/ofa/summarization.py index 99028e61..cfd3c23d 100644 --- a/modelscope/preprocessors/ofa/summarization.py +++ b/modelscope/preprocessors/ofa/summarization.py @@ -1,19 +1,27 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict +from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor class OfaSummarizationPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): """preprocess the data Args: cfg(modelscope.utils.config.ConfigDict) : model config - model_dir (str): model path + model_dir (str): model path, + mode: preprocessor mode (model mode) """ - super(OfaSummarizationPreprocessor, self).__init__(cfg, model_dir) + super(OfaSummarizationPreprocessor, + self).__init__(cfg, model_dir, mode, *args, **kwargs) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: source = super().pre_caption( @@ -23,7 +31,7 @@ class OfaSummarizationPreprocessor(OfaBasePreprocessor): prompt = self.cfg.model.get( 'prompt', ' " {} " Summarize the article with a title: ') text = prompt.format(source) - inputs = self.get_inputs(text) + inputs = self.tokenize_text(text) if self.prompt_type == 'none': decoder_prompt = self.bos_item elif self.prompt_type == 'prev_output': diff --git a/modelscope/preprocessors/ofa/text_classification.py b/modelscope/preprocessors/ofa/text_classification.py index 5673a07f..24c4f67e 100644 --- a/modelscope/preprocessors/ofa/text_classification.py +++ b/modelscope/preprocessors/ofa/text_classification.py @@ -1,38 +1,81 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict +import torch + +from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor class OfaTextClassificationPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): """preprocess the data Args: cfg(modelscope.utils.config.ConfigDict) : model config - model_dir (str): model path + model_dir (str): model path, + mode: preprocessor mode (model mode) """ - super(OfaTextClassificationPreprocessor, self).__init__(cfg, model_dir) + super(OfaTextClassificationPreprocessor, + self).__init__(cfg, model_dir, mode, *args, **kwargs) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + if self.mode == ModeKeys.TRAIN: + return self._build_train_sample(data) + else: + return self._build_infer_sample(data) + + def _build_instruction(self, data): text1 = ' '.join( data['text'].lower().strip().split()[:self.max_src_length]) text2 = ' '.join( data['text2'].lower().strip().split()[:self.max_src_length]) prompt = ' can text1 " {} " imply text2 " {} "?' text = prompt.format(text1, text2) - inputs = self.get_inputs(text) + instruction_itm = self.tokenize_text(text) + return instruction_itm + + def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + instruction_itm = self._build_instruction(data) + assert 'label' in data, 'there must has `label` column in train phase ' + label = data['label'] + if self.label2ans: + label = self.label2ans[label] # ans + label_itm = self.tokenize_text(f' {label}', add_bos=False) + if self.prompt_type == 'none': + target_itm = label_itm + elif self.prompt_type == 'prev_output': + target_itm = torch.cat([instruction_itm[1:-1], label_itm]) + else: + raise NotImplementedError + prev_output_itm = torch.cat([self.bos_item, target_itm[:-1]]) + target_itm[:-len(label_itm)] = self.pad_item + sample = { + 'source': instruction_itm, + 'target': target_itm, + 'prev_output_tokens': prev_output_itm, + } + self.add_constraint_mask(sample) + return sample + + def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + instruction_itm = self._build_instruction(data) if self.prompt_type == 'none': - decoder_prompt = self.bos_item - elif self.prompt_type == 'src': - decoder_prompt = inputs + prefix_token = [] elif self.prompt_type == 'prev_output': - decoder_prompt = inputs[:-1] + prefix_token = instruction_itm[:-1] # remove eos else: raise NotImplementedError sample = { - 'source': inputs, - 'decoder_prompt': decoder_prompt, + 'source': instruction_itm, + 'prefix_token': prefix_token, } + if 'label' in data: + sample['label'] = self.label2ans[data['label']] return sample diff --git a/modelscope/preprocessors/ofa/text_to_image_synthesis.py b/modelscope/preprocessors/ofa/text_to_image_synthesis.py index e10de82c..2f6000eb 100644 --- a/modelscope/preprocessors/ofa/text_to_image_synthesis.py +++ b/modelscope/preprocessors/ofa/text_to_image_synthesis.py @@ -3,26 +3,34 @@ from typing import Any, Dict import torch +from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor class OfaTextToImageSynthesisPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): """preprocess the data Args: - model_dir (str): model path + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path, + mode: preprocessor mode (model mode) """ super(OfaTextToImageSynthesisPreprocessor, - self).__init__(cfg, model_dir) + self).__init__(cfg, model_dir, mode, *args, **kwargs) self.max_src_length = 64 def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: source = ' '.join( data['text'].lower().strip().split()[:self.max_src_length]) source = 'what is the complete image? caption: {}'.format(source) - inputs = self.get_inputs(source) + inputs = self.tokenize_text(source) sample = { 'source': inputs, 'patch_images': None, diff --git a/modelscope/preprocessors/ofa/utils/collate.py b/modelscope/preprocessors/ofa/utils/collate.py index a473335b..f7775680 100644 --- a/modelscope/preprocessors/ofa/utils/collate.py +++ b/modelscope/preprocessors/ofa/utils/collate.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import numpy as np import torch @@ -47,11 +49,15 @@ def collate_fn(samples, pad_idx, eos_idx): batch['conf'] = torch.cat([s['conf'] for s in samples], dim=0) if samples[0].get('ref_dict', None) is not None: batch['ref_dict'] = np.array([s['ref_dict'] for s in samples]) + if samples[0].get('label', None) is not None: + batch['labels'] = np.array([s['label'] for s in samples]).tolist() if samples[0].get('constraint_mask', None) is not None: batch['constraint_masks'] = merge('constraint_mask') if samples[0].get('decoder_prompt', None) is not None: batch['decoder_prompts'] = np.array( [s['decoder_prompt'].tolist() for s in samples]) + if samples[0].get('prefix_token', None) is not None: + batch['prefix_tokens'] = merge('prefix_token') # For detection and visual grounding if samples[0].get('w_resize_ratio', None) is not None: batch['w_resize_ratios'] = torch.stack( diff --git a/modelscope/preprocessors/ofa/utils/constant.py b/modelscope/preprocessors/ofa/utils/constant.py new file mode 100644 index 00000000..102d27c0 --- /dev/null +++ b/modelscope/preprocessors/ofa/utils/constant.py @@ -0,0 +1,13 @@ +from modelscope.utils.constant import Tasks + +OFA_TASK_KEY_MAPPING = { + Tasks.ocr_recognition: ['image'], + Tasks.image_captioning: ['image'], + Tasks.image_classification: ['image'], + Tasks.text_summarization: ['text'], + Tasks.text_classification: ['text', 'text2'], + Tasks.visual_grounding: ['image', 'text'], + Tasks.visual_question_answering: ['image', 'text'], + Tasks.visual_entailment: ['image', 'text', 'text2'], + Tasks.text_to_image_synthesis: ['text'] +} diff --git a/modelscope/preprocessors/ofa/utils/random_help.py b/modelscope/preprocessors/ofa/utils/random_help.py index 77f4df3f..e0dca54e 100644 --- a/modelscope/preprocessors/ofa/utils/random_help.py +++ b/modelscope/preprocessors/ofa/utils/random_help.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import torch try: diff --git a/modelscope/preprocessors/ofa/visual_entailment.py b/modelscope/preprocessors/ofa/visual_entailment.py index 6002c4a6..61c3cc6a 100644 --- a/modelscope/preprocessors/ofa/visual_entailment.py +++ b/modelscope/preprocessors/ofa/visual_entailment.py @@ -6,24 +6,33 @@ from PIL import Image from torchvision import transforms from modelscope.preprocessors.image import load_image +from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): """preprocess the data Args: cfg(modelscope.utils.config.ConfigDict) : model config - model_dir (str): model path + model_dir (str): model path, + mode: preprocessor mode (model mode) """ - super(OfaVisualEntailmentPreprocessor, self).__init__(cfg, model_dir) + super(OfaVisualEntailmentPreprocessor, + self).__init__(cfg, model_dir, mode, *args, **kwargs) # Initialize transform self.patch_resize_transform = transforms.Compose([ lambda image: image.convert('RGB'), - transforms.Resize((self.patch_image_size, self.patch_image_size), - interpolation=Image.BICUBIC), + transforms.Resize( + (self.patch_image_size, self.patch_image_size), + interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std), ]) @@ -44,7 +53,7 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor): prompt = self.cfg.model.get( 'prompt', ' can image and text1 " {} " imply text2 " {} "?') text = prompt.format(caption, hypothesis) - inputs = self.get_inputs(text) + inputs = self.tokenize_text(text) if self.prompt_type == 'none': decoder_prompt = self.bos_item elif self.prompt_type == 'src': diff --git a/modelscope/preprocessors/ofa/visual_grounding.py b/modelscope/preprocessors/ofa/visual_grounding.py index 022e5788..8b116463 100644 --- a/modelscope/preprocessors/ofa/visual_grounding.py +++ b/modelscope/preprocessors/ofa/visual_grounding.py @@ -6,24 +6,33 @@ from PIL import Image from torchvision import transforms from modelscope.preprocessors.image import load_image +from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): """preprocess the data Args: cfg(modelscope.utils.config.ConfigDict) : model config - model_dir (str): model path + model_dir (str): model path, + mode: preprocessor mode (model mode) """ - super(OfaVisualGroundingPreprocessor, self).__init__(cfg, model_dir) + super(OfaVisualGroundingPreprocessor, + self).__init__(cfg, model_dir, mode, *args, **kwargs) # Initialize transform self.patch_resize_transform = transforms.Compose([ lambda image: image.convert('RGB'), - transforms.Resize((self.patch_image_size, self.patch_image_size), - interpolation=Image.BICUBIC), + transforms.Resize( + (self.patch_image_size, self.patch_image_size), + interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std), ]) @@ -39,7 +48,7 @@ class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): prompt = self.cfg.model.get( 'prompt', ' which region does the text " {} " describe?') text = prompt.format(src_caption) - src_item = self.get_inputs(text) + src_item = self.tokenize_text(text) sample = { 'source': src_item, 'patch_image': patch_image, diff --git a/modelscope/preprocessors/ofa/visual_question_answering.py b/modelscope/preprocessors/ofa/visual_question_answering.py index d34d1db0..11104e7e 100644 --- a/modelscope/preprocessors/ofa/visual_question_answering.py +++ b/modelscope/preprocessors/ofa/visual_question_answering.py @@ -6,25 +6,33 @@ from PIL import Image from torchvision import transforms from modelscope.preprocessors.image import load_image +from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): """preprocess the data Args: cfg(modelscope.utils.config.ConfigDict) : model config - model_dir (str): model path + model_dir (str): model path, + mode: preprocessor mode (model mode) """ super(OfaVisualQuestionAnsweringPreprocessor, - self).__init__(cfg, model_dir) + self).__init__(cfg, model_dir, mode, *args, **kwargs) # Initialize transform self.patch_resize_transform = transforms.Compose([ lambda image: image.convert('RGB'), - transforms.Resize((self.patch_image_size, self.patch_image_size), - interpolation=Image.BICUBIC), + transforms.Resize( + (self.patch_image_size, self.patch_image_size), + interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std), ]) @@ -34,7 +42,7 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): data['image'], Image.Image) else load_image(data['image']) patch_image = self.patch_resize_transform(image) text = ' {}'.format(data['text']) - inputs = self.get_inputs(text) + inputs = self.tokenize_text(text) if self.prompt_type == 'none': decoder_prompt = self.bos_item elif self.prompt_type == 'src': diff --git a/modelscope/preprocessors/science/__init__.py b/modelscope/preprocessors/science/__init__.py new file mode 100644 index 00000000..54b24887 --- /dev/null +++ b/modelscope/preprocessors/science/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .unifold import (UniFoldPreprocessor) + +else: + _import_structure = {'unifold': ['UniFoldPreprocessor']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/preprocessors/science/uni_fold.py b/modelscope/preprocessors/science/uni_fold.py new file mode 100644 index 00000000..2a44c885 --- /dev/null +++ b/modelscope/preprocessors/science/uni_fold.py @@ -0,0 +1,569 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import gzip +import hashlib +import logging +import os +import pickle +import random +import re +import tarfile +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from unittest import result + +import json +import numpy as np +import requests +import torch +from tqdm import tqdm + +from modelscope.metainfo import Preprocessors +from modelscope.models.science.unifold.data import protein, residue_constants +from modelscope.models.science.unifold.data.protein import PDB_CHAIN_IDS +from modelscope.models.science.unifold.data.utils import compress_features +from modelscope.models.science.unifold.msa import parsers, pipeline, templates +from modelscope.models.science.unifold.msa.tools import hhsearch +from modelscope.models.science.unifold.msa.utils import divide_multi_chains +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields + +__all__ = [ + 'UniFoldPreprocessor', +] + +TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]' +DEFAULT_API_SERVER = 'https://api.colabfold.com' + + +def run_mmseqs2( + x, + prefix, + use_env=True, + use_templates=False, + use_pairing=False, + host_url='https://api.colabfold.com') -> Tuple[List[str], List[str]]: + submission_endpoint = 'ticket/pair' if use_pairing else 'ticket/msa' + + def submit(seqs, mode, N=101): + n, query = N, '' + for seq in seqs: + query += f'>{n}\n{seq}\n' + n += 1 + + res = requests.post( + f'{host_url}/{submission_endpoint}', + data={ + 'q': query, + 'mode': mode + }) + try: + out = res.json() + except ValueError: + out = {'status': 'ERROR'} + return out + + def status(ID): + res = requests.get(f'{host_url}/ticket/{ID}') + try: + out = res.json() + except ValueError: + out = {'status': 'ERROR'} + return out + + def download(ID, path): + res = requests.get(f'{host_url}/result/download/{ID}') + with open(path, 'wb') as out: + out.write(res.content) + + # process input x + seqs = [x] if isinstance(x, str) else x + + mode = 'env' + if use_pairing: + mode = '' + use_templates = False + use_env = False + + # define path + path = f'{prefix}' + if not os.path.isdir(path): + os.mkdir(path) + + # call mmseqs2 api + tar_gz_file = f'{path}/out_{mode}.tar.gz' + N, REDO = 101, True + + # deduplicate and keep track of order + seqs_unique = [] + # TODO this might be slow for large sets + [seqs_unique.append(x) for x in seqs if x not in seqs_unique] + Ms = [N + seqs_unique.index(seq) for seq in seqs] + # lets do it! + if not os.path.isfile(tar_gz_file): + TIME_ESTIMATE = 150 * len(seqs_unique) + with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar: + while REDO: + pbar.set_description('SUBMIT') + + # Resubmit job until it goes through + out = submit(seqs_unique, mode, N) + while out['status'] in ['UNKNOWN', 'RATELIMIT']: + sleep_time = 5 + random.randint(0, 5) + # logger.error(f"Sleeping for {sleep_time}s. Reason: {out['status']}") + # resubmit + time.sleep(sleep_time) + out = submit(seqs_unique, mode, N) + + if out['status'] == 'ERROR': + error = 'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence.' + error = error + 'If error persists, please try again an hour later.' + raise Exception(error) + + if out['status'] == 'MAINTENANCE': + raise Exception( + 'MMseqs2 API is undergoing maintenance. Please try again in a few minutes.' + ) + + # wait for job to finish + ID, TIME = out['id'], 0 + pbar.set_description(out['status']) + while out['status'] in ['UNKNOWN', 'RUNNING', 'PENDING']: + t = 5 + random.randint(0, 5) + # logger.error(f"Sleeping for {t}s. Reason: {out['status']}") + time.sleep(t) + out = status(ID) + pbar.set_description(out['status']) + if out['status'] == 'RUNNING': + TIME += t + pbar.update(n=t) + + if out['status'] == 'COMPLETE': + if TIME < TIME_ESTIMATE: + pbar.update(n=(TIME_ESTIMATE - TIME)) + REDO = False + + if out['status'] == 'ERROR': + REDO = False + error = 'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence.' + error = error + 'If error persists, please try again an hour later.' + raise Exception(error) + + # Download results + download(ID, tar_gz_file) + + # prep list of a3m files + if use_pairing: + a3m_files = [f'{path}/pair.a3m'] + else: + a3m_files = [f'{path}/uniref.a3m'] + if use_env: + a3m_files.append(f'{path}/bfd.mgnify30.metaeuk30.smag30.a3m') + + # extract a3m files + if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files): + with tarfile.open(tar_gz_file) as tar_gz: + tar_gz.extractall(path) + + # templates + if use_templates: + templates = {} + + with open(f'{path}/pdb70.m8', 'r') as f: + lines = f.readlines() + for line in lines: + p = line.rstrip().split() + M, pdb, _, _ = p[0], p[1], p[2], p[10] # qid, e_value + M = int(M) + if M not in templates: + templates[M] = [] + templates[M].append(pdb) + + template_paths = {} + for k, TMPL in templates.items(): + TMPL_PATH = f'{prefix}/templates_{k}' + if not os.path.isdir(TMPL_PATH): + os.mkdir(TMPL_PATH) + TMPL_LINE = ','.join(TMPL[:20]) + os.system( + f'curl -s -L {host_url}/template/{TMPL_LINE} | tar xzf - -C {TMPL_PATH}/' + ) + os.system( + f'cp {TMPL_PATH}/pdb70_a3m.ffindex {TMPL_PATH}/pdb70_cs219.ffindex' + ) + os.system(f'touch {TMPL_PATH}/pdb70_cs219.ffdata') + template_paths[k] = TMPL_PATH + + # gather a3m lines + a3m_lines = {} + for a3m_file in a3m_files: + update_M, M = True, None + with open(a3m_file, 'r') as f: + lines = f.readlines() + for line in lines: + if len(line) > 0: + if '\x00' in line: + line = line.replace('\x00', '') + update_M = True + if line.startswith('>') and update_M: + M = int(line[1:].rstrip()) + update_M = False + if M not in a3m_lines: + a3m_lines[M] = [] + a3m_lines[M].append(line) + + # return results + + a3m_lines = [''.join(a3m_lines[n]) for n in Ms] + + if use_templates: + template_paths_ = [] + for n in Ms: + if n not in template_paths: + template_paths_.append(None) + # print(f"{n-N}\tno_templates_found") + else: + template_paths_.append(template_paths[n]) + template_paths = template_paths_ + + return (a3m_lines, template_paths) if use_templates else a3m_lines + + +def get_null_template(query_sequence: Union[List[str], str], + num_temp: int = 1) -> Dict[str, Any]: + ln = ( + len(query_sequence) if isinstance(query_sequence, str) else sum( + len(s) for s in query_sequence)) + output_templates_sequence = 'A' * ln + # output_confidence_scores = np.full(ln, 1.0) + + templates_all_atom_positions = np.zeros( + (ln, templates.residue_constants.atom_type_num, 3)) + templates_all_atom_masks = np.zeros( + (ln, templates.residue_constants.atom_type_num)) + templates_aatype = templates.residue_constants.sequence_to_onehot( + output_templates_sequence, + templates.residue_constants.HHBLITS_AA_TO_ID) + template_features = { + 'template_all_atom_positions': + np.tile(templates_all_atom_positions[None], [num_temp, 1, 1, 1]), + 'template_all_atom_masks': + np.tile(templates_all_atom_masks[None], [num_temp, 1, 1]), + 'template_sequence': ['none'.encode()] * num_temp, + 'template_aatype': + np.tile(np.array(templates_aatype)[None], [num_temp, 1, 1]), + 'template_domain_names': ['none'.encode()] * num_temp, + 'template_sum_probs': + np.zeros([num_temp], dtype=np.float32), + } + return template_features + + +def get_template(a3m_lines: str, template_path: str, + query_sequence: str) -> Dict[str, Any]: + template_featurizer = templates.HhsearchHitFeaturizer( + mmcif_dir=template_path, + max_template_date='2100-01-01', + max_hits=20, + kalign_binary_path='kalign', + release_dates_path=None, + obsolete_pdbs_path=None, + ) + + hhsearch_pdb70_runner = hhsearch.HHSearch( + binary_path='hhsearch', databases=[f'{template_path}/pdb70']) + + hhsearch_result = hhsearch_pdb70_runner.query(a3m_lines) + hhsearch_hits = pipeline.parsers.parse_hhr(hhsearch_result) + templates_result = template_featurizer.get_templates( + query_sequence=query_sequence, hits=hhsearch_hits) + return dict(templates_result.features) + + +@PREPROCESSORS.register_module( + Fields.science, module_name=Preprocessors.unifold_preprocessor) +class UniFoldPreprocessor(Preprocessor): + + def __init__(self, **cfg): + self.symmetry_group = cfg['symmetry_group'] # "C1" + if not self.symmetry_group: + self.symmetry_group = None + self.MIN_SINGLE_SEQUENCE_LENGTH = 16 # TODO: change to cfg + self.MAX_SINGLE_SEQUENCE_LENGTH = 1000 + self.MAX_MULTIMER_LENGTH = 1000 + self.jobname = 'unifold' + self.output_dir_base = './unifold-predictions' + os.makedirs(self.output_dir_base, exist_ok=True) + + def clean_and_validate_sequence(self, input_sequence: str, min_length: int, + max_length: int) -> str: + clean_sequence = input_sequence.translate( + str.maketrans('', '', ' \n\t')).upper() + aatypes = set(residue_constants.restypes) # 20 standard aatypes. + if not set(clean_sequence).issubset(aatypes): + raise ValueError( + f'Input sequence contains non-amino acid letters: ' + f'{set(clean_sequence) - aatypes}. AlphaFold only supports 20 standard ' + 'amino acids as inputs.') + if len(clean_sequence) < min_length: + raise ValueError( + f'Input sequence is too short: {len(clean_sequence)} amino acids, ' + f'while the minimum is {min_length}') + if len(clean_sequence) > max_length: + raise ValueError( + f'Input sequence is too long: {len(clean_sequence)} amino acids, while ' + f'the maximum is {max_length}. You may be able to run it with the full ' + f'Uni-Fold system depending on your resources (system memory, ' + f'GPU memory).') + return clean_sequence + + def validate_input(self, input_sequences: Sequence[str], + symmetry_group: str, min_length: int, max_length: int, + max_multimer_length: int) -> Tuple[Sequence[str], bool]: + """Validates and cleans input sequences and determines which model to use.""" + sequences = [] + + for input_sequence in input_sequences: + if input_sequence.strip(): + input_sequence = self.clean_and_validate_sequence( + input_sequence=input_sequence, + min_length=min_length, + max_length=max_length) + sequences.append(input_sequence) + + if symmetry_group is not None and symmetry_group != 'C1': + if symmetry_group.startswith( + 'C') and symmetry_group[1:].isnumeric(): + print( + f'Using UF-Symmetry with group {symmetry_group}. If you do not ' + f'want to use UF-Symmetry, please use `C1` and copy the AU ' + f'sequences to the count in the assembly.') + is_multimer = (len(sequences) > 1) + return sequences, is_multimer, symmetry_group + else: + raise ValueError( + f'UF-Symmetry does not support symmetry group ' + f'{symmetry_group} currently. Cyclic groups (Cx) are ' + f'supported only.') + + elif len(sequences) == 1: + print('Using the single-chain model.') + return sequences, False, None + + elif len(sequences) > 1: + total_multimer_length = sum([len(seq) for seq in sequences]) + if total_multimer_length > max_multimer_length: + raise ValueError( + f'The total length of multimer sequences is too long: ' + f'{total_multimer_length}, while the maximum is ' + f'{max_multimer_length}. Please use the full AlphaFold ' + f'system for long multimers.') + print(f'Using the multimer model with {len(sequences)} sequences.') + return sequences, True, None + + else: + raise ValueError( + 'No input amino acid sequence provided, please provide at ' + 'least one sequence.') + + def add_hash(self, x, y): + return x + '_' + hashlib.sha1(y.encode()).hexdigest()[:5] + + def get_msa_and_templates( + self, + jobname: str, + query_seqs_unique: Union[str, List[str]], + result_dir: Path, + msa_mode: str, + use_templates: bool, + homooligomers_num: int = 1, + host_url: str = DEFAULT_API_SERVER, + ) -> Tuple[Optional[List[str]], Optional[List[str]], List[str], List[int], + List[Dict[str, Any]]]: + + use_env = msa_mode == 'MMseqs2' + + template_features = [] + if use_templates: + a3m_lines_mmseqs2, template_paths = run_mmseqs2( + query_seqs_unique, + str(result_dir.joinpath(jobname)), + use_env, + use_templates=True, + host_url=host_url, + ) + if template_paths is None: + for index in range(0, len(query_seqs_unique)): + template_feature = get_null_template( + query_seqs_unique[index]) + template_features.append(template_feature) + else: + for index in range(0, len(query_seqs_unique)): + if template_paths[index] is not None: + template_feature = get_template( + a3m_lines_mmseqs2[index], + template_paths[index], + query_seqs_unique[index], + ) + if len(template_feature['template_domain_names']) == 0: + template_feature = get_null_template( + query_seqs_unique[index]) + else: + template_feature = get_null_template( + query_seqs_unique[index]) + template_features.append(template_feature) + else: + for index in range(0, len(query_seqs_unique)): + template_feature = get_null_template(query_seqs_unique[index]) + template_features.append(template_feature) + + if msa_mode == 'single_sequence': + a3m_lines = [] + num = 101 + for i, seq in enumerate(query_seqs_unique): + a3m_lines.append('>' + str(num + i) + '\n' + seq) + else: + # find normal a3ms + a3m_lines = run_mmseqs2( + query_seqs_unique, + str(result_dir.joinpath(jobname)), + use_env, + use_pairing=False, + host_url=host_url, + ) + if len(query_seqs_unique) > 1: + # find paired a3m if not a homooligomers + paired_a3m_lines = run_mmseqs2( + query_seqs_unique, + str(result_dir.joinpath(jobname)), + use_env, + use_pairing=True, + host_url=host_url, + ) + else: + num = 101 + paired_a3m_lines = [] + for i in range(0, homooligomers_num): + paired_a3m_lines.append('>' + str(num + i) + '\n' + + query_seqs_unique[0] + '\n') + + return ( + a3m_lines, + paired_a3m_lines, + template_features, + ) + + def __call__(self, data: Union[str, Tuple]): + if isinstance(data, str): + data = [data, '', '', ''] + basejobname = ''.join(data) + basejobname = re.sub(r'\W+', '', basejobname) + target_id = self.add_hash(self.jobname, basejobname) + + sequences, is_multimer, _ = self.validate_input( + input_sequences=data, + symmetry_group=self.symmetry_group, + min_length=self.MIN_SINGLE_SEQUENCE_LENGTH, + max_length=self.MAX_SINGLE_SEQUENCE_LENGTH, + max_multimer_length=self.MAX_MULTIMER_LENGTH) + + descriptions = [ + '> ' + target_id + ' seq' + str(ii) + for ii in range(len(sequences)) + ] + + if is_multimer: + divide_multi_chains(target_id, self.output_dir_base, sequences, + descriptions) + + s = [] + for des, seq in zip(descriptions, sequences): + s += [des, seq] + + unique_sequences = [] + [ + unique_sequences.append(x) for x in sequences + if x not in unique_sequences + ] + + if len(unique_sequences) == 1: + homooligomers_num = len(sequences) + else: + homooligomers_num = 1 + + with open(f'{self.jobname}.fasta', 'w') as f: + f.write('\n'.join(s)) + + result_dir = Path(self.output_dir_base) + output_dir = os.path.join(self.output_dir_base, target_id) + + # msa_mode = 'single_sequence' + msa_mode = 'MMseqs2' + use_templates = True + + unpaired_msa, paired_msa, template_results = self.get_msa_and_templates( + target_id, + unique_sequences, + result_dir=result_dir, + msa_mode=msa_mode, + use_templates=use_templates, + homooligomers_num=homooligomers_num) + + features = [] + pair_features = [] + + for idx, seq in enumerate(unique_sequences): + chain_id = PDB_CHAIN_IDS[idx] + sequence_features = pipeline.make_sequence_features( + sequence=seq, + description=f'> {self.jobname} seq {chain_id}', + num_res=len(seq)) + monomer_msa = parsers.parse_a3m(unpaired_msa[idx]) + msa_features = pipeline.make_msa_features([monomer_msa]) + template_features = template_results[idx] + feature_dict = { + **sequence_features, + **msa_features, + **template_features + } + feature_dict = compress_features(feature_dict) + features_output_path = os.path.join( + output_dir, '{}.feature.pkl.gz'.format(chain_id)) + pickle.dump( + feature_dict, + gzip.GzipFile(features_output_path, 'wb'), + protocol=4) + features.append(feature_dict) + + if is_multimer: + multimer_msa = parsers.parse_a3m(paired_msa[idx]) + pair_features = pipeline.make_msa_features([multimer_msa]) + pair_feature_dict = compress_features(pair_features) + uniprot_output_path = os.path.join( + output_dir, '{}.uniprot.pkl.gz'.format(chain_id)) + pickle.dump( + pair_feature_dict, + gzip.GzipFile(uniprot_output_path, 'wb'), + protocol=4, + ) + pair_features.append(pair_feature_dict) + + # return features, pair_features, target_id + return { + 'features': features, + 'pair_features': pair_features, + 'target_id': target_id, + 'is_multimer': is_multimer, + } + + +if __name__ == '__main__': + proc = UniFoldPreprocessor() + protein_example = 'LILNLRGGAFVSNTQITMADKQKKFINEIQEGDLVRSYSITDETFQQNAVTSIVKHEADQLCQINFGKQHVVC' + \ + 'TVNHRFYDPESKLWKSVCPHPGSGISFLKKYDYLLSEEGEKLQITEIKTFTTKQPVFIYHIQVENNHNFFANGVLAHAMQVSI' + features, pair_features = proc.__call__(protein_example) + import ipdb + ipdb.set_trace() diff --git a/modelscope/preprocessors/space/fields/__init__.py b/modelscope/preprocessors/space/fields/__init__.py deleted file mode 100644 index 925eac71..00000000 --- a/modelscope/preprocessors/space/fields/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .gen_field import MultiWOZBPETextField -from .intent_field import IntentBPETextField diff --git a/modelscope/preprocessors/space/fields/dst_processors.py b/modelscope/preprocessors/space/fields/dst_processors.py deleted file mode 100644 index 22e06eec..00000000 --- a/modelscope/preprocessors/space/fields/dst_processors.py +++ /dev/null @@ -1,1523 +0,0 @@ -# -# Copyright 2020 Heinrich Heine University Duesseldorf -# -# Part of this code is based on the source code of BERT-DST -# (arXiv:1907.03040) -# Part of this code is based on the source code of Transformers -# (arXiv:1910.03771) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import re - -import json -import numpy as np -import six -from tqdm import tqdm - -logger = logging.getLogger(__name__) -USER_NAME = 'User' -SYSTEM_NAME = 'System' -DIALOG_ACT = 'Dialog_Act' - -utter1 = { - 'User-1': - "I'd really like to take my client out to a nice restaurant that serves indian food." -} -history_states1 = [ - {}, -] -utter2 = { - 'User-1': - "I'd really like to take my client out to a nice restaurant that serves indian food.", - 'System-1': - 'I show many restaurants that serve Indian food in that price range. What area would you like to travel to?', - 'Dialog_Act-1': { - 'Restaurant-Inform': [['choice', 'many'], ['food', 'Indian'], - ['pricerange', 'that price range']] - }, - 'User-2': - 'I am looking for an expensive indian restaurant in the area of centre.', -} - -history_states2 = [{}, { - 'attraction': { - 'book': { - 'booked': [] - }, - 'semi': { - 'area': '', - 'name': '', - 'type': '' - } - }, - 'hospital': { - 'book': { - 'booked': [] - }, - 'semi': { - 'department': '' - } - }, - 'hotel': { - 'book': { - 'booked': [{ - 'name': 'alexander bed and breakfast', - 'reference': 'JXVKZ7KV' - }], - 'day': - 'sunday', - 'people': - '6', - 'stay': - '4' - }, - 'semi': { - 'area': '', - 'internet': 'yes', - 'name': 'alexander bed and breakfast', - 'parking': 'yes', - 'pricerange': 'cheap', - 'stars': '', - 'type': 'guesthouse' - } - }, - 'police': { - 'book': { - 'booked': [] - }, - 'semi': {} - }, - 'restaurant': { - 'book': { - 'booked': [{ - 'name': 'ask', - 'reference': 'Y2Y8QYBY' - }], - 'day': 'sunday', - 'people': '6', - 'time': '18:45' - }, - 'semi': { - 'area': 'centre', - 'food': 'italian', - 'name': 'ask', - 'pricerange': 'cheap' - } - }, - 'taxi': { - 'book': { - 'booked': [] - }, - 'semi': { - 'arriveBy': '', - 'departure': '', - 'destination': '', - 'leaveAt': '' - } - }, - 'train': { - 'book': { - 'booked': [], - 'people': '' - }, - 'semi': { - 'arriveBy': '', - 'day': '', - 'departure': '', - 'destination': '', - 'leaveAt': '' - } - } -}, {}] - -utter3 = { - 'User-1': - "I'd really like to take my client out to a nice restaurant that serves indian food.", - 'System-1': - 'I show many restaurants that serve Indian food in that price range. What area would you like to travel to?', - 'Dialog_Act-1': { - 'Restaurant-Inform': [['choice', 'many'], ['food', 'Indian'], - ['pricerange', 'that price range']] - }, - 'User-2': - 'I am looking for an expensive indian restaurant in the area of centre.', - 'System-2': - 'Might I recommend Saffron Brasserie? That is an expensive Indian restaurant ' - 'in the center of town. I can book a table for you, if you like.', - 'Dialog_Act-2': { - 'Restaurant-Recommend': [['area', 'center of town'], - ['food', 'Indian'], - ['name', 'Saffron Brasserie'], - ['pricerange', 'expensive']] - }, - 'User-3': - 'Sure thing, please book for 6 people at 19:30 on Saturday.' -} - -history_states3 = [{}, { - 'attraction': { - 'book': { - 'booked': [] - }, - 'semi': { - 'area': '', - 'name': '', - 'type': '' - } - }, - 'hospital': { - 'book': { - 'booked': [] - }, - 'semi': { - 'department': '' - } - }, - 'hotel': { - 'book': { - 'booked': [{ - 'name': 'alexander bed and breakfast', - 'reference': 'JXVKZ7KV' - }], - 'day': - 'sunday', - 'people': - '6', - 'stay': - '4' - }, - 'semi': { - 'area': '', - 'internet': 'yes', - 'name': 'alexander bed and breakfast', - 'parking': 'yes', - 'pricerange': 'cheap', - 'stars': '', - 'type': 'guesthouse' - } - }, - 'police': { - 'book': { - 'booked': [] - }, - 'semi': {} - }, - 'restaurant': { - 'book': { - 'booked': [{ - 'name': 'ask', - 'reference': 'Y2Y8QYBY' - }], - 'day': 'sunday', - 'people': '6', - 'time': '18:45' - }, - 'semi': { - 'area': 'centre', - 'food': 'italian', - 'name': 'ask', - 'pricerange': 'cheap' - } - }, - 'taxi': { - 'book': { - 'booked': [] - }, - 'semi': { - 'arriveBy': '', - 'departure': '', - 'destination': '', - 'leaveAt': '' - } - }, - 'train': { - 'book': { - 'booked': [], - 'people': '' - }, - 'semi': { - 'arriveBy': '', - 'day': '', - 'departure': '', - 'destination': '', - 'leaveAt': '' - } - } -}, {}, { - 'attraction': { - 'book': { - 'booked': [] - }, - 'semi': { - 'area': '', - 'name': '', - 'type': '' - } - }, - 'hospital': { - 'book': { - 'booked': [] - }, - 'semi': { - 'department': '' - } - }, - 'hotel': { - 'book': { - 'booked': [{ - 'name': 'alexander bed and breakfast', - 'reference': 'JXVKZ7KV' - }], - 'day': - 'sunday', - 'people': - '6', - 'stay': - '4' - }, - 'semi': { - 'area': '', - 'internet': 'yes', - 'name': 'alexander bed and breakfast', - 'parking': 'yes', - 'pricerange': 'cheap', - 'stars': '', - 'type': 'guesthouse' - } - }, - 'police': { - 'book': { - 'booked': [] - }, - 'semi': {} - }, - 'restaurant': { - 'book': { - 'booked': [{ - 'name': 'ask', - 'reference': 'Y2Y8QYBY' - }], - 'day': 'sunday', - 'people': '6', - 'time': '18:45' - }, - 'semi': { - 'area': 'centre', - 'food': 'italian', - 'name': 'ask', - 'pricerange': 'cheap' - } - }, - 'taxi': { - 'book': { - 'booked': [] - }, - 'semi': { - 'arriveBy': '', - 'departure': '', - 'destination': '', - 'leaveAt': '' - } - }, - 'train': { - 'book': { - 'booked': [], - 'people': '' - }, - 'semi': { - 'arriveBy': '', - 'day': '', - 'departure': '', - 'destination': '', - 'leaveAt': '' - } - } -}, {}] - - -class DSTProcessor(object): - ACTS_DICT = { - 'taxi-depart': 'taxi-departure', - 'taxi-dest': 'taxi-destination', - 'taxi-leaveat': 'taxi-leaveAt', - 'taxi-arriveby': 'taxi-arriveBy', - 'train-depart': 'train-departure', - 'train-dest': 'train-destination', - 'train-leaveat': 'train-leaveAt', - 'train-arriveby': 'train-arriveBy', - 'train-bookpeople': 'train-book_people', - 'restaurant-price': 'restaurant-pricerange', - 'restaurant-bookpeople': 'restaurant-book_people', - 'restaurant-bookday': 'restaurant-book_day', - 'restaurant-booktime': 'restaurant-book_time', - 'hotel-price': 'hotel-pricerange', - 'hotel-bookpeople': 'hotel-book_people', - 'hotel-bookday': 'hotel-book_day', - 'hotel-bookstay': 'hotel-book_stay', - 'booking-bookpeople': 'booking-book_people', - 'booking-bookday': 'booking-book_day', - 'booking-bookstay': 'booking-book_stay', - 'booking-booktime': 'booking-book_time', - } - - LABEL_MAPS = {} # Loaded from file - - def __init__(self): - # Required for mapping slot names in dialogue_acts.json file - # to proper designations. - pass - - def _convert_inputs_to_utterances(self, inputs: dict, - history_states: list): - """This method is to generate the utterances with user, sys, dialog_acts and metadata, - while metadata is from the history_states or the output from the inference pipline""" - - utterances = [] - user_inputs = [] - sys_gen_inputs = [] - dialog_acts_inputs = [] - for i, item in enumerate(inputs): - name, turn = item.split('-') - if name == USER_NAME: - user_inputs.insert(int(turn) - 1, inputs[item]) - elif name == SYSTEM_NAME: - sys_gen_inputs.insert(int(turn) - 1, inputs[item]) - else: - dialog_acts_inputs.insert(int(turn) - 1, inputs[item]) - - # user is leading the topic should aways larger than sys and dialog acts - assert len(user_inputs) - 1 == len(sys_gen_inputs) - assert len(user_inputs) - 1 == len(dialog_acts_inputs) - # the history states record both user and sys states - assert len(history_states) == len(user_inputs) + len(sys_gen_inputs) - - # the dialog_act at user turn is useless - for i, item in enumerate(history_states): - utterance = {} - # the dialog_act at user turn is useless - utterance['dialog_act'] = dialog_acts_inputs[ - i // 2] if i % 2 == 1 else {} - utterance['text'] = sys_gen_inputs[ - i // 2] if i % 2 == 1 else user_inputs[i // 2] - utterance['metadata'] = item - utterance['span_info'] = [] - utterances.append(utterance) - - return utterances - - def _load_acts(self, inputs: dict, dialog_id='example.json'): - dialog_acts_inputs = [] - for i, item in enumerate(inputs): - name, turn = item.split('-') - if name == DIALOG_ACT: - dialog_acts_inputs.insert(int(turn) - 1, inputs[item]) - s_dict = {} - - for j, item in enumerate(dialog_acts_inputs): - if isinstance(item, dict): - for a in item: - aa = a.lower().split('-') - if aa[1] == 'inform' or aa[1] == 'recommend' or \ - aa[1] == 'select' or aa[1] == 'book': - for i in item[a]: - s = i[0].lower() - v = i[1].lower().strip() - if s == 'none' or v == '?' or v == 'none': - continue - slot = aa[0] + '-' + s - if slot in self.ACTS_DICT: - slot = self.ACTS_DICT[slot] - key = dialog_id, str(int(j) + 1), slot - # In case of multiple mentioned values... - # ... Option 1: Keep first informed value - if key not in s_dict: - s_dict[key] = list([v]) - # ... Option 2: Keep last informed value - # s_dict[key] = list([v]) - - return s_dict - - -class multiwoz22Processor(DSTProcessor): - - def __init__(self): - super().__init__() - - def normalize_time(self, text): - text = re.sub(r'(\d{1})(a\.?m\.?|p\.?m\.?)', r'\1 \2', - text) # am/pm without space - text = re.sub(r'(^| )(\d{1,2}) (a\.?m\.?|p\.?m\.?)', r'\1\2:00 \3', - text) # am/pm short to long form - text = re.sub( - r'(^| )(at|from|by|until|after) ?(\d{1,2}) ?(\d{2})([^0-9]|$)', - r'\1\2 \3:\4\5', text) # Missing separator - text = re.sub(r'(^| )(\d{2})[;.,](\d{2})', r'\1\2:\3', - text) # Wrong separator - text = re.sub(r'(^| )(at|from|by|until|after) ?(\d{1,2})([;., ]|$)', - r'\1\2 \3:00\4', text) # normalize simple full hour time - text = re.sub(r'(^| )(\d{1}:\d{2})', r'\g<1>0\2', - text) # Add missing leading 0 - # Map 12 hour times to 24 hour times - text = \ - re.sub( - r'(\d{2})(:\d{2}) ?p\.?m\.?', - lambda x: str(int(x.groups()[0]) + 12 - if int(x.groups()[0]) < 12 else int(x.groups()[0])) + x.groups()[1], text) - text = re.sub(r'(^| )24:(\d{2})', r'\g<1>00:\2', - text) # Correct times that use 24 as hour - return text - - def normalize_text(self, text): - text = self.normalize_time(text) - text = re.sub("n't", ' not', text) - text = re.sub('(^| )zero(-| )star([s.,? ]|$)', r'\g<1>0 star\3', text) - text = re.sub('(^| )one(-| )star([s.,? ]|$)', r'\g<1>1 star\3', text) - text = re.sub('(^| )two(-| )star([s.,? ]|$)', r'\g<1>2 star\3', text) - text = re.sub('(^| )three(-| )star([s.,? ]|$)', r'\g<1>3 star\3', text) - text = re.sub('(^| )four(-| )star([s.,? ]|$)', r'\g<1>4 star\3', text) - text = re.sub('(^| )five(-| )star([s.,? ]|$)', r'\g<1>5 star\3', text) - text = re.sub('archaelogy', 'archaeology', text) # Systematic typo - text = re.sub('guesthouse', 'guest house', text) # Normalization - text = re.sub('(^| )b ?& ?b([.,? ]|$)', r'\1bed and breakfast\2', - text) # Normalization - text = re.sub('bed & breakfast', 'bed and breakfast', - text) # Normalization - return text - - # Loads the dialogue_acts.json and returns a list - # of slot-value pairs. - def load_acts(self, input_file): - with open(input_file) as f: - acts = json.load(f) - s_dict = {} - for d in acts: - for t in acts[d]: - if int(t) % 2 == 0: - continue - # Only process, if turn has annotation - if isinstance(acts[d][t]['dialog_act'], dict): - for a in acts[d][t]['dialog_act']: - aa = a.lower().split('-') - if aa[1] == 'inform' or aa[1] == 'recommend' \ - or aa[1] == 'select' or aa[1] == 'book': - for i in acts[d][t]['dialog_act'][a]: - s = i[0].lower() - v = i[1].lower().strip() - if s == 'none' or v == '?' or v == 'none': - continue - slot = aa[0] + '-' + s - if slot in self.ACTS_DICT: - slot = self.ACTS_DICT[slot] - key = d, str(int(t) // 2 + 1), slot - # In case of multiple mentioned values... - # ... Option 1: Keep first informed value - if key not in s_dict: - s_dict[key] = list([v]) - # ... Option 2: Keep last informed value - # s_dict[key] = list([v]) - return s_dict - - # This should only contain label normalizations. All other mappings should - # be defined in LABEL_MAPS. - def normalize_label(self, slot, value_label): - # Normalization of empty slots - if value_label == '' or value_label == 'not mentioned': - return 'none' - - # Normalization of time slots - if 'leaveAt' in slot or 'arriveBy' in slot or slot == 'restaurant-book_time': - return self.normalize_time(value_label) - - # Normalization - if 'type' in slot or 'name' in slot or 'destination' in slot or 'departure' in slot: - value_label = re.sub('guesthouse', 'guest house', value_label) - - # Map to boolean slots - if slot == 'hotel-parking' or slot == 'hotel-internet': - if value_label == 'yes' or value_label == 'free': - return 'true' - if value_label == 'no': - return 'false' - if slot == 'hotel-type': - if value_label == 'hotel': - return 'true' - if value_label == 'guest house': - return 'false' - - return value_label - - def tokenize(self, utt): - utt_lower = convert_to_unicode(utt).lower() - utt_lower = self.normalize_text(utt_lower) - utt_tok = [ - tok for tok in map(str.strip, re.split(r'(\W+)', utt_lower)) - if len(tok) > 0 - ] - return utt_tok - - def delex_utt(self, utt, values, unk_token='[UNK]'): - utt_norm = self.tokenize(utt) - for s, vals in values.items(): - for v in vals: - if v != 'none': - v_norm = self.tokenize(v) - v_len = len(v_norm) - for i in range(len(utt_norm) + 1 - v_len): - if utt_norm[i:i + v_len] == v_norm: - utt_norm[i:i + v_len] = [unk_token] * v_len - return utt_norm - - def get_token_pos(self, tok_list, value_label): - find_pos = [] - found = False - label_list = [ - item for item in map(str.strip, re.split(r'(\W+)', value_label)) - if len(item) > 0 - ] - len_label = len(label_list) - for i in range(len(tok_list) + 1 - len_label): - if tok_list[i:i + len_label] == label_list: - find_pos.append((i, i + len_label)) # start, exclusive_end - found = True - return found, find_pos - - def check_label_existence(self, value_label, usr_utt_tok): - in_usr, usr_pos = self.get_token_pos(usr_utt_tok, value_label) - # If no hit even though there should be one, check for value label variants - if not in_usr and value_label in self.LABEL_MAPS: - for value_label_variant in self.LABEL_MAPS[value_label]: - in_usr, usr_pos = self.get_token_pos(usr_utt_tok, - value_label_variant) - if in_usr: - break - return in_usr, usr_pos - - def check_slot_referral(self, value_label, slot, seen_slots): - referred_slot = 'none' - if slot == 'hotel-stars' or slot == 'hotel-internet' or slot == 'hotel-parking': - return referred_slot - for s in seen_slots: - # Avoid matches for slots that share values with different meaning. - # hotel-internet and -parking are handled separately as Boolean slots. - if s == 'hotel-stars' or s == 'hotel-internet' or s == 'hotel-parking': - continue - if re.match('(hotel|restaurant)-book_people', - s) and slot == 'hotel-book_stay': - continue - if re.match('(hotel|restaurant)-book_people', - slot) and s == 'hotel-book_stay': - continue - if slot != s and (slot not in seen_slots - or seen_slots[slot] != value_label): - if seen_slots[s] == value_label: - referred_slot = s - break - elif value_label in self.LABEL_MAPS: - for value_label_variant in self.LABEL_MAPS[value_label]: - if seen_slots[s] == value_label_variant: - referred_slot = s - break - return referred_slot - - def is_in_list(self, tok, value): - found = False - tok_list = [ - item for item in map(str.strip, re.split(r'(\W+)', tok)) - if len(item) > 0 - ] - value_list = [ - item for item in map(str.strip, re.split(r'(\W+)', value)) - if len(item) > 0 - ] - tok_len = len(tok_list) - value_len = len(value_list) - for i in range(tok_len + 1 - value_len): - if tok_list[i:i + value_len] == value_list: - found = True - break - return found - - # Fuzzy matching to label informed slot values - def check_slot_inform(self, value_label, inform_label): - result = False - informed_value = 'none' - vl = ' '.join(self.tokenize(value_label)) - for il in inform_label: - if vl == il: - result = True - elif self.is_in_list(il, vl): - result = True - elif self.is_in_list(vl, il): - result = True - elif il in self.LABEL_MAPS: - for il_variant in self.LABEL_MAPS[il]: - if vl == il_variant: - result = True - break - elif self.is_in_list(il_variant, vl): - result = True - break - elif self.is_in_list(vl, il_variant): - result = True - break - elif vl in self.LABEL_MAPS: - for value_label_variant in self.LABEL_MAPS[vl]: - if value_label_variant == il: - result = True - break - elif self.is_in_list(il, value_label_variant): - result = True - break - elif self.is_in_list(value_label_variant, il): - result = True - break - if result: - informed_value = il - break - return result, informed_value - - def get_turn_label(self, value_label, inform_label, sys_utt_tok, - usr_utt_tok, slot, seen_slots, slot_last_occurrence): - usr_utt_tok_label = [0 for _ in usr_utt_tok] - informed_value = 'none' - referred_slot = 'none' - if value_label == 'none' or value_label == 'dontcare' or value_label == 'true' or value_label == 'false': - class_type = value_label - else: - in_usr, usr_pos = self.check_label_existence( - value_label, usr_utt_tok) - is_informed, informed_value = self.check_slot_inform( - value_label, inform_label) - if in_usr: - class_type = 'copy_value' - if slot_last_occurrence: - (s, e) = usr_pos[-1] - for i in range(s, e): - usr_utt_tok_label[i] = 1 - else: - for (s, e) in usr_pos: - for i in range(s, e): - usr_utt_tok_label[i] = 1 - elif is_informed: - class_type = 'inform' - else: - referred_slot = self.check_slot_referral( - value_label, slot, seen_slots) - if referred_slot != 'none': - class_type = 'refer' - else: - class_type = 'unpointable' - return informed_value, referred_slot, usr_utt_tok_label, class_type - - def _create_example(self, - utterances, - sys_inform_dict, - set_type, - slot_list, - label_maps={}, - append_history=False, - use_history_labels=False, - swap_utterances=False, - label_value_repetitions=False, - delexicalize_sys_utts=False, - unk_token='[UNK]', - analyze=False, - dialog_id='example.json'): - - # Collects all slot changes throughout the dialog - cumulative_labels = {slot: 'none' for slot in slot_list} - - # First system utterance is empty, since multiwoz starts with user input - utt_tok_list = [[]] - mod_slots_list = [] - - # Collect all utterances and their metadata - usr_sys_switch = True - turn_itr = 0 - - for utt in utterances: - # Assert that system and user utterances alternate - is_sys_utt = utt['metadata'] != {} - if usr_sys_switch == is_sys_utt: - print( - 'WARN: Wrong order of system and user utterances. Skipping rest of the dialog %s' - % (dialog_id)) - break - usr_sys_switch = is_sys_utt - - if is_sys_utt: - turn_itr += 1 - - # Delexicalize sys utterance - if delexicalize_sys_utts and is_sys_utt: - inform_dict = {slot: 'none' for slot in slot_list} - for slot in slot_list: - if (str(dialog_id), str(turn_itr), - slot) in sys_inform_dict: - inform_dict[slot] = sys_inform_dict[(str(dialog_id), - str(turn_itr), - slot)] - utt_tok_list.append( - self.delex_utt(utt['text'], inform_dict, - unk_token)) # normalize utterances - else: - utt_tok_list.append(self.tokenize( - utt['text'])) # normalize utterances - - modified_slots = {} - - # If sys utt, extract metadata (identify and collect modified slots) - if is_sys_utt: - for d in utt['metadata']: - booked = utt['metadata'][d]['book']['booked'] - booked_slots = {} - # Check the booked section - if booked != []: - for s in booked[0]: - booked_slots[s] = self.normalize_label( - '%s-%s' % (d, s), - booked[0][s]) # normalize labels - # Check the semi and the inform slots - for category in ['book', 'semi']: - for s in utt['metadata'][d][category]: - cs = '%s-book_%s' % ( - d, s) if category == 'book' else '%s-%s' % (d, - s) - value_label = self.normalize_label( - cs, utt['metadata'][d][category] - [s]) # normalize labels - # Prefer the slot value as stored in the booked section - if s in booked_slots: - value_label = booked_slots[s] - # Remember modified slots and entire dialog state - if cs in slot_list and cumulative_labels[ - cs] != value_label: - modified_slots[cs] = value_label - cumulative_labels[cs] = value_label - - mod_slots_list.append(modified_slots.copy()) - - # Form proper (usr, sys) turns - turn_itr = 0 - diag_seen_slots_dict = {} - diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list} - diag_state = {slot: 'none' for slot in slot_list} - sys_utt_tok = [] - usr_utt_tok = [] - hst_utt_tok = [] - hst_utt_tok_label_dict = {slot: [] for slot in slot_list} - new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy() - new_diag_state = diag_state.copy() - - for i in range(0, len(utt_tok_list) - 1, 2): - sys_utt_tok_label_dict = {} - usr_utt_tok_label_dict = {} - value_dict = {} - inform_dict = {} - inform_slot_dict = {} - referral_dict = {} - class_type_dict = {} - - # Collect turn data - if append_history: - if swap_utterances: - hst_utt_tok = usr_utt_tok + sys_utt_tok + hst_utt_tok - else: - hst_utt_tok = sys_utt_tok + usr_utt_tok + hst_utt_tok - sys_utt_tok = utt_tok_list[i] - usr_utt_tok = utt_tok_list[i + 1] - turn_slots = mod_slots_list[ - i + 1] if len(mod_slots_list) > 1 else {} - - guid = '%s-%s-%s' % (set_type, str(dialog_id), str(turn_itr)) - - if analyze: - print('%15s %2s %s ||| %s' % - (dialog_id, turn_itr, ' '.join(sys_utt_tok), - ' '.join(usr_utt_tok))) - print('%15s %2s [' % (dialog_id, turn_itr), end='') - - new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy() - new_diag_state = diag_state.copy() - for slot in slot_list: - value_label = 'none' - if slot in turn_slots: - value_label = turn_slots[slot] - # We keep the original labels so as to not - # overlook unpointable values, as well as to not - # modify any of the original labels for test sets, - # since this would make comparison difficult. - value_dict[slot] = value_label - elif label_value_repetitions and slot in diag_seen_slots_dict: - value_label = diag_seen_slots_value_dict[slot] - - # Get dialog act annotations - inform_label = list(['none']) - inform_slot_dict[slot] = 0 - if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict: - inform_label = list([ - self.normalize_label(slot, i) - for i in sys_inform_dict[(str(dialog_id), - str(turn_itr), slot)] - ]) - inform_slot_dict[slot] = 1 - elif (str(dialog_id), str(turn_itr), - 'booking-' + slot.split('-')[1]) in sys_inform_dict: - inform_label = list([ - self.normalize_label(slot, i) - for i in sys_inform_dict[(str(dialog_id), - str(turn_itr), 'booking-' - + slot.split('-')[1])] - ]) - inform_slot_dict[slot] = 1 - - (informed_value, referred_slot, usr_utt_tok_label, - class_type) = self.get_turn_label( - value_label, - inform_label, - sys_utt_tok, - usr_utt_tok, - slot, - diag_seen_slots_value_dict, - slot_last_occurrence=True) - - inform_dict[slot] = informed_value - - # Generally don't use span prediction on sys utterance (but inform prediction instead). - sys_utt_tok_label = [0 for _ in sys_utt_tok] - - # Determine what to do with value repetitions. - # If value is unique in seen slots, then tag it, otherwise not, - # since correct slot assignment can not be guaranteed anymore. - if label_value_repetitions and slot in diag_seen_slots_dict: - if class_type == 'copy_value' and list( - diag_seen_slots_value_dict.values()).count( - value_label) > 1: - class_type = 'none' - usr_utt_tok_label = [0 for _ in usr_utt_tok_label] - - sys_utt_tok_label_dict[slot] = sys_utt_tok_label - usr_utt_tok_label_dict[slot] = usr_utt_tok_label - - if append_history: - if use_history_labels: - if swap_utterances: - new_hst_utt_tok_label_dict[ - slot] = usr_utt_tok_label + sys_utt_tok_label + new_hst_utt_tok_label_dict[ - slot] - else: - new_hst_utt_tok_label_dict[ - slot] = sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[ - slot] - else: - new_hst_utt_tok_label_dict[slot] = [ - 0 for _ in sys_utt_tok_label + usr_utt_tok_label - + new_hst_utt_tok_label_dict[slot] - ] - - # For now, we map all occurences of unpointable slot values - # to none. However, since the labels will still suggest - # a presence of unpointable slot values, the task of the - # DST is still to find those values. It is just not - # possible to do that via span prediction on the current input. - if class_type == 'unpointable': - class_type_dict[slot] = 'none' - referral_dict[slot] = 'none' - if analyze: - if slot not in diag_seen_slots_dict or value_label != diag_seen_slots_value_dict[ - slot]: - print('(%s): %s, ' % (slot, value_label), end='') - elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] \ - and class_type != 'copy_value' and class_type != 'inform': - # If slot has seen before and its class type did not change, label this slot a not present, - # assuming that the slot has not actually been mentioned in this turn. - # Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform, - # this must mean there is evidence in the original labels, therefore consider - # them as mentioned again. - class_type_dict[slot] = 'none' - referral_dict[slot] = 'none' - else: - class_type_dict[slot] = class_type - referral_dict[slot] = referred_slot - # Remember that this slot was mentioned during this dialog already. - if class_type != 'none': - diag_seen_slots_dict[slot] = class_type - diag_seen_slots_value_dict[slot] = value_label - new_diag_state[slot] = class_type - # Unpointable is not a valid class, therefore replace with - # some valid class for now... - if class_type == 'unpointable': - new_diag_state[slot] = 'copy_value' - - if analyze: - print(']') - - if swap_utterances: - txt_a = usr_utt_tok - txt_b = sys_utt_tok - txt_a_lbl = usr_utt_tok_label_dict - txt_b_lbl = sys_utt_tok_label_dict - else: - txt_a = sys_utt_tok - txt_b = usr_utt_tok - txt_a_lbl = sys_utt_tok_label_dict - txt_b_lbl = usr_utt_tok_label_dict - - example = DSTExample( - guid=guid, - text_a=txt_a, - text_b=txt_b, - history=hst_utt_tok, - text_a_label=txt_a_lbl, - text_b_label=txt_b_lbl, - history_label=hst_utt_tok_label_dict, - values=diag_seen_slots_value_dict.copy(), - inform_label=inform_dict, - inform_slot_label=inform_slot_dict, - refer_label=referral_dict, - diag_state=diag_state, - class_label=class_type_dict) - # Update some variables. - hst_utt_tok_label_dict = new_hst_utt_tok_label_dict.copy() - diag_state = new_diag_state.copy() - - turn_itr += 1 - return example - - def create_example(self, - inputs, - history_states, - set_type, - slot_list, - label_maps={}, - append_history=False, - use_history_labels=False, - swap_utterances=False, - label_value_repetitions=False, - delexicalize_sys_utts=False, - unk_token='[UNK]', - analyze=False, - dialog_id='0'): - utterances = self._convert_inputs_to_utterances(inputs, history_states) - sys_inform_dict = self._load_acts(inputs) - self.LABEL_MAPS = label_maps - example = self._create_example(utterances, sys_inform_dict, set_type, - slot_list, label_maps, append_history, - use_history_labels, swap_utterances, - label_value_repetitions, - delexicalize_sys_utts, unk_token, - analyze) - - return example - - def create_examples(self, - input_file, - acts_file, - set_type, - slot_list, - label_maps={}, - append_history=False, - use_history_labels=False, - swap_utterances=False, - label_value_repetitions=False, - delexicalize_sys_utts=False, - unk_token='[UNK]', - analyze=False): - """Read a DST json file into a list of DSTExample.""" - - sys_inform_dict = self.load_acts(acts_file) - - with open(input_file, 'r', encoding='utf-8') as reader: - input_data = json.load(reader) - - self.LABEL_MAPS = label_maps - - examples = [] - for dialog_id in tqdm(input_data): - entry = input_data[dialog_id] - utterances = entry['log'] - - example = self._create_example( - utterances, sys_inform_dict, set_type, slot_list, label_maps, - append_history, use_history_labels, swap_utterances, - label_value_repetitions, delexicalize_sys_utts, unk_token, - analyze) - examples.append(example) - - return examples - - -class DSTExample(object): - """ - A single training/test example for the DST dataset. - """ - - def __init__(self, - guid, - text_a, - text_b, - history, - text_a_label=None, - text_b_label=None, - history_label=None, - values=None, - inform_label=None, - inform_slot_label=None, - refer_label=None, - diag_state=None, - class_label=None): - self.guid = guid - self.text_a = text_a - self.text_b = text_b - self.history = history - self.text_a_label = text_a_label - self.text_b_label = text_b_label - self.history_label = history_label - self.values = values - self.inform_label = inform_label - self.inform_slot_label = inform_slot_label - self.refer_label = refer_label - self.diag_state = diag_state - self.class_label = class_label - - def __str__(self): - return self.__repr__() - - def __repr__(self): - s = '' - s += 'guid: %s' % (self.guid) - s += ', text_a: %s' % (self.text_a) - s += ', text_b: %s' % (self.text_b) - s += ', history: %s' % (self.history) - if self.text_a_label: - s += ', text_a_label: %d' % (self.text_a_label) - if self.text_b_label: - s += ', text_b_label: %d' % (self.text_b_label) - if self.history_label: - s += ', history_label: %d' % (self.history_label) - if self.values: - s += ', values: %d' % (self.values) - if self.inform_label: - s += ', inform_label: %d' % (self.inform_label) - if self.inform_slot_label: - s += ', inform_slot_label: %d' % (self.inform_slot_label) - if self.refer_label: - s += ', refer_label: %d' % (self.refer_label) - if self.diag_state: - s += ', diag_state: %d' % (self.diag_state) - if self.class_label: - s += ', class_label: %d' % (self.class_label) - return s - - -class InputFeatures(object): - """A single set of features of data.""" - - def __init__(self, - input_ids, - input_ids_unmasked, - input_mask, - segment_ids, - start_pos=None, - end_pos=None, - values=None, - inform=None, - inform_slot=None, - refer_id=None, - diag_state=None, - class_label_id=None, - guid='NONE'): - self.guid = guid - self.input_ids = input_ids - self.input_ids_unmasked = input_ids_unmasked - self.input_mask = input_mask - self.segment_ids = segment_ids - self.start_pos = start_pos - self.end_pos = end_pos - self.values = values - self.inform = inform - self.inform_slot = inform_slot - self.refer_id = refer_id - self.diag_state = diag_state - self.class_label_id = class_label_id - - -def convert_examples_to_features(examples, - slot_list, - class_types, - model_type, - tokenizer, - max_seq_length, - slot_value_dropout=0.0): - """Loads a data file into a list of `InputBatch`s.""" - - if model_type == 'bert': - model_specs = { - 'MODEL_TYPE': 'bert', - 'CLS_TOKEN': '[CLS]', - 'UNK_TOKEN': '[UNK]', - 'SEP_TOKEN': '[SEP]', - 'TOKEN_CORRECTION': 4 - } - else: - logger.error('Unknown model type (%s). Aborting.' % (model_type)) - exit(1) - - def _tokenize_text_and_label(text, text_label_dict, slot, tokenizer, - model_specs, slot_value_dropout): - joint_text_label = [0 for _ in text_label_dict[slot] - ] # joint all slots' label - for slot_text_label in text_label_dict.values(): - for idx, label in enumerate(slot_text_label): - if label == 1: - joint_text_label[idx] = 1 - - text_label = text_label_dict[slot] - tokens = [] - tokens_unmasked = [] - token_labels = [] - for token, token_label, joint_label in zip(text, text_label, - joint_text_label): - token = convert_to_unicode(token) - sub_tokens = tokenizer.tokenize(token) # Most time intensive step - tokens_unmasked.extend(sub_tokens) - if slot_value_dropout == 0.0 or joint_label == 0: - tokens.extend(sub_tokens) - else: - rn_list = np.random.random_sample((len(sub_tokens), )) - for rn, sub_token in zip(rn_list, sub_tokens): - if rn > slot_value_dropout: - tokens.append(sub_token) - else: - tokens.append(model_specs['UNK_TOKEN']) - token_labels.extend([token_label for _ in sub_tokens]) - assert len(tokens) == len(token_labels) - assert len(tokens_unmasked) == len(token_labels) - return tokens, tokens_unmasked, token_labels - - def _truncate_seq_pair(tokens_a, tokens_b, history, max_length): - """Truncates a sequence pair in place to the maximum length. - Copied from bert/run_classifier.py - """ - # This is a simple heuristic which will always truncate the longer sequence - # one token at a time. This makes more sense than truncating an equal percent - # of tokens from each, since if one sequence is very short then each token - # that's truncated likely contains more information than a longer sequence. - while True: - total_length = len(tokens_a) + len(tokens_b) + len(history) - if total_length <= max_length: - break - if len(history) > 0: - history.pop() - elif len(tokens_a) > len(tokens_b): - tokens_a.pop() - else: - tokens_b.pop() - - def _truncate_length_and_warn(tokens_a, tokens_b, history, max_seq_length, - model_specs, guid): - # Modifies `tokens_a` and `tokens_b` in place so that the total - # length is less than the specified length. - # Account for [CLS], [SEP], [SEP], [SEP] with "- 4" (BERT) - if len(tokens_a) + len(tokens_b) + len( - history) > max_seq_length - model_specs['TOKEN_CORRECTION']: - logger.info('Truncate Example %s. Total len=%d.' % - (guid, len(tokens_a) + len(tokens_b) + len(history))) - input_text_too_long = True - else: - input_text_too_long = False - _truncate_seq_pair(tokens_a, tokens_b, history, - max_seq_length - model_specs['TOKEN_CORRECTION']) - return input_text_too_long - - def _get_token_label_ids(token_labels_a, token_labels_b, - token_labels_history, max_seq_length, - model_specs): - token_label_ids = [] - token_label_ids.append(0) # [CLS] - for token_label in token_labels_a: - token_label_ids.append(token_label) - token_label_ids.append(0) # [SEP] - for token_label in token_labels_b: - token_label_ids.append(token_label) - token_label_ids.append(0) # [SEP] - for token_label in token_labels_history: - token_label_ids.append(token_label) - token_label_ids.append(0) # [SEP] - while len(token_label_ids) < max_seq_length: - token_label_ids.append(0) # padding - assert len(token_label_ids) == max_seq_length - return token_label_ids - - def _get_start_end_pos(class_type, token_label_ids, max_seq_length): - if class_type == 'copy_value' and 1 not in token_label_ids: - # logger.warn("copy_value label, but token_label not detected. Setting label to 'none'.") - class_type = 'none' - start_pos = 0 - end_pos = 0 - if 1 in token_label_ids: - start_pos = token_label_ids.index(1) - # Parsing is supposed to find only first location of wanted value - if 0 not in token_label_ids[start_pos:]: - end_pos = len(token_label_ids[start_pos:]) + start_pos - 1 - else: - end_pos = token_label_ids[start_pos:].index(0) + start_pos - 1 - for i in range(max_seq_length): - if i >= start_pos and i <= end_pos: - assert token_label_ids[i] == 1 - return class_type, start_pos, end_pos - - def _get_transformer_input(tokens_a, tokens_b, history, max_seq_length, - tokenizer, model_specs): - # The convention in BERT is: - # (a) For sequence pairs: - # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] - # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 - # (b) For single sequences: - # tokens: [CLS] the dog is hairy . [SEP] - # type_ids: 0 0 0 0 0 0 0 - # - # Where "type_ids" are used to indicate whether this is the first - # sequence or the second sequence. The embedding vectors for `type=0` and - # `type=1` were learned during pre-training and are added to the wordpiece - # embedding vector (and position vector). This is not *strictly* necessary - # since the [SEP] token unambiguously separates the sequences, but it makes - # it easier for the model to learn the concept of sequences. - # - # For classification tasks, the first vector (corresponding to [CLS]) is - # used as the "sentence vector". Note that this only makes sense because - # the entire model is fine-tuned. - tokens = [] - segment_ids = [] - tokens.append(model_specs['CLS_TOKEN']) - segment_ids.append(0) - for token in tokens_a: - tokens.append(token) - segment_ids.append(0) - tokens.append(model_specs['SEP_TOKEN']) - segment_ids.append(0) - for token in tokens_b: - tokens.append(token) - segment_ids.append(1) - tokens.append(model_specs['SEP_TOKEN']) - segment_ids.append(1) - for token in history: - tokens.append(token) - segment_ids.append(1) - tokens.append(model_specs['SEP_TOKEN']) - segment_ids.append(1) - input_ids = tokenizer.convert_tokens_to_ids(tokens) - # The mask has 1 for real tokens and 0 for padding tokens. Only real - # tokens are attended to. - input_mask = [1] * len(input_ids) - # Zero-pad up to the sequence length. - while len(input_ids) < max_seq_length: - input_ids.append(0) - input_mask.append(0) - segment_ids.append(0) - assert len(input_ids) == max_seq_length - assert len(input_mask) == max_seq_length - assert len(segment_ids) == max_seq_length - return tokens, input_ids, input_mask, segment_ids - - total_cnt = 0 - too_long_cnt = 0 - - refer_list = ['none'] + slot_list - - features = [] - # Convert single example - for (example_index, example) in enumerate(examples): - if example_index % 1000 == 0: - logger.info('Writing example %d of %d' % - (example_index, len(examples))) - - total_cnt += 1 - - value_dict = {} - inform_dict = {} - inform_slot_dict = {} - refer_id_dict = {} - diag_state_dict = {} - class_label_id_dict = {} - start_pos_dict = {} - end_pos_dict = {} - for slot in slot_list: - tokens_a, tokens_a_unmasked, token_labels_a = _tokenize_text_and_label( - example.text_a, example.text_a_label, slot, tokenizer, - model_specs, slot_value_dropout) - tokens_b, tokens_b_unmasked, token_labels_b = _tokenize_text_and_label( - example.text_b, example.text_b_label, slot, tokenizer, - model_specs, slot_value_dropout) - tokens_history, tokens_history_unmasked, token_labels_history = _tokenize_text_and_label( - example.history, example.history_label, slot, tokenizer, - model_specs, slot_value_dropout) - - input_text_too_long = _truncate_length_and_warn( - tokens_a, tokens_b, tokens_history, max_seq_length, - model_specs, example.guid) - - if input_text_too_long: - if example_index < 10: - if len(token_labels_a) > len(tokens_a): - logger.info(' tokens_a truncated labels: %s' - % str(token_labels_a[len(tokens_a):])) - if len(token_labels_b) > len(tokens_b): - logger.info(' tokens_b truncated labels: %s' - % str(token_labels_b[len(tokens_b):])) - if len(token_labels_history) > len(tokens_history): - logger.info( - ' tokens_history truncated labels: %s' - % str(token_labels_history[len(tokens_history):])) - - token_labels_a = token_labels_a[:len(tokens_a)] - token_labels_b = token_labels_b[:len(tokens_b)] - token_labels_history = token_labels_history[:len(tokens_history - )] - tokens_a_unmasked = tokens_a_unmasked[:len(tokens_a)] - tokens_b_unmasked = tokens_b_unmasked[:len(tokens_b)] - tokens_history_unmasked = tokens_history_unmasked[:len( - tokens_history)] - - assert len(token_labels_a) == len(tokens_a) - assert len(token_labels_b) == len(tokens_b) - assert len(token_labels_history) == len(tokens_history) - assert len(token_labels_a) == len(tokens_a_unmasked) - assert len(token_labels_b) == len(tokens_b_unmasked) - assert len(token_labels_history) == len(tokens_history_unmasked) - token_label_ids = _get_token_label_ids(token_labels_a, - token_labels_b, - token_labels_history, - max_seq_length, model_specs) - - value_dict[slot] = example.values[slot] - inform_dict[slot] = example.inform_label[slot] - - class_label_mod, start_pos_dict[slot], end_pos_dict[ - slot] = _get_start_end_pos(example.class_label[slot], - token_label_ids, max_seq_length) - if class_label_mod != example.class_label[slot]: - example.class_label[slot] = class_label_mod - inform_slot_dict[slot] = example.inform_slot_label[slot] - refer_id_dict[slot] = refer_list.index(example.refer_label[slot]) - diag_state_dict[slot] = class_types.index(example.diag_state[slot]) - class_label_id_dict[slot] = class_types.index( - example.class_label[slot]) - - if input_text_too_long: - too_long_cnt += 1 - - tokens, input_ids, input_mask, segment_ids = _get_transformer_input( - tokens_a, tokens_b, tokens_history, max_seq_length, tokenizer, - model_specs) - if slot_value_dropout > 0.0: - _, input_ids_unmasked, _, _ = _get_transformer_input( - tokens_a_unmasked, tokens_b_unmasked, tokens_history_unmasked, - max_seq_length, tokenizer, model_specs) - else: - input_ids_unmasked = input_ids - - assert (len(input_ids) == len(input_ids_unmasked)) - - if example_index < 10: - logger.info('*** Example ***') - logger.info('guid: %s' % (example.guid)) - logger.info('tokens: %s' % ' '.join(tokens)) - logger.info('input_ids: %s' % ' '.join([str(x) - for x in input_ids])) - logger.info('input_mask: %s' - % ' '.join([str(x) for x in input_mask])) - logger.info('segment_ids: %s' - % ' '.join([str(x) for x in segment_ids])) - logger.info('start_pos: %s' % str(start_pos_dict)) - logger.info('end_pos: %s' % str(end_pos_dict)) - logger.info('values: %s' % str(value_dict)) - logger.info('inform: %s' % str(inform_dict)) - logger.info('inform_slot: %s' % str(inform_slot_dict)) - logger.info('refer_id: %s' % str(refer_id_dict)) - logger.info('diag_state: %s' % str(diag_state_dict)) - logger.info('class_label_id: %s' % str(class_label_id_dict)) - - features.append( - InputFeatures( - guid=example.guid, - input_ids=input_ids, - input_ids_unmasked=input_ids_unmasked, - input_mask=input_mask, - segment_ids=segment_ids, - start_pos=start_pos_dict, - end_pos=end_pos_dict, - values=value_dict, - inform=inform_dict, - inform_slot=inform_slot_dict, - refer_id=refer_id_dict, - diag_state=diag_state_dict, - class_label_id=class_label_id_dict)) - - logger.info('========== %d out of %d examples have text too long' % - (too_long_cnt, total_cnt)) - - return features - - -# From bert.tokenization (TF code) -def convert_to_unicode(text): - """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" - if six.PY3: - if isinstance(text, str): - return text - elif isinstance(text, bytes): - return text.decode('utf-8', 'ignore') - else: - raise ValueError('Unsupported string type: %s' % (type(text))) - elif six.PY2: - if isinstance(text, str): - return text.decode('utf-8', 'ignore') - elif isinstance(text, unicode): - return text - else: - raise ValueError('Unsupported string type: %s' % (type(text))) - else: - raise ValueError('Not running on Python2 or Python 3?') - - -if __name__ == '__main__': - processor = multiwoz22Processor() - set_type = 'test' - slot_list = [ - 'taxi-leaveAt', 'taxi-destination', 'taxi-departure', 'taxi-arriveBy', - 'restaurant-book_people', 'restaurant-book_day', - 'restaurant-book_time', 'restaurant-food', 'restaurant-pricerange', - 'restaurant-name', 'restaurant-area', 'hotel-book_people', - 'hotel-book_day', 'hotel-book_stay', 'hotel-name', 'hotel-area', - 'hotel-parking', 'hotel-pricerange', 'hotel-stars', 'hotel-internet', - 'hotel-type', 'attraction-type', 'attraction-name', 'attraction-area', - 'train-book_people', 'train-leaveAt', 'train-destination', 'train-day', - 'train-arriveBy', 'train-departure' - ] - append_history = True - use_history_labels = True - swap_utterances = True - label_value_repetitions = True - delexicalize_sys_utts = True, - unk_token = '[UNK]' - analyze = False - example = processor.create_example(utter1, history_states1, set_type, - slot_list, {}, append_history, - use_history_labels, swap_utterances, - label_value_repetitions, - delexicalize_sys_utts, unk_token, - analyze) - print(f'utterances is {example}') diff --git a/modelscope/preprocessors/video.py b/modelscope/preprocessors/video.py index f693cd9e..794033b5 100644 --- a/modelscope/preprocessors/video.py +++ b/modelscope/preprocessors/video.py @@ -1,5 +1,10 @@ import math +import os import random +import uuid +from os.path import exists +from tempfile import TemporaryDirectory +from urllib.parse import urlparse import numpy as np import torch @@ -9,6 +14,7 @@ import torchvision.transforms._transforms_video as transforms from decord import VideoReader from torchvision.transforms import Compose +from modelscope.hub.file_download import http_get_file from modelscope.metainfo import Preprocessors from modelscope.utils.constant import Fields, ModeKeys from modelscope.utils.type_assert import type_assert @@ -30,7 +36,22 @@ def ReadVideoData(cfg, Returns: data (Tensor): the normalized video clips for model inputs """ - data = _decode_video(cfg, video_path, num_temporal_views_override) + url_parsed = urlparse(video_path) + if url_parsed.scheme in ('file', '') and exists( + url_parsed.path): # Possibly a local file + data = _decode_video(cfg, video_path, num_temporal_views_override) + else: + with TemporaryDirectory() as temporary_cache_dir: + random_str = uuid.uuid4().hex + http_get_file( + url=video_path, + local_dir=temporary_cache_dir, + file_name=random_str, + cookies=None) + temp_file_path = os.path.join(temporary_cache_dir, random_str) + data = _decode_video(cfg, temp_file_path, + num_temporal_views_override) + if num_spatial_crops_override is not None: num_spatial_crops = num_spatial_crops_override transform = kinetics400_tranform(cfg, num_spatial_crops_override) diff --git a/modelscope/trainers/__init__.py b/modelscope/trainers/__init__.py index a632642a..d914489c 100644 --- a/modelscope/trainers/__init__.py +++ b/modelscope/trainers/__init__.py @@ -9,10 +9,10 @@ if TYPE_CHECKING: from .builder import build_trainer from .cv import (ImageInstanceSegmentationTrainer, ImagePortraitEnhancementTrainer, - MovieSceneSegmentationTrainer) + MovieSceneSegmentationTrainer, ImageInpaintingTrainer) from .multi_modal import CLIPTrainer - from .nlp import SequenceClassificationTrainer, PassageRankingTrainer - from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer + from .nlp import SequenceClassificationTrainer, TextRankingTrainer + from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer, NlpTrainerArguments from .trainer import EpochBasedTrainer else: @@ -22,11 +22,13 @@ else: 'builder': ['build_trainer'], 'cv': [ 'ImageInstanceSegmentationTrainer', - 'ImagePortraitEnhancementTrainer', 'MovieSceneSegmentationTrainer' + 'ImagePortraitEnhancementTrainer', 'MovieSceneSegmentationTrainer', + 'ImageInpaintingTrainer' ], 'multi_modal': ['CLIPTrainer'], - 'nlp': ['SequenceClassificationTrainer', 'PassageRankingTrainer'], - 'nlp_trainer': ['NlpEpochBasedTrainer', 'VecoTrainer'], + 'nlp': ['SequenceClassificationTrainer', 'TextRankingTrainer'], + 'nlp_trainer': + ['NlpEpochBasedTrainer', 'VecoTrainer', 'NlpTrainerArguments'], 'trainer': ['EpochBasedTrainer'] } diff --git a/modelscope/trainers/audio/kws_farfield_trainer.py b/modelscope/trainers/audio/kws_farfield_trainer.py new file mode 100644 index 00000000..a720ced5 --- /dev/null +++ b/modelscope/trainers/audio/kws_farfield_trainer.py @@ -0,0 +1,279 @@ +import datetime +import math +import os +from typing import Callable, Dict, Optional + +import numpy as np +import torch +from torch import nn as nn +from torch import optim as optim + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.models import Model, TorchModel +from modelscope.msdatasets.task_datasets.audio import KWSDataLoader, KWSDataset +from modelscope.trainers.base import BaseTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.utils.audio.audio_utils import update_conf +from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile +from modelscope.utils.data_utils import to_device +from modelscope.utils.device import create_device +from modelscope.utils.logger import get_logger +from modelscope.utils.torch_utils import (get_dist_info, get_local_rank, + init_dist, is_master) + +logger = get_logger() + +BASETRAIN_CONF_EASY = 'basetrain_easy' +BASETRAIN_CONF_NORMAL = 'basetrain_normal' +BASETRAIN_CONF_HARD = 'basetrain_hard' +FINETUNE_CONF_EASY = 'finetune_easy' +FINETUNE_CONF_NORMAL = 'finetune_normal' +FINETUNE_CONF_HARD = 'finetune_hard' + +EASY_RATIO = 0.1 +NORMAL_RATIO = 0.6 +HARD_RATIO = 0.3 +BASETRAIN_RATIO = 0.5 + + +@TRAINERS.register_module(module_name=Trainers.speech_dfsmn_kws_char_farfield) +class KWSFarfieldTrainer(BaseTrainer): + DEFAULT_WORK_DIR = './work_dir' + conf_keys = (BASETRAIN_CONF_EASY, FINETUNE_CONF_EASY, + BASETRAIN_CONF_NORMAL, FINETUNE_CONF_NORMAL, + BASETRAIN_CONF_HARD, FINETUNE_CONF_HARD) + + def __init__(self, + model: str, + work_dir: str, + cfg_file: Optional[str] = None, + arg_parse_fn: Optional[Callable] = None, + model_revision: Optional[str] = DEFAULT_MODEL_REVISION, + custom_conf: Optional[dict] = None, + **kwargs): + + if isinstance(model, str): + if os.path.exists(model): + self.model_dir = model if os.path.isdir( + model) else os.path.dirname(model) + else: + self.model_dir = snapshot_download( + model, revision=model_revision) + if cfg_file is None: + cfg_file = os.path.join(self.model_dir, + ModelFile.CONFIGURATION) + else: + assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!' + self.model_dir = os.path.dirname(cfg_file) + + super().__init__(cfg_file, arg_parse_fn) + + self.model = self.build_model() + self.work_dir = work_dir + # the number of model output dimension + # should update config outside the trainer, if user need more wake word + self._num_classes = self.cfg.model.num_syn + + if kwargs.get('launcher', None) is not None: + init_dist(kwargs['launcher']) + + _, world_size = get_dist_info() + self._dist = world_size > 1 + + device_name = kwargs.get('device', 'gpu') + if self._dist: + local_rank = get_local_rank() + device_name = f'cuda:{local_rank}' + + self.device = create_device(device_name) + # model placement + if self.device.type == 'cuda': + self.model.to(self.device) + + if 'max_epochs' not in kwargs: + assert hasattr( + self.cfg.train, 'max_epochs' + ), 'max_epochs is missing from the configuration file' + self._max_epochs = self.cfg.train.max_epochs + else: + self._max_epochs = kwargs['max_epochs'] + self._train_iters = kwargs.get('train_iters_per_epoch', None) + self._val_iters = kwargs.get('val_iters_per_epoch', None) + if self._train_iters is None: + self._train_iters = self.cfg.train.train_iters_per_epoch + if self._val_iters is None: + self._val_iters = self.cfg.evaluation.val_iters_per_epoch + dataloader_config = self.cfg.train.dataloader + self._threads = kwargs.get('workers', None) + if self._threads is None: + self._threads = dataloader_config.workers_per_gpu + self._single_rate = BASETRAIN_RATIO + if 'single_rate' in kwargs: + self._single_rate = kwargs['single_rate'] + self._batch_size = dataloader_config.batch_size_per_gpu + if 'model_bin' in kwargs: + model_bin_file = os.path.join(self.model_dir, kwargs['model_bin']) + checkpoint = torch.load(model_bin_file) + self.model.load_state_dict(checkpoint) + # build corresponding optimizer and loss function + lr = self.cfg.train.optimizer.lr + self.optimizer = optim.Adam(self.model.parameters(), lr) + self.loss_fn = nn.CrossEntropyLoss() + self.data_val = None + self.json_log_path = os.path.join(self.work_dir, + '{}.log.json'.format(self.timestamp)) + self.conf_files = [] + for conf_key in self.conf_keys: + template_file = os.path.join(self.model_dir, conf_key) + conf_file = os.path.join(self.model_dir, f'{conf_key}.conf') + update_conf(template_file, conf_file, custom_conf[conf_key]) + self.conf_files.append(conf_file) + self._current_epoch = 0 + self.stages = (math.floor(self._max_epochs * EASY_RATIO), + math.floor(self._max_epochs * NORMAL_RATIO), + math.floor(self._max_epochs * HARD_RATIO)) + + def build_model(self) -> nn.Module: + """ Instantiate a pytorch model and return. + + By default, we will create a model using config from configuration file. You can + override this method in a subclass. + + """ + model = Model.from_pretrained( + self.model_dir, cfg_dict=self.cfg, training=True) + if isinstance(model, TorchModel) and hasattr(model, 'model'): + return model.model + elif isinstance(model, nn.Module): + return model + + def train(self, *args, **kwargs): + if not self.data_val: + self.gen_val() + logger.info('Start training...') + totaltime = datetime.datetime.now() + + for stage, num_epoch in enumerate(self.stages): + self.run_stage(stage, num_epoch) + + # total time spent + totaltime = datetime.datetime.now() - totaltime + logger.info('Total time spent: {:.2f} hours\n'.format( + totaltime.total_seconds() / 3600.0)) + + def run_stage(self, stage, num_epoch): + """ + Run training stages with correspond data + + Args: + stage: id of stage + num_epoch: the number of epoch to run in this stage + """ + if num_epoch <= 0: + logger.warning(f'Invalid epoch number, stage {stage} exit!') + return + logger.info(f'Starting stage {stage}...') + dataset, dataloader = self.create_dataloader( + self.conf_files[stage * 2], self.conf_files[stage * 2 + 1]) + it = iter(dataloader) + for _ in range(num_epoch): + self._current_epoch += 1 + epochtime = datetime.datetime.now() + logger.info('Start epoch %d...', self._current_epoch) + loss_train_epoch = 0.0 + validbatchs = 0 + for bi in range(self._train_iters): + # prepare data + feat, label = next(it) + label = torch.reshape(label, (-1, )) + feat = to_device(feat, self.device) + label = to_device(label, self.device) + # apply model + self.optimizer.zero_grad() + predict = self.model(feat) + # calculate loss + loss = self.loss_fn( + torch.reshape(predict, (-1, self._num_classes)), label) + if not np.isnan(loss.item()): + loss.backward() + self.optimizer.step() + loss_train_epoch += loss.item() + validbatchs += 1 + train_result = 'Epoch: {:04d}/{:04d}, batch: {:04d}/{:04d}, loss: {:.4f}'.format( + self._current_epoch, self._max_epochs, bi + 1, + self._train_iters, loss.item()) + logger.info(train_result) + self._dump_log(train_result) + + # average training loss in one epoch + loss_train_epoch /= validbatchs + loss_val_epoch = self.evaluate('') + val_result = 'Evaluate epoch: {:04d}, loss_train: {:.4f}, loss_val: {:.4f}'.format( + self._current_epoch, loss_train_epoch, loss_val_epoch) + logger.info(val_result) + self._dump_log(val_result) + # check point + ckpt_name = 'checkpoint_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format( + self._current_epoch, loss_train_epoch, loss_val_epoch) + torch.save(self.model, os.path.join(self.work_dir, ckpt_name)) + # time spent per epoch + epochtime = datetime.datetime.now() - epochtime + logger.info('Epoch {:04d} time spent: {:.2f} hours'.format( + self._current_epoch, + epochtime.total_seconds() / 3600.0)) + dataloader.stop() + dataset.release() + logger.info(f'Stage {stage} is finished.') + + def gen_val(self): + """ + generate validation set + """ + logger.info('Start generating validation set...') + dataset, dataloader = self.create_dataloader(self.conf_files[2], + self.conf_files[3]) + it = iter(dataloader) + + self.data_val = [] + for bi in range(self._val_iters): + logger.info('Iterating validation data %d', bi) + feat, label = next(it) + label = torch.reshape(label, (-1, )) + self.data_val.append([feat, label]) + + dataloader.stop() + dataset.release() + logger.info('Finish generating validation set!') + + def create_dataloader(self, base_path, finetune_path): + dataset = KWSDataset(base_path, finetune_path, self._threads, + self._single_rate, self._num_classes) + dataloader = KWSDataLoader( + dataset, batchsize=self._batch_size, numworkers=self._threads) + dataloader.start() + return dataset, dataloader + + def evaluate(self, checkpoint_path: str, *args, + **kwargs) -> Dict[str, float]: + logger.info('Start validation...') + loss_val_epoch = 0.0 + + with torch.no_grad(): + for feat, label in self.data_val: + feat = to_device(feat, self.device) + label = to_device(label, self.device) + # apply model + predict = self.model(feat) + # calculate loss + loss = self.loss_fn( + torch.reshape(predict, (-1, self._num_classes)), label) + loss_val_epoch += loss.item() + logger.info('Finish validation.') + return loss_val_epoch / self._val_iters + + def _dump_log(self, msg): + if is_master(): + with open(self.json_log_path, 'a+') as f: + f.write(msg) + f.write('\n') diff --git a/modelscope/trainers/cv/__init__.py b/modelscope/trainers/cv/__init__.py index 4c65870e..d09fd75c 100644 --- a/modelscope/trainers/cv/__init__.py +++ b/modelscope/trainers/cv/__init__.py @@ -8,6 +8,7 @@ if TYPE_CHECKING: ImageInstanceSegmentationTrainer from .image_portrait_enhancement_trainer import ImagePortraitEnhancementTrainer from .movie_scene_segmentation_trainer import MovieSceneSegmentationTrainer + from .image_inpainting_trainer import ImageInpaintingTrainer else: _import_structure = { @@ -15,7 +16,8 @@ else: ['ImageInstanceSegmentationTrainer'], 'image_portrait_enhancement_trainer': ['ImagePortraitEnhancementTrainer'], - 'movie_scene_segmentation_trainer': ['MovieSceneSegmentationTrainer'] + 'movie_scene_segmentation_trainer': ['MovieSceneSegmentationTrainer'], + 'image_inpainting_trainer': ['ImageInpaintingTrainer'] } import sys diff --git a/modelscope/trainers/cv/card_detection_scrfd_trainer.py b/modelscope/trainers/cv/card_detection_scrfd_trainer.py new file mode 100644 index 00000000..e1f81bcf --- /dev/null +++ b/modelscope/trainers/cv/card_detection_scrfd_trainer.py @@ -0,0 +1,18 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from modelscope.metainfo import Trainers +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.cv.face_detection_scrfd_trainer import \ + FaceDetectionScrfdTrainer + + +@TRAINERS.register_module(module_name=Trainers.card_detection_scrfd) +class CardDetectionScrfdTrainer(FaceDetectionScrfdTrainer): + + def __init__(self, cfg_file: str, *args, **kwargs): + """ High-level finetune api for SCRFD. + + Args: + cfg_file: Path to configuration file. + """ + # card/face dataset use different img folder names + super().__init__(cfg_file, imgdir_name='', **kwargs) diff --git a/modelscope/trainers/cv/face_detection_scrfd_trainer.py b/modelscope/trainers/cv/face_detection_scrfd_trainer.py new file mode 100644 index 00000000..9cfae7dd --- /dev/null +++ b/modelscope/trainers/cv/face_detection_scrfd_trainer.py @@ -0,0 +1,154 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import copy +import os +import os.path as osp +import time +from typing import Callable, Dict, Optional + +from modelscope.metainfo import Trainers +from modelscope.trainers.base import BaseTrainer +from modelscope.trainers.builder import TRAINERS + + +@TRAINERS.register_module(module_name=Trainers.face_detection_scrfd) +class FaceDetectionScrfdTrainer(BaseTrainer): + + def __init__(self, + cfg_file: str, + cfg_modify_fn: Optional[Callable] = None, + *args, + **kwargs): + """ High-level finetune api for SCRFD. + + Args: + cfg_file: Path to configuration file. + cfg_modify_fn: An input fn which is used to modify the cfg read out of the file. + """ + import mmcv + from mmcv.runner import get_dist_info, init_dist + from mmcv.utils import get_git_hash + from mmdet.utils import collect_env, get_root_logger + from mmdet.apis import set_random_seed + from mmdet.models import build_detector + from mmdet.datasets import build_dataset + from mmdet import __version__ + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets import RetinaFaceDataset + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import DefaultFormatBundleV2 + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import LoadAnnotationsV2 + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import RotateV2 + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import RandomSquareCrop + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.backbones import ResNetV1e + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.dense_heads import SCRFDHead + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.detectors import SCRFD + super().__init__(cfg_file) + cfg = self.cfg + if 'work_dir' in kwargs: + cfg.work_dir = kwargs['work_dir'] + else: + # use config filename as default work_dir if work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(cfg_file))[0]) + mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) + + if 'resume_from' in kwargs: # pretrain model for finetune + cfg.resume_from = kwargs['resume_from'] + cfg.device = 'cuda' + if 'gpu_ids' in kwargs: + cfg.gpu_ids = kwargs['gpu_ids'] + else: + cfg.gpu_ids = range(1) + labelfile_name = kwargs.pop('labelfile_name', 'labelv2.txt') + imgdir_name = kwargs.pop('imgdir_name', 'images/') + if 'train_root' in kwargs: + cfg.data.train.ann_file = kwargs['train_root'] + labelfile_name + cfg.data.train.img_prefix = kwargs['train_root'] + imgdir_name + if 'val_root' in kwargs: + cfg.data.val.ann_file = kwargs['val_root'] + labelfile_name + cfg.data.val.img_prefix = kwargs['val_root'] + imgdir_name + if 'total_epochs' in kwargs: + cfg.total_epochs = kwargs['total_epochs'] + if cfg_modify_fn is not None: + cfg = cfg_modify_fn(cfg) + if 'launcher' in kwargs: + distributed = True + init_dist(kwargs['launcher'], **cfg.dist_params) + # re-set gpu_ids with distributed training mode + _, world_size = get_dist_info() + cfg.gpu_ids = range(world_size) + else: + distributed = False + # no_validate=True will not evaluate checkpoint during training + cfg.no_validate = kwargs.get('no_validate', False) + # init the logger before other steps + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + log_file = osp.join(cfg.work_dir, f'{timestamp}.log') + logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) + # init the meta dict to record some important information such as + # environment info and seed, which will be logged + meta = dict() + # log env info + env_info_dict = collect_env() + env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) + dash_line = '-' * 60 + '\n' + logger.info('Environment info:\n' + dash_line + env_info + '\n' + + dash_line) + meta['env_info'] = env_info + meta['config'] = cfg.pretty_text + # log some basic info + logger.info(f'Distributed training: {distributed}') + logger.info(f'Config:\n{cfg.pretty_text}') + + # set random seeds + if 'seed' in kwargs: + cfg.seed = kwargs['seed'] + _deterministic = kwargs.get('deterministic', False) + logger.info(f'Set random seed to {kwargs["seed"]}, ' + f'deterministic: {_deterministic}') + set_random_seed(kwargs['seed'], deterministic=_deterministic) + else: + cfg.seed = None + meta['seed'] = cfg.seed + meta['exp_name'] = osp.basename(cfg_file) + + model = build_detector(cfg.model) + model.init_weights() + datasets = [build_dataset(cfg.data.train)] + if len(cfg.workflow) == 2: + val_dataset = copy.deepcopy(cfg.data.val) + val_dataset.pipeline = cfg.data.train.pipeline + datasets.append(build_dataset(val_dataset)) + if cfg.checkpoint_config is not None: + # save mmdet version, config file content and class names in + # checkpoints as meta data + cfg.checkpoint_config.meta = dict( + mmdet_version=__version__ + get_git_hash()[:7], + CLASSES=datasets[0].CLASSES) + # add an attribute for visualization convenience + model.CLASSES = datasets[0].CLASSES + + self.cfg = cfg + self.datasets = datasets + self.model = model + self.distributed = distributed + self.timestamp = timestamp + self.meta = meta + self.logger = logger + + def train(self, *args, **kwargs): + from mmdet.apis import train_detector + train_detector( + self.model, + self.datasets, + self.cfg, + distributed=self.distributed, + validate=(not self.cfg.no_validate), + timestamp=self.timestamp, + meta=self.meta) + + def evaluate(self, + checkpoint_path: str = None, + *args, + **kwargs) -> Dict[str, float]: + cfg = self.cfg.evaluation + logger.info(f'eval cfg {cfg}') + logger.info(f'checkpoint_path {checkpoint_path}') diff --git a/modelscope/trainers/cv/image_inpainting_trainer.py b/modelscope/trainers/cv/image_inpainting_trainer.py new file mode 100644 index 00000000..74d1ed9f --- /dev/null +++ b/modelscope/trainers/cv/image_inpainting_trainer.py @@ -0,0 +1,111 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import time +from collections.abc import Mapping + +from torch import distributed as dist + +from modelscope.metainfo import Trainers +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.trainer import EpochBasedTrainer +from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields, + ConfigKeys, Hubs, ModeKeys, ModelFile, + Tasks, TrainerStages) +from modelscope.utils.data_utils import to_device +from modelscope.utils.file_utils import func_receive_dict_inputs + + +@TRAINERS.register_module(module_name=Trainers.image_inpainting) +class ImageInpaintingTrainer(EpochBasedTrainer): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def train(self, *args, **kwargs): + super().train(*args, **kwargs) + + def evaluate(self, *args, **kwargs): + metric_values = super().evaluate(*args, **kwargs) + return metric_values + + def prediction_step(self, model, inputs): + pass + + def train_loop(self, data_loader): + """ Training loop used by `EpochBasedTrainer.train()` + """ + self.invoke_hook(TrainerStages.before_run) + self._epoch = 0 + self.model.train() + for _ in range(self._epoch, self._max_epochs): + self.invoke_hook(TrainerStages.before_train_epoch) + for i, data_batch in enumerate(data_loader): + data_batch = to_device(data_batch, self.device) + self.data_batch = data_batch + self._inner_iter = i + for idx in range(2): + self.invoke_hook(TrainerStages.before_train_iter) + self.train_step(self.model, data_batch, idx) + self.invoke_hook(TrainerStages.after_train_iter) + del self.data_batch + self._iter += 1 + self._mode = ModeKeys.TRAIN + + if i + 1 >= self.iters_per_epoch: + break + + self.invoke_hook(TrainerStages.after_train_epoch) + self._epoch += 1 + + self.invoke_hook(TrainerStages.after_run) + + def train_step(self, model, inputs, idx): + """ Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + Args: + model (`TorchModel`): The model to train. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + + Return: + `torch.Tensor`: The tensor with training loss on this batch. + """ + # EvaluationHook will do evaluate and change mode to val, return to train mode + # TODO: find more pretty way to change mode + model.train() + self._mode = ModeKeys.TRAIN + # call model forward but not __call__ to skip postprocess + if isinstance(inputs, + Mapping) and not func_receive_dict_inputs(model.forward): + train_outputs = model.model._do_step(**inputs, optimizer_idx=idx) + else: + train_outputs = model.model._do_step(inputs, optimizer_idx=idx) + + if not isinstance(train_outputs, dict): + raise TypeError('"model.forward()" must return a dict') + + # add model output info to log + if 'log_vars' not in train_outputs: + default_keys_pattern = ['loss'] + match_keys = set([]) + for key_p in default_keys_pattern: + match_keys.update( + [key for key in train_outputs.keys() if key_p in key]) + + log_vars = {} + for key in match_keys: + value = train_outputs.get(key, None) + if value is not None: + if dist.is_available() and dist.is_initialized(): + value = value.data.clone() + dist.all_reduce(value.div_(dist.get_world_size())) + log_vars.update({key: value.item()}) + self.log_buffer.update(log_vars) + else: + self.log_buffer.update(train_outputs['log_vars']) + + self.train_outputs = train_outputs diff --git a/modelscope/trainers/default_config.py b/modelscope/trainers/default_config.py index c8f0c7b0..a02478b9 100644 --- a/modelscope/trainers/default_config.py +++ b/modelscope/trainers/default_config.py @@ -22,7 +22,8 @@ def merge_cfg(cfg: Config): This function will pop the default CheckpointHook when the BestCkptSaverHook exists in the input cfg. - @param cfg: The input cfg to be merged into. + Aegs: + cfg: The input cfg to be merged into. """ cfg.merge_from_dict(DEFAULT_CONFIG, force=False) # pop duplicate hook diff --git a/modelscope/trainers/hooks/__init__.py b/modelscope/trainers/hooks/__init__.py index f133041b..a2e0cf4b 100644 --- a/modelscope/trainers/hooks/__init__.py +++ b/modelscope/trainers/hooks/__init__.py @@ -6,10 +6,11 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .builder import HOOKS, build_hook from .checkpoint_hook import BestCkptSaverHook, CheckpointHook + from .compression import SparsityHook from .evaluation_hook import EvaluationHook from .hook import Hook from .iter_timer_hook import IterTimerHook - from .logger import TextLoggerHook, TensorboardHook + from .logger import TensorboardHook, TextLoggerHook from .lr_scheduler_hook import LrSchedulerHook from .optimizer import (ApexAMPOptimizerHook, NoneOptimizerHook, OptimizerHook, TorchAMPOptimizerHook) @@ -19,6 +20,7 @@ else: _import_structure = { 'builder': ['HOOKS', 'build_hook'], 'checkpoint_hook': ['BestCkptSaverHook', 'CheckpointHook'], + 'compression': ['SparsityHook'], 'evaluation_hook': ['EvaluationHook'], 'hook': ['Hook'], 'iter_timer_hook': ['IterTimerHook'], diff --git a/modelscope/trainers/hooks/checkpoint_hook.py b/modelscope/trainers/hooks/checkpoint_hook.py index 220929b8..9b86d5b5 100644 --- a/modelscope/trainers/hooks/checkpoint_hook.py +++ b/modelscope/trainers/hooks/checkpoint_hook.py @@ -69,7 +69,7 @@ class CheckpointHook(Hook): self.rng_state = meta.get('rng_state') self.need_load_rng_state = True - def before_train_epoch(self, trainer): + def before_train_iter(self, trainer): if self.need_load_rng_state: if self.rng_state is not None: random.setstate(self.rng_state['random']) @@ -84,13 +84,6 @@ class CheckpointHook(Hook): 'this may cause a random data order or model initialization.' ) - self.rng_state = { - 'random': random.getstate(), - 'numpy': np.random.get_state(), - 'cpu': torch.random.get_rng_state(), - 'cuda': torch.cuda.get_rng_state_all(), - } - def after_train_epoch(self, trainer): if not self.by_epoch: return @@ -142,6 +135,12 @@ class CheckpointHook(Hook): cur_save_name = os.path.join( self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') + self.rng_state = { + 'random': random.getstate(), + 'numpy': np.random.get_state(), + 'cpu': torch.random.get_rng_state(), + 'cuda': torch.cuda.get_rng_state_all(), + } meta = { 'epoch': trainer.epoch, 'iter': trainer.iter + 1, @@ -216,6 +215,7 @@ class BestCkptSaverHook(CheckpointHook): by_epoch (bool): Save best checkpoints by epoch or by iteration. save_optimizer (bool): Whether to save optimizer state dict. Default: True. save_dir (str): Output directory to save best checkpoint. + restore_best (bool): Whether to restore the best checkpoint after training. """ PRIORITY = Priority.LOW @@ -228,6 +228,7 @@ class BestCkptSaverHook(CheckpointHook): save_optimizer=True, save_dir=None, save_file_name=None, + restore_best=False, interval=0): assert rule in ['max', 'min'], 'Only support "max" or "min" rule now.' super().__init__( @@ -241,6 +242,7 @@ class BestCkptSaverHook(CheckpointHook): self._best_metric = None self._best_ckpt_file = None self.save_file_name = save_file_name + self.restore_best = restore_best def _should_save(self, trainer): return self._is_best_metric(trainer.metric_values) @@ -305,3 +307,7 @@ class BestCkptSaverHook(CheckpointHook): self.logger.warn( 'The state_dict is not available, the best metric value will be affected.' ) + + def after_run(self, trainer): + if self.restore_best: + self.load_checkpoint(self._best_ckpt_file, trainer) diff --git a/modelscope/trainers/hooks/compression/__init__.py b/modelscope/trainers/hooks/compression/__init__.py new file mode 100644 index 00000000..f755b2ca --- /dev/null +++ b/modelscope/trainers/hooks/compression/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .sparsity_hook import SparsityHook + from .utils import SparseLinear, convert_sparse_network + +else: + _import_structure = { + 'sparsity_hook': ['SparsityHook'], + 'utils': ['convert_sparse_network', 'SparseLinear'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/trainers/hooks/compression/sparsity_hook.py b/modelscope/trainers/hooks/compression/sparsity_hook.py new file mode 100644 index 00000000..993488d8 --- /dev/null +++ b/modelscope/trainers/hooks/compression/sparsity_hook.py @@ -0,0 +1,131 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os + +from modelscope import __version__ +from modelscope.metainfo import Hooks +from modelscope.trainers.hooks.builder import HOOKS +from modelscope.trainers.hooks.hook import Hook +from modelscope.trainers.hooks.priority import Priority +from modelscope.utils.checkpoint import save_checkpoint +from modelscope.utils.torch_utils import is_master + + +@HOOKS.register_module(module_name=Hooks.SparsityHook) +class SparsityHook(Hook): + + PRIORITY = Priority.HIGHEST + + def __init__(self, pruning_method, config={}, save_dir=None): + self.pruning_method = pruning_method + self.save_dir = save_dir + + self.compress_module = config.get('compress_module', []) + self.weight_rank = config.get('weight_rank', 8) + self.weight_beta = config.get('weight_beta', 1) + self.mask_rank = config.get('mask_rank', 8) + self.mask_alpha1 = config.get('mask_alpha1', 1) + self.mask_alpha2 = config.get('mask_alpha2', 1) + + self.step = 0 + self.total_step = 0 + self.frequency = config.get('frequency', 1) + self.initial_warmup = config.get('initial_warmup', 0.1) + self.final_warmup = config.get('final_warmup', 0.3) + self.initial_sparsity = config.get('initial_sparsity', 0.0) + self.final_sparsity = config.get('final_sparsity', 0.0) + + def before_run(self, trainer): + import torch + + from .utils import SparseLinear, convert_sparse_network + + if self.save_dir is None: + self.save_dir = trainer.work_dir + + if len(self.compress_module) == 0: + convert_sparse_network( + trainer.model, + pruning_method=self.pruning_method, + weight_rank=self.weight_rank, + weight_beta=self.weight_beta, + mask_rank=self.mask_rank, + mask_alpha1=self.mask_alpha1, + mask_alpha2=self.mask_alpha2, + logger=trainer.logger, + ) + else: + for cm in self.compress_module: + for name, module in trainer.model.named_modules(): + if name != cm: + continue + convert_sparse_network( + module, + pruning_method=self.pruning_method, + weight_rank=self.weight_rank, + weight_beta=self.weight_beta, + mask_rank=self.mask_rank, + mask_alpha1=self.mask_alpha1, + mask_alpha2=self.mask_alpha2, + logger=trainer.logger, + ) + + for i in range(len(trainer.optimizer.param_groups)): + new_train_params = [] + for param in trainer.optimizer.param_groups[i]['params']: + is_find = False + for name, module in trainer.model.named_modules(): + if isinstance(module, SparseLinear): + if torch.equal(param.half(), + module.weight.data.half()): + is_find = True + break + + if not is_find: + new_train_params.append(param) + + trainer.optimizer.param_groups[i]['params'] = new_train_params + + new_params = [] + for name, module in trainer.model.named_modules(): + if isinstance(module, SparseLinear): + new_params.extend( + [p for p in module.parameters() if p.requires_grad]) + + trainer.optimizer.add_param_group({'params': new_params}) + + self.total_step = trainer.iters_per_epoch * trainer._max_epochs + + def before_train_iter(self, trainer): + from .utils import schedule_sparsity_ratio, update_network_sparsity + + cur_sparsity = schedule_sparsity_ratio( + self.step, + self.total_step, + self.frequency, + self.initial_warmup, + self.final_warmup, + self.initial_sparsity, + self.final_sparsity, + ) + + update_network_sparsity(trainer.model, cur_sparsity) + + if is_master(): + trainer.logger.info( + f'Step[{self.step}/{self.total_step}] current sparsity ratio = {cur_sparsity}' + ) + + self.step += 1 + + def after_run(self, trainer): + from .utils import generate_sparse_model + + generate_sparse_model(trainer.model, logger=trainer.logger) + + self._save_checkpoint(trainer) + + def _save_checkpoint(self, trainer): + if is_master(): + trainer.logger.info('Saving checkpoint at final compress') + cur_save_name = os.path.join(self.save_dir, 'compress_model.pth') + save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) diff --git a/modelscope/trainers/hooks/compression/utils.py b/modelscope/trainers/hooks/compression/utils.py new file mode 100644 index 00000000..59418201 --- /dev/null +++ b/modelscope/trainers/hooks/compression/utils.py @@ -0,0 +1,208 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch +import torch.nn as nn + +from modelscope.utils.torch_utils import is_master + + +class SparseBinarizer(torch.autograd.Function): + + @staticmethod + def forward(ctx, mask_scores, sparsity): + num_prune = int(mask_scores.numel() * sparsity) + prune_indices = torch.argsort(mask_scores.reshape(-1))[:num_prune] + mask = mask_scores.clone().fill_(1) + mask.reshape(-1)[prune_indices] = 0.0 + return mask + + @staticmethod + def backward(ctx, gradOutput): + return gradOutput, None + + +class SparseLinear(nn.Module): + """ + Fully Connected layer with on the fly adaptive mask. + """ + + def __init__( + self, + module, + pruning_method='pst', + weight_rank=8, + weight_beta=1.0, + mask_rank=8, + mask_alpha1=1.0, + mask_alpha2=1.0, + ): + super(SparseLinear, self).__init__() + self.module = module + out_features = self.module.weight.shape[0] + in_features = self.module.weight.shape[1] + + self.weight = self.module.weight + self.module.weight = None + self.module._parameters.pop('weight') + + self.pruning_method = pruning_method + + self.cur_sparsity = 0.0 + + if self.pruning_method == 'pst': + self.weight_rank = weight_rank + self.weight_beta = weight_beta + self.mask_rank = mask_rank + self.mask_alpha1 = mask_alpha1 + self.mask_alpha2 = mask_alpha2 + + # create trainable params + self.weight_U = nn.Parameter( + torch.randn(out_features, self.weight_rank).to( + device=self.weight.device, dtype=self.weight.dtype)) + self.weight_V = nn.Parameter( + torch.zeros(self.weight_rank, in_features).to( + device=self.weight.device, dtype=self.weight.dtype)) + + self.mask_scores_A = nn.Parameter( + torch.randn(out_features, self.mask_rank).to( + device=self.weight.device, dtype=self.weight.dtype)) + self.mask_scores_B = nn.Parameter( + torch.zeros(self.mask_rank, in_features).to( + device=self.weight.device, dtype=self.weight.dtype)) + self.mask_scores_R = nn.Parameter( + torch.zeros(out_features).to( + device=self.weight.device, dtype=self.weight.dtype)) + self.mask_scores_C = nn.Parameter( + torch.zeros(in_features).to( + device=self.weight.device, dtype=self.weight.dtype)) + + self.weight.requires_grad = False + if self.module.bias is not None: + self.module.bias.requires_grad = False + + def forward(self, *inputs): + if self.pruning_method == 'pst': + weight = self.weight + self.weight_beta * self.weight_U @ self.weight_V + mask_scores = ( + weight.abs() + + self.mask_alpha1 * self.mask_scores_A @ self.mask_scores_B + + self.mask_alpha2 * (self.mask_scores_R.unsqueeze(1) + + self.mask_scores_C.unsqueeze(0))) + + mask = SparseBinarizer.apply(mask_scores, self.cur_sparsity) + masked_weight = mask * weight + + self.module.weight = masked_weight + return self.module(*inputs) + else: + return self.module(*inputs) + + def convert(self): + if self.pruning_method == 'pst': + weight = self.weight + self.weight_beta * self.weight_U @ self.weight_V + mask_scores = ( + weight.abs() + + self.mask_alpha1 * self.mask_scores_A @ self.mask_scores_B + + self.mask_alpha2 * (self.mask_scores_R.unsqueeze(1) + + self.mask_scores_C.unsqueeze(0))) + + mask = SparseBinarizer.apply(mask_scores, self.cur_sparsity) + + masked_weight = mask * weight + self.module.weight = nn.Parameter(masked_weight.data) + + +def _setattr(model, name, module): + name_list = name.split('.') + for name in name_list[:-1]: + model = getattr(model, name) + setattr(model, name_list[-1], module) + + +def convert_sparse_network( + model, + pruning_method, + weight_rank, + weight_beta, + mask_rank, + mask_alpha1, + mask_alpha2, + logger=None, +): + compress_module = [nn.Linear] + try: + from megatron import mpu + compress_module.extend( + [mpu.RowParallelLinear, mpu.ColumnParallelLinear]) + except ImportError: + pass + + for name, module in model.named_modules(): + if type(module) in compress_module: + new_module = SparseLinear( + module, + pruning_method, + weight_rank, + weight_beta, + mask_rank, + mask_alpha1, + mask_alpha2, + ) + + # replace original module by new sparse module + _setattr(model, name, new_module) + + if is_master(): + if logger: + logger.info(f'convert {name} to sparse module.') + else: + print(f'convert {name} to sparse module.') + + +def update_network_sparsity(model, sparsity): + for name, module in model.named_modules(): + if isinstance(module, SparseLinear): + module.cur_sparsity = sparsity + + +def schedule_sparsity_ratio( + step, + total_step, + frequency, + initial_warmup, + final_warmup, + initial_sparsity, + final_sparsity, +): + if step <= initial_warmup * total_step: + sparsity = initial_sparsity + elif step > (total_step - final_warmup * total_step): + sparsity = final_sparsity + else: + spars_warmup_steps = initial_warmup * total_step + spars_schedu_steps = (final_warmup + initial_warmup) * total_step + step = (step - spars_warmup_steps) // frequency * frequency + mul_coeff = 1 - step / (total_step - spars_schedu_steps) + sparsity = final_sparsity + (initial_sparsity - final_sparsity) * ( + mul_coeff**3) + return sparsity + + +def generate_sparse_model(model, logger=None): + # generate sparse weight for saving + for name, module in model.named_modules(): + if isinstance(module, SparseLinear): + module.convert() + + _setattr(model, name, module.module) + + if is_master(): + if logger: + logger.info(f'convert {name} weight to sparse weight, \ + sparsity ratio={torch.mean(1.0*(module.module.weight==0)).item()}.' + ) + else: + print(f'convert {name} weight to sparse, \ + sparsity ratio={torch.mean(1.0*(module.module.weight==0)).item()}.' + ) diff --git a/modelscope/trainers/hooks/logger/text_logger_hook.py b/modelscope/trainers/hooks/logger/text_logger_hook.py index 6629a0c9..8552ab4e 100644 --- a/modelscope/trainers/hooks/logger/text_logger_hook.py +++ b/modelscope/trainers/hooks/logger/text_logger_hook.py @@ -51,7 +51,7 @@ class TextLoggerHook(LoggerHook): if self.out_dir is None: self.out_dir = trainer.work_dir - if not osp.exists(self.out_dir): + if not osp.exists(self.out_dir) and is_master(): os.makedirs(self.out_dir) trainer.logger.info('Text logs will be saved to {}'.format( diff --git a/modelscope/trainers/hooks/lr_scheduler_hook.py b/modelscope/trainers/hooks/lr_scheduler_hook.py index ca0ec01b..ed018fef 100644 --- a/modelscope/trainers/hooks/lr_scheduler_hook.py +++ b/modelscope/trainers/hooks/lr_scheduler_hook.py @@ -47,7 +47,8 @@ class LrSchedulerHook(Hook): return lr def before_train_iter(self, trainer): - if not self.by_epoch: + if not self.by_epoch and trainer.iter >= getattr( + trainer, 'cumulative_iters', 1): if self.warmup_lr_scheduler is not None: self.warmup_lr_scheduler.step() else: diff --git a/modelscope/trainers/hooks/optimizer/base.py b/modelscope/trainers/hooks/optimizer/base.py index 8c61dfdb..0f38c67a 100644 --- a/modelscope/trainers/hooks/optimizer/base.py +++ b/modelscope/trainers/hooks/optimizer/base.py @@ -44,6 +44,7 @@ class OptimizerHook(Hook): def before_run(self, trainer): trainer.optimizer.zero_grad() + trainer.cumulative_iters = self.cumulative_iters def after_train_iter(self, trainer): for k in self.loss_keys: diff --git a/modelscope/trainers/multi_modal/clip/__init__.py b/modelscope/trainers/multi_modal/clip/__init__.py index 87f1040c..61a6664b 100644 --- a/modelscope/trainers/multi_modal/clip/__init__.py +++ b/modelscope/trainers/multi_modal/clip/__init__.py @@ -1 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + from .clip_trainer import CLIPTrainer diff --git a/modelscope/trainers/multi_modal/clip/clip_trainer.py b/modelscope/trainers/multi_modal/clip/clip_trainer.py index cccf4296..cbe83417 100644 --- a/modelscope/trainers/multi_modal/clip/clip_trainer.py +++ b/modelscope/trainers/multi_modal/clip/clip_trainer.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os from typing import Dict, Optional diff --git a/modelscope/trainers/multi_modal/clip/clip_trainer_utils.py b/modelscope/trainers/multi_modal/clip/clip_trainer_utils.py index 1391a4fd..4e150fe7 100644 --- a/modelscope/trainers/multi_modal/clip/clip_trainer_utils.py +++ b/modelscope/trainers/multi_modal/clip/clip_trainer_utils.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os import random diff --git a/modelscope/trainers/multi_modal/ofa/__init__.py b/modelscope/trainers/multi_modal/ofa/__init__.py new file mode 100644 index 00000000..34e4ec7a --- /dev/null +++ b/modelscope/trainers/multi_modal/ofa/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .ofa_trainer import OFATrainer diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py new file mode 100644 index 00000000..02853925 --- /dev/null +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py @@ -0,0 +1,154 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import os +import shutil +from functools import partial +from typing import Callable, Dict, Optional, Tuple, Union + +import torch +from torch import distributed as dist +from torch import nn +from torch.utils.data import Dataset + +from modelscope.metainfo import Trainers +from modelscope.models.base import Model, TorchModel +from modelscope.msdatasets.ms_dataset import MsDataset +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.multi_modal import OfaPreprocessor +from modelscope.preprocessors.ofa.utils.collate import collate_fn +from modelscope.trainers import EpochBasedTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.optimizer.builder import build_optimizer +from modelscope.utils.config import Config +from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys, + ModeKeys) +from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, + get_schedule) + + +@TRAINERS.register_module(module_name=Trainers.ofa) +class OFATrainer(EpochBasedTrainer): + + def __init__( + self, + model: Optional[Union[TorchModel, nn.Module, str]] = None, + cfg_file: Optional[str] = None, + arg_parse_fn: Optional[Callable] = None, + data_collator: Optional[Union[Callable, Dict[str, + Callable]]] = None, + train_dataset: Optional[Union[MsDataset, Dataset]] = None, + eval_dataset: Optional[Union[MsDataset, Dataset]] = None, + preprocessor: Optional[Union[Preprocessor, + Dict[str, Preprocessor]]] = None, + optimizers: Tuple[torch.optim.Optimizer, + torch.optim.lr_scheduler._LRScheduler] = (None, + None), + model_revision: Optional[str] = DEFAULT_MODEL_REVISION, + seed: int = 42, + **kwargs): + model = Model.from_pretrained(model, revision=model_revision) + model_dir = model.model_dir + cfg = Config.from_file(cfg_file) + if 'work_dir' not in kwargs or len(kwargs['work_dir']) == 0: + work_dir = cfg.train.work_dir + else: + work_dir = kwargs['work_dir'] + tokenizer_files = { + 'zh': [ + 'tokenizer.json', 'tokenizer_config.json', 'vocab.txt', + 'config.json' + ], + 'en': + ['tokenizer.json', 'vocab.json', 'merges.txt', 'config.json'], + } + for filename in tokenizer_files[cfg.model.get('language', 'en')]: + finetune_file = os.path.join(work_dir, filename) + pretrain_file = os.path.join(model_dir, filename) + if os.path.exists(finetune_file): + continue + if os.path.exists(pretrain_file): + shutil.copy(pretrain_file, finetune_file) + + if preprocessor is None: + preprocessor = { + ConfigKeys.train: + OfaPreprocessor( + model_dir=work_dir, mode=ModeKeys.TRAIN, no_collate=True), + ConfigKeys.val: + OfaPreprocessor( + model_dir=work_dir, mode=ModeKeys.EVAL, no_collate=True), + } + # use torchrun launch + world_size = int(os.environ.get('WORLD_SIZE', 1)) + epoch_steps = math.ceil( + len(train_dataset) / # noqa + (cfg.train.dataloader.batch_size_per_gpu * world_size)) # noqa + cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs + cfg.train.criterion.tokenizer = model.tokenizer + self.criterion = AdjustLabelSmoothedCrossEntropyCriterion( + cfg.train.criterion) + if optimizers[0] is None: + optimizer = build_optimizer(model, cfg=cfg.train.optimizer) + else: + optimizer = optimizers[0] + if optimizers[1] is None: + scheduler_class, scheduler_args = get_schedule( + cfg.train.lr_scheduler) + if scheduler_class is not None: + lr_scheduler = scheduler_class(**{'optimizer': optimizer}, + **scheduler_args) + else: + lr_scheduler = None + else: + lr_scheduler = optimizers[1] + optimizers = (optimizer, lr_scheduler) + if data_collator is None: + data_collator = partial( + collate_fn, + pad_idx=model.tokenizer.pad_token_id, + eos_idx=model.tokenizer.eos_token_id, + ) + if 'launcher' not in kwargs and cfg.train.get('launcher', None): + kwargs['launcher'] = cfg.train.launcher + if 'use_fp16' not in kwargs and cfg.train.get('use_fp16', False): + kwargs['use_fp16'] = cfg.train.use_fp16 + kwargs['to_tensor'] = False + super().__init__( + model=model, + cfg_file=cfg_file, + arg_parse_fn=arg_parse_fn, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + preprocessor=preprocessor, + optimizers=optimizers, + seed=seed, + **kwargs, + ) + + def train_step(self, model, inputs): + model.train() + model_outputs = model.forward(inputs) + loss, sample_size, logging_output = self.criterion( + model_outputs, inputs) + train_outputs = {'loss': loss} + # add model output info to log + if 'log_vars' not in train_outputs: + default_keys_pattern = ['loss'] + match_keys = set([]) + for key_p in default_keys_pattern: + match_keys.update( + [key for key in train_outputs.keys() if key_p in key]) + log_vars = {} + for key in match_keys: + value = train_outputs.get(key, None) + if value is not None: + if dist.is_available() and dist.is_initialized(): + value = value.data.clone() + dist.all_reduce(value.div_(dist.get_world_size())) + log_vars.update({key: value.item()}) + self.log_buffer.update(log_vars) + else: + self.log_buffer.update(train_outputs['log_vars']) + self.train_outputs = train_outputs diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py new file mode 100644 index 00000000..2189a5db --- /dev/null +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py @@ -0,0 +1,243 @@ +# Copyright 2022 The OFA-Sys Team. +# All rights reserved. +# This source code is licensed under the Apache 2.0 license +# found in the LICENSE file in the root directory. +import math + +import numpy as np +import torch +import torch.nn.functional as F +import transformers +from torch.nn.modules.loss import _Loss + + +def construct_rdrop_sample(x): + if isinstance(x, dict): + for key in x: + x[key] = construct_rdrop_sample(x[key]) + return x + elif isinstance(x, torch.Tensor): + return x.repeat(2, *([1] * (x.dim() - 1))) + elif isinstance(x, int): + return x * 2 + elif isinstance(x, np.ndarray): + return x.repeat(2) + else: + raise NotImplementedError + + +def kl_loss(p, q): + p_loss = F.kl_div(p, torch.exp(q), reduction='sum') + q_loss = F.kl_div(q, torch.exp(p), reduction='sum') + loss = (p_loss + q_loss) / 2 + return loss + + +def label_smoothed_nll_loss(lprobs, + target, + epsilon, + update_num, + reduce=True, + drop_worst_ratio=0.0, + drop_worst_after=0, + use_rdrop=False, + reg_alpha=1.0, + constraint_masks=None, + constraint_start=None, + constraint_end=None): + if target.dim() == lprobs.dim() - 1: + target = target.unsqueeze(-1) + nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1) + if constraint_masks is not None: + smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum( + dim=-1, keepdim=True).squeeze(-1) + eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6) + elif constraint_start is not None and constraint_end is not None: + constraint_range = [0, 1, 2, 3] + list( + range(constraint_start, constraint_end)) + smooth_loss = -lprobs[:, constraint_range].sum( + dim=-1, keepdim=True).squeeze(-1) + eps_i = epsilon / (len(constraint_range) - 1 + 1e-6) + else: + smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1) + eps_i = epsilon / (lprobs.size(-1) - 1) + loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss + if drop_worst_ratio > 0 and update_num > drop_worst_after: + if use_rdrop: + true_batch_size = loss.size(0) // 2 + _, indices = torch.topk( + loss[:true_batch_size], + k=int(true_batch_size * (1 - drop_worst_ratio)), + largest=False) + loss = torch.cat([loss[indices], loss[indices + true_batch_size]]) + nll_loss = torch.cat( + [nll_loss[indices], nll_loss[indices + true_batch_size]]) + lprobs = torch.cat( + [lprobs[indices], lprobs[indices + true_batch_size]]) + else: + loss, indices = torch.topk( + loss, + k=int(loss.shape[0] * (1 - drop_worst_ratio)), + largest=False) + nll_loss = nll_loss[indices] + lprobs = lprobs[indices] + + ntokens = loss.numel() + nll_loss = nll_loss.sum() / ntokens # 后面在grads里面处理 + loss = loss.sum() / ntokens # 后面在grads里面处理 + if use_rdrop: + true_batch_size = lprobs.size(0) // 2 + p = lprobs[:true_batch_size] + q = lprobs[true_batch_size:] + if constraint_start is not None and constraint_end is not None: + constraint_range = [0, 1, 2, 3] + list( + range(constraint_start, constraint_end)) + p = p[:, constraint_range] + q = q[:, constraint_range] + loss += kl_loss(p, q) * reg_alpha + + return loss, nll_loss, ntokens + + +class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): + + def __init__(self, args): + super().__init__() + self.sentence_avg = args.sentence_avg + self.eps = args.label_smoothing + self.ignore_prefix_size = args.ignore_prefix_size + self.ignore_eos = args.ignore_eos + self.report_accuracy = args.report_accuracy + self.drop_worst_ratio = args.drop_worst_ratio + self.drop_worst_after = args.drop_worst_after + self.use_rdrop = args.use_rdrop + self.reg_alpha = args.reg_alpha + self.sample_patch_num = args.sample_patch_num + + self.constraint_start = None + self.constraint_end = None + if args.constraint_range: + constraint_start, constraint_end = args.constraint_range.split(',') + self.constraint_start = int(constraint_start) + self.constraint_end = int(constraint_end) + self.padding_idx = args.tokenizer.pad_token_id + self.args = args + + def forward(self, output, sample, update_num=0, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + if self.use_rdrop: + construct_rdrop_sample(sample) + + loss, nll_loss, ntokens = self.compute_loss( + output, sample, update_num, reduce=reduce) + sample_size = ( + sample['target'].size(0) if self.sentence_avg else ntokens) + logging_output = { + 'loss': loss.data, + 'nll_loss': nll_loss.data, + 'ntokens': sample['ntokens'], + 'nsentences': sample['nsentences'], + 'sample_size': sample_size, + } + return loss, sample_size, logging_output + + def get_lprobs_and_target(self, net_output, sample): + conf = sample['conf'][:, None, None] if 'conf' in sample and sample[ + 'conf'] is not None else 1 + constraint_masks = None + if 'constraint_masks' in sample and sample[ + 'constraint_masks'] is not None: + constraint_masks = sample['constraint_masks'] + net_output[0].masked_fill_(~constraint_masks, -math.inf) + if self.constraint_start is not None and self.constraint_end is not None: + net_output[0][:, :, 4:self.constraint_start] = -math.inf + net_output[0][:, :, self.constraint_end:] = -math.inf + lprobs = F.log_softmax( + net_output[0], dim=-1, dtype=torch.float32) * conf + target = sample['target'] + if self.ignore_prefix_size > 0: + lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous() + target = target[:, self.ignore_prefix_size:].contiguous() + if constraint_masks is not None: + constraint_masks = constraint_masks[:, self.ignore_prefix_size:, :].contiguous() # yapf: disable + if self.ignore_eos: + bsz, seq_len, embed_dim = lprobs.size() + eos_indices = target.eq(self.task.tgt_dict.eos()) + lprobs = lprobs[~eos_indices].reshape(bsz, seq_len - 1, embed_dim) + target = target[~eos_indices].reshape(bsz, seq_len - 1) + if constraint_masks is not None: + constraint_masks = constraint_masks[~eos_indices].reshape( + bsz, seq_len - 1, embed_dim) + if constraint_masks is not None: + constraint_masks = constraint_masks.view(-1, + constraint_masks.size(-1)) + return lprobs.view(-1, + lprobs.size(-1)), target.view(-1), constraint_masks + + def compute_loss(self, net_output, sample, update_num, reduce=True): + lprobs, target, constraint_masks = self.get_lprobs_and_target( + net_output, sample) + if constraint_masks is not None: + constraint_masks = constraint_masks[target != self.padding_idx] + lprobs = lprobs[target != self.padding_idx] + target = target[target != self.padding_idx] + loss, nll_loss, ntokens = label_smoothed_nll_loss( + lprobs, + target, + self.eps, + update_num, + reduce=reduce, + drop_worst_ratio=self.drop_worst_ratio, + drop_worst_after=self.drop_worst_after, + use_rdrop=self.use_rdrop, + reg_alpha=self.reg_alpha, + constraint_masks=constraint_masks, + constraint_start=self.constraint_start, + constraint_end=self.constraint_end) + return loss, nll_loss, ntokens + + +def get_schedule(scheduler): + + if scheduler.name == 'const': + scheduler_class = transformers.get_constant_schedule_with_warmup + scheduler_args = { + 'num_warmup_steps': + int(scheduler.warmup_proportion * scheduler.num_train_steps) + } + elif scheduler.name == 'linear': + scheduler_class = transformers.get_linear_schedule_with_warmup + scheduler_args = { + 'num_warmup_steps': + int(scheduler.warmup_proportion * scheduler.num_train_steps), + 'num_training_steps': + scheduler.num_train_steps + } + elif scheduler.name == 'cosine': + scheduler_class = transformers.get_cosine_schedule_with_warmup + scheduler_args = { + 'num_warmup_steps': + int(scheduler.warmup_proportion * scheduler.num_train_steps), + 'num_training_steps': + scheduler.num_train_steps + } + elif scheduler.name == 'polynomial_decay': + scheduler_class = transformers.get_polynomial_decay_schedule_with_warmup + scheduler_args = { + 'num_warmup_steps': + int(scheduler.warmup_proportion * scheduler.num_train_steps), + 'num_training_steps': + scheduler.num_train_steps, + 'lr_end': + scheduler.lr_end + } + else: + raise NotImplementedError + + return scheduler_class, scheduler_args diff --git a/modelscope/trainers/nlp/__init__.py b/modelscope/trainers/nlp/__init__.py index 001cfefc..22f2cfe6 100644 --- a/modelscope/trainers/nlp/__init__.py +++ b/modelscope/trainers/nlp/__init__.py @@ -6,12 +6,12 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .sequence_classification_trainer import SequenceClassificationTrainer from .csanmt_translation_trainer import CsanmtTranslationTrainer - from .passage_ranking_trainer import PassageRankingTranier + from .text_ranking_trainer import TextRankingTrainer else: _import_structure = { 'sequence_classification_trainer': ['SequenceClassificationTrainer'], 'csanmt_translation_trainer': ['CsanmtTranslationTrainer'], - 'passage_ranking_trainer': ['PassageRankingTrainer'] + 'text_ranking_trainer': ['TextRankingTrainer'] } import sys diff --git a/modelscope/trainers/nlp/space/dialog_intent_trainer.py b/modelscope/trainers/nlp/space/dialog_intent_trainer.py index 2e59cd80..4baaddfe 100644 --- a/modelscope/trainers/nlp/space/dialog_intent_trainer.py +++ b/modelscope/trainers/nlp/space/dialog_intent_trainer.py @@ -1,23 +1,22 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os -import time -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Dict, Optional import numpy as np from modelscope.metainfo import Trainers from modelscope.models.nlp.space.model.generator import SpaceGenerator from modelscope.models.nlp.space.model.model_base import SpaceModelBase -from modelscope.preprocessors.space.data_loader import \ +from modelscope.preprocessors.nlp.space.data_loader import \ get_sequential_data_loader -from modelscope.preprocessors.space.fields.intent_field import \ +from modelscope.preprocessors.nlp.space.fields.intent_field import \ IntentBPETextField -from modelscope.preprocessors.space.preprocess import intent_preprocess +from modelscope.preprocessors.nlp.space.preprocess import intent_preprocess from modelscope.trainers.base import BaseTrainer from modelscope.trainers.builder import TRAINERS from modelscope.trainers.nlp.space.trainer.intent_trainer import IntentTrainer -from modelscope.utils.config import Config +from modelscope.utils.config import Config, ModelFile from modelscope.utils.logger import get_logger PATH = None @@ -34,14 +33,6 @@ class DialogIntentTrainer(BaseTrainer): **kwargs): super().__init__(os.path.join(kwargs['model_dir'], kwargs['cfg_name'])) - def to_tensor(array): - """ - numpy array -> tensor - """ - import torch - array = torch.tensor(array) - return array.cuda() if self.cfg.use_gpu else array - def setup_seed(seed): import random import torch @@ -59,56 +50,70 @@ class DialogIntentTrainer(BaseTrainer): # preprocess data intent_preprocess(self.cfg.Model.init_checkpoint, self.cfg) # set reader and evaluator - bpe = IntentBPETextField(self.cfg.Model.init_checkpoint, self.cfg) + self.bpe = IntentBPETextField(self.cfg.Model.init_checkpoint, self.cfg) - self.cfg.Model.num_token_embeddings = bpe.vocab_size - self.cfg.Model.num_turn_embeddings = bpe.max_ctx_turn + 1 + self.cfg.Model.num_token_embeddings = self.bpe.vocab_size + self.cfg.Model.num_turn_embeddings = self.bpe.max_ctx_turn + 1 dataset_paths = [ os.path.join(self.cfg.Dataset.data_dir, self.cfg.Dataset.trigger_data) ] # set data and data status - collate_fn = bpe.collate_fn_multi_turn + collate_fn = self.bpe.collate_fn_multi_turn self.train_label_loader = get_sequential_data_loader( batch_size=self.cfg.Trainer.batch_size_label, - reader=bpe, + reader=self.bpe, hparams=self.cfg, data_paths=dataset_paths, collate_fn=collate_fn, data_type='train') self.valid_label_loader = get_sequential_data_loader( batch_size=self.cfg.Trainer.batch_size_label, - reader=bpe, + reader=self.bpe, hparams=self.cfg, data_paths=dataset_paths, collate_fn=collate_fn, data_type='valid') self.test_label_loader = get_sequential_data_loader( batch_size=self.cfg.Trainer.batch_size_label, - reader=bpe, + reader=self.bpe, hparams=self.cfg, data_paths=dataset_paths, collate_fn=collate_fn, data_type='test') # set generator - generator = SpaceGenerator.create(self.cfg, reader=bpe) + self.generator = SpaceGenerator.create(self.cfg, reader=self.bpe) + self._load_model(**kwargs) + + def _load_model(self, **kwargs): + + def to_tensor(array): + """ + numpy array -> tensor + """ + import torch + array = torch.tensor(array) + return array.cuda() if self.cfg.use_gpu else array + # construct model - self.model = SpaceModelBase.create( - self.cfg.Model.init_checkpoint, - self.cfg, - reader=bpe, - generator=generator) + if 'model' in kwargs: + self.model = kwargs['model'] + else: + self.model = SpaceModelBase.create( + kwargs['model_dir'], + self.cfg, + reader=self.bpe, + generator=self.generator) import torch - # multi-gpu if self.cfg.Trainer.gpu > 1 and torch.cuda.device_count() > 1: self.model = torch.nn.DataParallel(self.model) # construct trainer self.trainer = IntentTrainer( - self.model, to_tensor, self.cfg, reader=bpe) + self.model, to_tensor, self.cfg, reader=self.bpe) num_batches = len(self.train_label_loader) self.trainer.set_optimizers(num_training_steps_per_epoch=num_batches) # load model, optimizer and lr_scheduler @@ -131,6 +136,16 @@ class DialogIntentTrainer(BaseTrainer): *args, **kwargs) -> Dict[str, float]: logger.info('Evaluate') + self.cfg.do_infer = True + + # get best checkpoint path + pos = checkpoint_path.rfind('/') + checkpoint_name = checkpoint_path[pos + 1:] + checkpoint_dir = checkpoint_path[:pos] + + assert checkpoint_name == ModelFile.TORCH_MODEL_BIN_FILE + kwargs['model_dir'] = checkpoint_dir + self._load_model(**kwargs) self.trainer.infer( data_iter=self.test_label_loader, ex_data_iter=self.train_label_loader) diff --git a/modelscope/trainers/nlp/space/dialog_modeling_trainer.py b/modelscope/trainers/nlp/space/dialog_modeling_trainer.py index 726404d4..aa6bb69d 100644 --- a/modelscope/trainers/nlp/space/dialog_modeling_trainer.py +++ b/modelscope/trainers/nlp/space/dialog_modeling_trainer.py @@ -9,8 +9,7 @@ import numpy as np from modelscope.metainfo import Trainers from modelscope.models.nlp.space.model.generator import SpaceGenerator from modelscope.models.nlp.space.model.model_base import SpaceModelBase -from modelscope.preprocessors.space.fields.gen_field import \ - MultiWOZBPETextField +from modelscope.preprocessors.nlp import MultiWOZBPETextField from modelscope.trainers.base import BaseTrainer from modelscope.trainers.builder import TRAINERS from modelscope.trainers.nlp.space.eval import MultiWOZEvaluator diff --git a/modelscope/trainers/nlp/space/trainer/gen_trainer.py b/modelscope/trainers/nlp/space/trainer/gen_trainer.py index 34cd2f9b..05efa138 100644 --- a/modelscope/trainers/nlp/space/trainer/gen_trainer.py +++ b/modelscope/trainers/nlp/space/trainer/gen_trainer.py @@ -1,9 +1,6 @@ -""" -Trainer class. -""" -import logging +# Copyright (c) Alibaba, Inc. and its affiliates. + import os -import sys import time from collections import OrderedDict @@ -61,7 +58,7 @@ class Trainer(object): self.evaluator = evaluator self.tokenizer = reader.tokenizer - self.logger = get_logger() + self.logger = logger or get_logger() self.batch_metrics_tracker = MetricsTracker() self.token_metrics_tracker = MetricsTracker() diff --git a/modelscope/trainers/nlp/space/trainer/intent_trainer.py b/modelscope/trainers/nlp/space/trainer/intent_trainer.py index 1e6f4a2d..dc6b317b 100644 --- a/modelscope/trainers/nlp/space/trainer/intent_trainer.py +++ b/modelscope/trainers/nlp/space/trainer/intent_trainer.py @@ -1,10 +1,6 @@ -""" -Trainer class. -""" +# Copyright (c) Alibaba, Inc. and its affiliates. -import logging import os -import sys import time from collections import OrderedDict @@ -16,24 +12,8 @@ from transformers.optimization import AdamW, get_linear_schedule_with_warmup from modelscope.trainers.nlp.space.metrics.metrics_tracker import \ MetricsTracker - - -def get_logger(log_path, name='default'): - logger = logging.getLogger(name) - logger.propagate = False - logger.setLevel(logging.DEBUG) - - formatter = logging.Formatter('%(message)s') - - sh = logging.StreamHandler(sys.stdout) - sh.setFormatter(formatter) - logger.addHandler(sh) - - fh = logging.FileHandler(log_path, mode='w') - fh.setFormatter(formatter) - logger.addHandler(fh) - - return logger +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger class Trainer(object): @@ -76,11 +56,7 @@ class Trainer(object): self.lr_scheduler = lr_scheduler self.optimizer = optimizer - # if not os.path.exists(self.save_dir): - # os.makedirs(self.save_dir) - - # self.logger = logger or get_logger(os.path.join(self.save_dir, "trainer.log"), "trainer") - self.logger = logger or get_logger('trainer.log', 'trainer') + self.logger = logger or get_logger() self.batch_metrics_tracker_label = MetricsTracker() self.token_metrics_tracker_label = MetricsTracker() @@ -201,9 +177,12 @@ class Trainer(object): # Save current best model if is_best: - best_model_file = os.path.join(self.save_dir, 'best.model') + best_model_file = os.path.join(self.save_dir, + ModelFile.TORCH_MODEL_BIN_FILE) torch.save(self.model.state_dict(), best_model_file) - best_train_file = os.path.join(self.save_dir, 'best.train') + best_train_file = os.path.join( + self.save_dir, + '{}.train'.format(ModelFile.TORCH_MODEL_BIN_FILE)) torch.save(train_state, best_train_file) self.logger.info( f"Saved best model state to '{best_model_file}' with new best valid metric " @@ -215,7 +194,7 @@ class Trainer(object): def _load_model_state(): model_state_dict = torch.load( - f'{self.func_model.init_checkpoint}.model', + f'{self.func_model.init_checkpoint}', map_location=lambda storage, loc: storage) if 'module.' in list(model_state_dict.keys())[0]: @@ -303,8 +282,13 @@ class Trainer(object): self.logger.info('Loaded no model !!!') return - _load_model_state() - _load_train_state() + if self.do_train: + _load_model_state() + return + + if self.do_infer: + _load_model_state() + _load_train_state() class IntentTrainer(Trainer): @@ -719,104 +703,3 @@ class IntentTrainer(Trainer): assert 'loss' in metrics return metrics['loss'], metrics - - def load(self): - """ load """ - - def _load_model_state(): - model_state_dict = torch.load( - f'{self.func_model.init_checkpoint}', - map_location=lambda storage, loc: storage) - - if 'module.' in list(model_state_dict.keys())[0]: - new_model_state_dict = OrderedDict() - for k, v in model_state_dict.items(): - assert k[:7] == 'module.' - new_model_state_dict[k[7:]] = v - model_state_dict = new_model_state_dict - - new_model_state_dict = OrderedDict() - parameters = { - name: param - for name, param in self.func_model.named_parameters() - } - for name, param in model_state_dict.items(): - if name in parameters: - if param.shape != parameters[name].shape: - assert hasattr(param, 'numpy') - arr = param.numpy() - z = np.random.normal( - scale=self.func_model.initializer_range, - size=parameters[name].shape).astype('float32') - if name == 'embedder.token_embedding.weight': - z[-param.shape[0]:] = arr - print( - f'part of parameter({name}) random normlize initialize' - ) - else: - if z.shape[0] < param.shape[0]: - z = arr[:z.shape[0]] - print(f'part of parameter({name}) are dropped') - else: - z[:param.shape[0]] = arr - print( - f'part of parameter({name}) random normlize initialize' - ) - dtype, device = param.dtype, param.device - z = torch.tensor(z, dtype=dtype, device=device) - new_model_state_dict[name] = z - else: - new_model_state_dict[name] = param - else: - print(f'parameter({name}) are dropped') - model_state_dict = new_model_state_dict - - for name in parameters: - if name not in model_state_dict: - if parameters[name].requires_grad: - print(f'parameter({name}) random normlize initialize') - z = np.random.normal( - scale=self.func_model.initializer_range, - size=parameters[name].shape).astype('float32') - dtype, device = parameters[name].dtype, parameters[ - name].device - model_state_dict[name] = torch.tensor( - z, dtype=dtype, device=device) - else: - model_state_dict[name] = parameters[name] - - self.func_model.load_state_dict(model_state_dict) - self.logger.info( - f"Loaded model state from '{self.func_model.init_checkpoint}.model'" - ) - - def _load_train_state(): - train_file = f'{self.func_model.init_checkpoint}.train' - if os.path.exists(train_file): - train_state_dict = torch.load( - train_file, map_location=lambda storage, loc: storage) - self.epoch = train_state_dict['epoch'] - self.best_valid_metric = train_state_dict['best_valid_metric'] - if self.optimizer is not None and 'optimizer' in train_state_dict: - self.optimizer.load_state_dict( - train_state_dict['optimizer']) - if self.lr_scheduler is not None and 'lr_scheduler' in train_state_dict: - self.lr_scheduler.load_state_dict( - train_state_dict['lr_scheduler']) - self.logger.info( - f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " - f'best_valid_metric={self.best_valid_metric:.3f})') - else: - self.logger.info('Loaded no train state') - - if self.func_model.init_checkpoint is None: - self.logger.info('Loaded no model !!!') - return - - if self.do_train: - _load_model_state() - return - - if self.do_infer: - _load_model_state() - _load_train_state() diff --git a/modelscope/trainers/nlp/passage_ranking_trainer.py b/modelscope/trainers/nlp/text_ranking_trainer.py similarity index 92% rename from modelscope/trainers/nlp/passage_ranking_trainer.py rename to modelscope/trainers/nlp/text_ranking_trainer.py index 711fd0c4..610c36b5 100644 --- a/modelscope/trainers/nlp/passage_ranking_trainer.py +++ b/modelscope/trainers/nlp/text_ranking_trainer.py @@ -8,12 +8,13 @@ import numpy as np import torch from torch import nn from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm from modelscope.metainfo import Trainers from modelscope.models.base import Model, TorchModel +from modelscope.models.nlp import BertForTextRanking from modelscope.msdatasets.ms_dataset import MsDataset from modelscope.preprocessors.base import Preprocessor -from modelscope.trainers.base import BaseTrainer from modelscope.trainers.builder import TRAINERS from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer from modelscope.utils.constant import DEFAULT_MODEL_REVISION @@ -42,8 +43,8 @@ class GroupCollator(): return batch -@TRAINERS.register_module(module_name=Trainers.nlp_passage_ranking_trainer) -class PassageRankingTrainer(NlpEpochBasedTrainer): +@TRAINERS.register_module(module_name=Trainers.nlp_text_ranking_trainer) +class TextRankingTrainer(NlpEpochBasedTrainer): def __init__( self, @@ -117,7 +118,6 @@ class PassageRankingTrainer(NlpEpochBasedTrainer): Example: {"accuracy": 0.5091743119266054, "f1": 0.673780487804878} """ - from modelscope.models.nlp import PassageRanking # get the raw online dataset self.eval_dataloader = self._build_dataloader_with_dataset( self.eval_dataset, @@ -126,7 +126,7 @@ class PassageRankingTrainer(NlpEpochBasedTrainer): # generate a standard dataloader # generate a model if checkpoint_path is not None: - model = PassageRanking.from_pretrained(checkpoint_path) + model = BertForTextRanking.from_pretrained(checkpoint_path) else: model = self.model @@ -141,7 +141,7 @@ class PassageRankingTrainer(NlpEpochBasedTrainer): total_spent_time = 0.0 device = 'cuda:0' if torch.cuda.is_available() else 'cpu' model.to(device) - for _step, batch in enumerate(self.eval_dataloader): + for _step, batch in enumerate(tqdm(self.eval_dataloader)): try: batch = { key: @@ -155,13 +155,16 @@ class PassageRankingTrainer(NlpEpochBasedTrainer): with torch.no_grad(): label_ids = batch.pop('labels').detach().cpu().numpy() qids = batch.pop('qid').detach().cpu().numpy() - outputs = model(batch) + outputs = model(**batch) infer_end_time = time.time() total_spent_time += infer_end_time - infer_start_time total_samples += self.eval_dataloader.batch_size - assert 'scores' in outputs - logits = outputs['scores'] + def sigmoid(logits): + return np.exp(logits) / (1 + np.exp(logits)) + + logits = outputs['logits'].squeeze(-1).detach().cpu().numpy() + logits = sigmoid(logits).tolist() label_list.extend(label_ids) logits_list.extend(logits) diff --git a/modelscope/trainers/nlp_trainer.py b/modelscope/trainers/nlp_trainer.py index b54aa666..a19e7c7b 100644 --- a/modelscope/trainers/nlp_trainer.py +++ b/modelscope/trainers/nlp_trainer.py @@ -1,7 +1,9 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os -from typing import Callable, Optional, Tuple, Union +from copy import deepcopy +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -13,15 +15,416 @@ from modelscope.metainfo import Trainers from modelscope.metrics.builder import build_metric from modelscope.models.base import Model, TorchModel from modelscope.msdatasets import MsDataset -from modelscope.preprocessors import Preprocessor, build_preprocessor -from modelscope.utils.config import Config +from modelscope.preprocessors import Preprocessor +from modelscope.utils.config import Config, ConfigDict from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ModeKeys, - ModelFile, Tasks) + ModelFile) from modelscope.utils.hub import parse_label_mapping from .base import TRAINERS from .trainer import EpochBasedTrainer +@dataclass +class NlpTrainerArguments: + """The arguments for the nlp trainer. + + All the arguments listed here have None default values, which means follow the default value in the input + cfg dict. + """ + + work_dir: Optional[str] = field( + default=None, metadata={'help': 'The work dir(key: train.work_dir)'}) + + task: Optional[str] = field( + default=None, metadata={'help': 'The task type(key: task)'}) + + preprocessor_type: Optional[str] = field( + default=None, + metadata={'help': 'The preprocessor type(key: preprocessor.type)'}) + + train_first_sequence: str = field( + default=None, + metadata={ + 'help': + 'The key of first sentence for the training dataset(key:preprocessor.train.' + 'first_sequence/dataset.train.first_sequence)' + }) + + train_second_sequence: Optional[str] = field( + default=None, + metadata={ + 'help': + 'The key of second sentence for the training dataset(key:preprocessor.train.' + 'second_sequence/dataset.train.second_sequence)' + }) + + train_label: str = field( + default=None, + metadata={ + 'help': + 'The key of label for the training dataset(key:preprocessor.train.' + 'second_sequence/dataset.train.second_sequence)' + }) + + eval_first_sequence: Optional[str] = field( + default=None, + metadata={ + 'help': + 'The key of first sentence for the eval dataset(key:preprocessor.val.' + 'first_sequence/dataset.val.first_sequence), ' + 'if not provided, the trainer will use the train_first_sequence for evaluation' + }) + + eval_second_sequence: Optional[str] = field( + default=None, + metadata={ + 'help': + 'The key of second sentence for the eval dataset(key:preprocessor.val.' + 'second_sequence/dataset.val.second_sequence),' + 'if not provided, the trainer will use the train_second_sequence for evaluation' + }) + + eval_label: Optional[str] = field( + default=None, + metadata={ + 'help': + 'The key of label for the eval dataset(key:preprocessor.val.' + 'second_sequence/dataset.val.second_sequence),' + 'if not provided, the trainer will use the train_label for evaluation' + }) + + labels: Optional[List] = field( + default=None, + metadata={ + 'help': + 'The labels list of the dataset(key:dataset.train.labels),' + 'This parameter has the same effect with "label2id"' + }) + + max_epochs: Optional[int] = field( + default=None, + metadata={ + 'help': + 'The max_epochs of the training loop(key: train.max_epochs)' + }) + + train_batch_size_per_gpu: Optional[int] = field( + default=None, + metadata={ + 'help': + 'The train batch size per gpu(key: train.dataloader.batch_size_per_gpu)' + }) + + train_workers_per_gpu: Optional[int] = field( + default=None, + metadata={ + 'help': + 'The number of workers per gpu(key: train.dataloader.workers_per_gpu)' + }) + + train_shuffle: Optional[bool] = field( + default=None, + metadata={ + 'help': + 'Shuffle the train dataset or not(key: train.dataloader.shuffle)' + }) + + eval_batch_size_per_gpu: Optional[int] = field( + default=None, + metadata={ + 'help': + 'The eval batch size per gpu(key: evaluation.dataloader.batch_size_per_gpu)' + }) + + eval_workers_per_gpu: Optional[int] = field( + default=None, + metadata={ + 'help': + 'The number of workers per gpu(key: evaluation.dataloader.workers_per_gpu)' + }) + + eval_shuffle: Optional[bool] = field( + default=None, + metadata={ + 'help': + 'Shuffle the eval dataset or not(key: evaluation.dataloader.shuffle)' + }) + + optimizer_args: Optional[Dict] = field( + default=None, + metadata={'help': 'The optimizer config dict(key: train.optimizer)'}) + + lr_scheduler_args: Optional[Dict] = field( + default=None, + metadata={ + 'help': 'The lr_scheduler config dict(key: train.lr_scheduler)' + }) + + checkpoint_saving_type: Optional[str] = field( + default=None, + metadata={ + 'help': + 'The checkpoint saving type(key: The ckpt hook dict in train.hooks), ' + 'valid options: "BestCkptSaverHook", "CheckpointHook"' + }) + + checkpoint_by_epoch: Optional[bool] = field( + default=None, + metadata={ + 'help': + 'Saving checkpoint by epoch or not(key: The by_epoch key in ' + 'ckpt hook dict in train.hooks)' + }) + + checkpoint_interval: Optional[int] = field( + default=None, + metadata={ + 'help': + 'The checkpoint saving interval(key: The interval key in ' + 'ckpt hook dict in train.hooks)' + }) + + metric_key: Optional[str] = field( + default=None, + metadata={ + 'help': + 'The metric key for the BestCkptSaverHook(key: The metric_key key in ' + 'ckpt hook dict in train.hooks), if the checkpoint_saving_type is "CheckpointHook" or ' + '"None", the metric_key key has no effects' + }) + + evaluation_type: Optional[str] = field( + default=None, + metadata={ + 'help': + 'The evaluation type(key: The evaluation hook dict in train.hooks), ' + 'valid options: "EvaluationHook", "None"' + }) + + evaluation_by_epoch: Optional[bool] = field( + default=None, + metadata={ + 'help': + 'Evaluating by epoch or not(key: The by_epoch key in ' + 'evaluation hook dict in train.hooks)' + }) + + evaluation_interval: Optional[int] = field( + default=None, + metadata={ + 'help': + 'The evaluating interval(key: The interval key in ' + 'evaluation hook dict in train.hooks)' + }) + + metrics: Optional[List[str]] = field( + default=None, + metadata={'help': 'The metrics class keys(key: evaluation.metrics)'}) + + default_train_config = ConfigDict({ + 'work_dir': + '/tmp', + 'max_epochs': + 5, + 'dataloader': { + 'batch_size_per_gpu': 32, + 'workers_per_gpu': 0 + }, + 'optimizer': { + 'type': 'AdamW', + 'lr': 2e-5, + 'options': {} + }, + 'lr_scheduler': { + 'type': 'LinearLR', + 'start_factor': 1.0, + 'end_factor': 0.0, + 'total_iters': 10000, + 'options': { + 'by_epoch': False + } + }, + 'hooks': [{ + 'type': 'CheckpointHook', + 'by_epoch': False, + 'interval': 100 + }, { + 'type': 'TextLoggerHook', + 'interval': 1 + }, { + 'type': 'IterTimerHook' + }, { + 'type': 'EvaluationHook', + 'by_epoch': False, + 'interval': 100 + }] + }) + + def __call__(self, cfg): + """ + + Args: + cfg(`Config`): The cfg to be modified. + + Returns: + The cfg after modification. + """ + + if self.task is not None: + cfg.task = self.task + + if self.preprocessor_type is not None: + if not hasattr(cfg, 'preprocessor'): + cfg.preprocessor = ConfigDict() + cfg.preprocessor.type = self.preprocessor_type + + if self.train_first_sequence is not None or self.train_second_sequence \ + is not None or self.train_label is not None or self.labels is not None: + if not hasattr(cfg, 'dataset'): + cfg.dataset = ConfigDict() + if not hasattr(cfg.dataset, 'train'): + cfg.dataset.train = ConfigDict() + if self.train_first_sequence is not None: + cfg.dataset.train.first_sequence = self.train_first_sequence + if self.train_second_sequence is not None: + cfg.dataset.train.second_sequence = self.train_second_sequence + if self.train_label is not None: + cfg.dataset.train.label = self.train_label + if self.labels is not None: + cfg.dataset.train.labels = self.labels + + if self.eval_first_sequence is not None or self.eval_second_sequence \ + is not None or self.eval_label is not None: + if not hasattr(cfg, 'dataset'): + cfg.dataset = ConfigDict() + if not hasattr(cfg.dataset, 'val'): + cfg.dataset.val = ConfigDict() + if self.eval_first_sequence is not None: + cfg.dataset.val.first_sequence = self.eval_first_sequence + if self.eval_second_sequence is not None: + cfg.dataset.val.second_sequence = self.eval_second_sequence + if self.eval_label is not None: + cfg.dataset.val.label = self.eval_label + + if self.max_epochs is not None or self.train_batch_size_per_gpu is not None \ + or self.train_shuffle is not None or self.optimizer_args is not None \ + or self.work_dir is not None or self.lr_scheduler_args is not None\ + or self.train_workers_per_gpu is not None: + if not hasattr(cfg, 'train'): + cfg.train = deepcopy(self.default_train_config) + if not hasattr(cfg.train, 'dataloader'): + cfg.train.dataloader = deepcopy( + self.default_train_config.dataloader) + if not hasattr(cfg.train, 'optimizer'): + cfg.train.optimizer = deepcopy( + self.default_train_config.optimizer) + if not hasattr(cfg.train, 'lr_scheduler'): + cfg.train.lr_scheduler = deepcopy( + self.default_train_config.lr_scheduler) + if self.work_dir is not None: + cfg.train.work_dir = self.work_dir + if self.max_epochs is not None: + cfg.train.max_epochs = self.max_epochs + if self.train_batch_size_per_gpu is not None: + cfg.train.dataloader.batch_size_per_gpu = self.train_batch_size_per_gpu + if self.train_workers_per_gpu is not None: + cfg.train.dataloader.workers_per_gpu = self.train_workers_per_gpu + if self.train_shuffle is not None: + cfg.train.dataloader.shuffle = self.train_shuffle + if self.optimizer_args is not None: + if cfg.train.optimizer.type != self.optimizer_args.get( + 'type', cfg.train.optimizer.type): + cfg.train.optimizer = ConfigDict( + deepcopy(self.optimizer_args)) + else: + cfg.train.optimizer = Config._merge_a_into_b( + self.optimizer_args, cfg.train.optimizer, force=True) + if self.lr_scheduler_args is not None: + if cfg.train.lr_scheduler.type != self.lr_scheduler_args.get( + 'type', cfg.train.lr_scheduler.type): + cfg.train.lr_scheduler = ConfigDict( + deepcopy(self.lr_scheduler_args)) + else: + cfg.train.lr_scheduler = Config._merge_a_into_b( + self.lr_scheduler_args, + cfg.train.lr_scheduler, + force=True) + + if self.checkpoint_saving_type is not None or self.checkpoint_by_epoch is not None \ + or self.checkpoint_interval is not None or self.metric_key is not None: + if not any([ + self.checkpoint_saving_type == hook['type'] + for hook in cfg.train.hooks + ]): + cfg.train.hooks = list( + filter( + lambda hook: hook['type'] not in + ['CheckpointHook', 'BestCkptSaverHook'], + cfg.train.hooks)) + cfg.train.hooks.append( + deepcopy(self.default_train_config.hooks[0])) + cfg.train.hooks[-1].type = self.checkpoint_saving_type + checkpoint_hook = list( + filter( + lambda hook: hook[ + 'type'] in ['CheckpointHook', 'BestCkptSaverHook'], + cfg.train.hooks))[0] + if self.checkpoint_by_epoch is not None: + checkpoint_hook['by_epoch'] = self.checkpoint_by_epoch + if self.checkpoint_interval is not None: + checkpoint_hook['interval'] = self.checkpoint_interval + if checkpoint_hook['type'] == 'BestCkptSaverHook': + assert self.metric_key is not None, 'The metric_key must be provided ' \ + 'if the ckpt saving hook is "BestCkptSaverHook"' + checkpoint_hook['metric_key'] = self.metric_key + + if self.evaluation_type is not None or self.evaluation_by_epoch is not None \ + or self.evaluation_interval is not None or self.eval_batch_size_per_gpu is not None or \ + self.eval_shuffle is not None or self.metrics is not None: + if self.evaluation_type is not None and not any([ + self.evaluation_type == hook['type'] + for hook in cfg.train.hooks + ]): + cfg.train.hooks = list( + filter(lambda hook: hook['type'] not in ['EvaluationHook'], + cfg.train.hooks)) + if self.evaluation_type != 'None': + cfg.train.hooks.append( + deepcopy(self.default_train_config.hooks[3])) + cfg.train.hooks[-1].type = self.evaluation_type + + evaluation_hook = list( + filter(lambda hook: hook['type'] in ['EvaluationHook'], + cfg.train.hooks)) + evaluation_hook = evaluation_hook[0] if len( + evaluation_hook) > 0 else None + + if evaluation_hook is not None and self.evaluation_by_epoch is not None: + evaluation_hook['by_epoch'] = self.evaluation_by_epoch + if evaluation_hook is not None and self.evaluation_interval is not None: + evaluation_hook['interval'] = self.evaluation_interval + + if not hasattr(cfg, 'evaluation'): + cfg.evaluation = ConfigDict({ + 'dataloader': { + 'batch_size_per_gpu': 32, + 'workers_per_gpu': 0, + 'shuffle': False + } + }) + + if self.metrics is not None: + cfg.evaluation.metrics = self.metrics + if self.eval_batch_size_per_gpu is not None: + cfg.evaluation.dataloader.batch_size_per_gpu = self.eval_batch_size_per_gpu + if self.eval_workers_per_gpu is not None: + cfg.evaluation.dataloader.workers_per_gpu = self.eval_workers_per_gpu + if self.eval_shuffle is not None: + cfg.evaluation.dataloader.shuffle = self.eval_shuffle + + return cfg + + @TRAINERS.register_module(module_name=Trainers.nlp_base_trainer) class NlpEpochBasedTrainer(EpochBasedTrainer): @@ -80,9 +483,10 @@ class NlpEpochBasedTrainer(EpochBasedTrainer): model) else: model_dir = snapshot_download(model, revision=model_revision) - cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) + if cfg_file is None: + cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) else: - assert cfg_file is not None, 'Config file should not be None if model is an nn.Module class' + assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!' model_dir = os.path.dirname(cfg_file) self.label2id = None @@ -91,26 +495,17 @@ class NlpEpochBasedTrainer(EpochBasedTrainer): self.cfg_modify_fn = cfg_modify_fn self.cfg = self.rebuild_config(Config.from_file(cfg_file)) - label2id = parse_label_mapping(model_dir) - if label2id is not None: - self.label2id = label2id - self.id2label = {id: label for label, id in label2id.items()} - self.num_labels = len(label2id) - else: - try: - labels = self.cfg.dataset.train.labels - if labels is not None and len(labels) > 0: - self.label2id = { - label: idx - for idx, label in enumerate(labels) - } - self.id2label = { - idx: label - for idx, label in enumerate(labels) - } - self.num_labels = len(labels) - except AttributeError: - pass + try: + labels = self.cfg.dataset.train.labels + self.label2id = {label: idx for idx, label in enumerate(labels)} + self.id2label = {idx: label for idx, label in enumerate(labels)} + self.num_labels = len(labels) + except AttributeError: + label2id = parse_label_mapping(model_dir) + if label2id is not None: + self.label2id = label2id + self.id2label = {id: label for label, id in label2id.items()} + self.num_labels = len(label2id) def build_dataset_keys(cfg): if cfg is not None: @@ -185,36 +580,20 @@ class NlpEpochBasedTrainer(EpochBasedTrainer): 'label2id': self.label2id } - field_name = Tasks.find_field_by_task(self.cfg.task) - train_preprocessor, eval_preprocessor = None, None - _train_cfg, _eval_cfg = {}, {} - - if 'type' not in self.cfg.preprocessor and ( - 'train' in self.cfg.preprocessor - or 'val' in self.cfg.preprocessor): - if 'train' in self.cfg.preprocessor: - _train_cfg = self.cfg.preprocessor.train - if 'val' in self.cfg.preprocessor: - _eval_cfg = self.cfg.preprocessor.val - else: - _train_cfg = self.cfg.preprocessor - _eval_cfg = self.cfg.preprocessor - - if len(_train_cfg): - _train_cfg.update({ - 'model_dir': self.model_dir, - **model_args, - **self.train_keys, 'mode': ModeKeys.TRAIN - }) - train_preprocessor = build_preprocessor(_train_cfg, field_name) - if len(_eval_cfg): - _eval_cfg.update({ - 'model_dir': self.model_dir, - **model_args, - **self.eval_keys, 'mode': ModeKeys.EVAL - }) - eval_preprocessor = build_preprocessor(_eval_cfg, field_name) - + train_preprocessor = Preprocessor.from_pretrained( + self.model_dir, + cfg_dict=self.cfg, + preprocessor_mode=ModeKeys.TRAIN, + **model_args, + **self.train_keys, + mode=ModeKeys.TRAIN) + eval_preprocessor = Preprocessor.from_pretrained( + self.model_dir, + cfg_dict=self.cfg, + preprocessor_mode=ModeKeys.EVAL, + **model_args, + **self.eval_keys, + mode=ModeKeys.EVAL) return train_preprocessor, eval_preprocessor diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index a01d9b59..605136e5 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -4,7 +4,7 @@ import time from collections.abc import Mapping from distutils.version import LooseVersion from functools import partial -from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import json import torch @@ -22,18 +22,18 @@ from modelscope.msdatasets.ms_dataset import MsDataset from modelscope.msdatasets.task_datasets.builder import build_task_dataset from modelscope.msdatasets.task_datasets.torch_base_dataset import \ TorchTaskDataset +from modelscope.outputs import ModelOutputBase from modelscope.preprocessors.base import Preprocessor -from modelscope.preprocessors.builder import build_preprocessor from modelscope.trainers.hooks.builder import HOOKS from modelscope.trainers.hooks.priority import Priority, get_priority from modelscope.trainers.lrscheduler.builder import build_lr_scheduler from modelscope.trainers.optimizer.builder import build_optimizer from modelscope.utils.config import Config, ConfigDict from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields, - ConfigKeys, Hubs, ModeKeys, ModelFile, - Tasks, TrainerStages) + ConfigKeys, ModeKeys, ModelFile, + TrainerStages) from modelscope.utils.data_utils import to_device -from modelscope.utils.device import create_device, verify_device +from modelscope.utils.device import create_device from modelscope.utils.file_utils import func_receive_dict_inputs from modelscope.utils.logger import get_logger from modelscope.utils.registry import build_from_cfg @@ -146,7 +146,8 @@ class EpochBasedTrainer(BaseTrainer): if ConfigKeys.val in preprocessor: assert isinstance(preprocessor[ConfigKeys.val], Preprocessor) self.eval_preprocessor = preprocessor[ConfigKeys.val] - elif hasattr(self.cfg, ConfigFields.preprocessor): + elif hasattr(self.cfg, ConfigFields.preprocessor + ) and self.cfg.preprocessor is not None: self.train_preprocessor, self.eval_preprocessor = self.build_preprocessor( ) @@ -167,19 +168,20 @@ class EpochBasedTrainer(BaseTrainer): device_name = f'cuda:{local_rank}' self.device = create_device(device_name) - self.train_dataset = self.to_task_dataset( train_dataset, mode=ModeKeys.TRAIN, task_data_config=self.cfg.dataset.get('train', None) if hasattr( self.cfg, 'dataset') else None, - preprocessor=self.train_preprocessor) + preprocessor=self.train_preprocessor, + **kwargs) self.eval_dataset = self.to_task_dataset( eval_dataset, mode=ModeKeys.EVAL, task_data_config=self.cfg.dataset.get('val', None) if hasattr( self.cfg, 'dataset') else None, - preprocessor=self.eval_preprocessor) + preprocessor=self.eval_preprocessor, + **kwargs) self.train_data_collator, self.eval_default_collate = None, None if isinstance(data_collator, Mapping): @@ -215,7 +217,6 @@ class EpochBasedTrainer(BaseTrainer): self._max_epochs = self.cfg.train.max_epochs else: self._max_epochs = kwargs['max_epochs'] - self._train_iters_per_epoch = kwargs.get('train_iters_per_epoch', None) self._eval_iters_per_epoch = kwargs.get('val_iters_per_epoch', None) if self._train_iters_per_epoch is None and hasattr( @@ -305,13 +306,15 @@ class EpochBasedTrainer(BaseTrainer): datasets: Union[Dataset, List[Dataset]], mode: str, task_data_config: Config = None, - preprocessor: Optional[Preprocessor] = None): + preprocessor: Optional[Preprocessor] = None, + **kwargs): """Build the task specific dataset processor for this trainer. Returns: The task dataset processor for the task. If no result for the very model-type and task, the default TaskDataset will be returned. """ try: + to_tensor = kwargs.get('to_tensor', True) if not datasets: return datasets if isinstance(datasets, TorchTaskDataset): @@ -327,7 +330,8 @@ class EpochBasedTrainer(BaseTrainer): return datasets.to_torch_dataset( task_data_config=task_data_config, task_name=self.cfg.task, - preprocessors=preprocessor) + preprocessors=preprocessor, + to_tensor=to_tensor) elif isinstance(datasets, List) and isinstance( datasets[0], MsDataset): if task_data_config is None: @@ -341,26 +345,39 @@ class EpochBasedTrainer(BaseTrainer): d.to_torch_dataset( task_data_config=task_data_config, task_name=self.cfg.task, - preprocessors=preprocessor) for d in datasets + preprocessors=preprocessor, + to_tensor=to_tensor) for d in datasets ] cfg = ConfigDict( - type=self.cfg.task, mode=mode, datasets=datasets) - return build_task_dataset(cfg, self.cfg.task) + type=self.cfg.model.type, mode=mode, datasets=datasets) + task_dataset = build_task_dataset(cfg, self.cfg.task) + task_dataset.trainer = self + return task_dataset else: + if task_data_config is None: + # adapt to some special models + task_data_config = {} # avoid add no str value datasets, preprocessors in cfg task_data_build_config = ConfigDict( - mode=mode, datasets=datasets, preprocessor=preprocessor) + type=self.cfg.model.type, + mode=mode, + datasets=datasets, + preprocessor=preprocessor) task_data_build_config.update(task_data_config) - return build_task_dataset(task_data_build_config, - self.cfg.task) + task_dataset = build_task_dataset(task_data_build_config, + self.cfg.task) + task_dataset.trainer = self + return task_dataset except Exception: if isinstance(datasets, (List, Tuple)) or preprocessor is not None: - return TorchTaskDataset( + task_dataset = TorchTaskDataset( datasets, mode=mode, preprocessor=preprocessor, **(dict(type=self.cfg.model.type) if hasattr( self.cfg, 'model') else {})) + task_dataset.trainer = self + return task_dataset else: return datasets @@ -372,35 +389,12 @@ class EpochBasedTrainer(BaseTrainer): Returns: The train preprocessor and eval preprocessor instance. """ - field_name = Tasks.find_field_by_task(self.cfg.task) - train_preprocessor, eval_preprocessor = None, None - _train_cfg, _eval_cfg = {}, {} - _dafault_args = {'model_dir': self.model_dir} - - if 'type' not in self.cfg.preprocessor and ( - 'train' in self.cfg.preprocessor - or 'val' in self.cfg.preprocessor): - if 'train' in self.cfg.preprocessor: - _train_cfg = self.cfg.preprocessor.train - if 'val' in self.cfg.preprocessor: - _eval_cfg = self.cfg.preprocessor.val - else: - _train_cfg = self.cfg.preprocessor - _eval_cfg = self.cfg.preprocessor - - if len(_train_cfg): - if isinstance(_train_cfg, Sequence): - # TODO: for Sequence, need adapt to `mode` and `mode_dir` args, - # and add mode for Compose or other plans - raise NotImplementedError('Not supported yet!') - _train_cfg.update(_dafault_args) - train_preprocessor = build_preprocessor(_train_cfg, field_name) - if len(_eval_cfg): - if isinstance(_eval_cfg, Sequence): - raise NotImplementedError('Not supported yet!') - _eval_cfg.update(_dafault_args) - eval_preprocessor = build_preprocessor(_eval_cfg, field_name) - + train_preprocessor = Preprocessor.from_pretrained( + self.model_dir, + cfg_dict=self.cfg, + preprocessor_mode=ModeKeys.TRAIN) + eval_preprocessor = Preprocessor.from_pretrained( + self.model_dir, cfg_dict=self.cfg, preprocessor_mode=ModeKeys.EVAL) return train_preprocessor, eval_preprocessor def get_metrics(self) -> List[Union[str, Dict]]: @@ -428,13 +422,17 @@ class EpochBasedTrainer(BaseTrainer): return metrics def set_checkpoint_file_to_hook(self, checkpoint_path): - if checkpoint_path is not None and os.path.isfile(checkpoint_path): - from modelscope.trainers.hooks import CheckpointHook - checkpoint_hooks = list( - filter(lambda hook: isinstance(hook, CheckpointHook), - self.hooks)) - for hook in checkpoint_hooks: - hook.checkpoint_file = checkpoint_path + if checkpoint_path is not None: + if os.path.isfile(checkpoint_path): + from modelscope.trainers.hooks import CheckpointHook + checkpoint_hooks = list( + filter(lambda hook: isinstance(hook, CheckpointHook), + self.hooks)) + for hook in checkpoint_hooks: + hook.checkpoint_file = checkpoint_path + else: + self.logger.error( + f'No {checkpoint_path} found in local file system.') def train(self, checkpoint_path=None, *args, **kwargs): self._mode = ModeKeys.TRAIN @@ -510,6 +508,7 @@ class EpochBasedTrainer(BaseTrainer): dp_cfg = dict( type='DistributedDataParallel', module=model, + find_unused_parameters=True, device_ids=[torch.cuda.current_device()]) return build_parallel(dp_cfg) @@ -547,6 +546,8 @@ class EpochBasedTrainer(BaseTrainer): else: train_outputs = model.forward(inputs) + if isinstance(train_outputs, ModelOutputBase): + train_outputs = train_outputs.to_dict() if not isinstance(train_outputs, dict): raise TypeError('"model.forward()" must return a dict') @@ -650,8 +651,9 @@ class EpochBasedTrainer(BaseTrainer): """ # TODO: support MsDataset load for cv if hasattr(data_cfg, 'name'): + dataset_name = data_cfg.pop('name') dataset = MsDataset.load( - dataset_name=data_cfg.name, + dataset_name=dataset_name, **data_cfg, ) cfg = ConfigDict(type=self.cfg.model.type, mode=mode) @@ -664,6 +666,12 @@ class EpochBasedTrainer(BaseTrainer): dataset = self.to_task_dataset(torch_dataset, mode) return dataset + def build_optimizer(self, cfg: ConfigDict, default_args: dict = None): + return build_optimizer(self.model, cfg=cfg, default_args=default_args) + + def build_lr_scheduler(self, cfg: ConfigDict, default_args: dict = None): + return build_lr_scheduler(cfg=cfg, default_args=default_args) + def create_optimizer_and_scheduler(self): """ Create optimizer and lr scheduler @@ -680,7 +688,7 @@ class EpochBasedTrainer(BaseTrainer): optim_options = {} if optimizer_cfg is not None: optim_options = optimizer_cfg.pop('options', {}) - optimizer = build_optimizer(self.model, cfg=optimizer_cfg) + optimizer = self.build_optimizer(cfg=optimizer_cfg) if lr_scheduler is None: lr_scheduler_cfg = self.cfg.train.get('lr_scheduler', None) @@ -691,7 +699,7 @@ class EpochBasedTrainer(BaseTrainer): if lr_scheduler_cfg is not None: assert optimizer is not None lr_options = lr_scheduler_cfg.pop('options', {}) - lr_scheduler = build_lr_scheduler( + lr_scheduler = self.build_lr_scheduler( cfg=lr_scheduler_cfg, default_args={'optimizer': optimizer}) self.optimizer = optimizer @@ -783,7 +791,7 @@ class EpochBasedTrainer(BaseTrainer): batch_size = batch_size_per_gpu num_workers = workers_per_gpu - if dist: + if dist and not isinstance(dataset, torch.utils.data.IterableDataset): sampler = DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=shuffle) else: @@ -822,7 +830,6 @@ class EpochBasedTrainer(BaseTrainer): self.model.train() for _ in range(self._epoch, self._max_epochs): self.invoke_hook(TrainerStages.before_train_epoch) - time.sleep(2) # Prevent possible deadlock during epoch transition for i, data_batch in enumerate(data_loader): if i < self.inner_iter: # inner_iter may be read out from the checkpoint file, so skip the trained iters in the epoch. @@ -846,7 +853,6 @@ class EpochBasedTrainer(BaseTrainer): self._inner_iter = 0 self._epoch += 1 - time.sleep(1) # wait for some hooks like loggers to finish self.invoke_hook(TrainerStages.after_run) def evaluation_loop(self, data_loader, metric_classes): diff --git a/modelscope/trainers/utils/inference.py b/modelscope/trainers/utils/inference.py index 7f5d4ec3..1f8f8ed0 100644 --- a/modelscope/trainers/utils/inference.py +++ b/modelscope/trainers/utils/inference.py @@ -69,7 +69,10 @@ def single_gpu_test(model, batch_size = 1 # iteration count else: if isinstance(data, dict): - batch_size = len(next(iter(data.values()))) + if 'nsentences' in data: + batch_size = data['nsentences'] + else: + batch_size = len(next(iter(data.values()))) else: batch_size = len(data) for _ in range(batch_size): @@ -152,21 +155,29 @@ def multi_gpu_test(model, result = model.forward(data) results.append(result) - if rank == 0: - if isinstance(data, dict): - batch_size = len(next(iter(data.values()))) + if isinstance(data, dict): + if 'nsentences' in data: + batch_size = data['nsentences'] else: - batch_size = len(data) - - if progress_with_iters: - total_samples += batch_size * world_size - batch_size = 1 # iteration count + batch_size = len(next(iter(data.values()))) + else: + batch_size = len(data) + if i >= (data_len // world_size) - 1: + total_samples = torch.LongTensor([batch_size]).to(model.device) + dist.all_reduce(total_samples, op=dist.reduce_op.SUM) + total_samples = total_samples.item() + else: + total_samples = batch_size * world_size + if progress_with_iters: + iter_cnt_all = world_size + else: + iter_cnt_all = total_samples + count += iter_cnt_all - batch_size_all = batch_size * world_size - count += batch_size_all + if rank == 0: if count > data_len: - batch_size_all = data_len - (count - batch_size_all) - for _ in range(batch_size_all): + iter_cnt_all = data_len - (count - iter_cnt_all) + for _ in range(iter_cnt_all): pbar.update() if progress_with_iters and (i + 1) >= data_len: diff --git a/modelscope/utils/audio/audio_utils.py b/modelscope/utils/audio/audio_utils.py index 4c2c45cc..32e2fa54 100644 --- a/modelscope/utils/audio/audio_utils.py +++ b/modelscope/utils/audio/audio_utils.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import re import struct from typing import Union from urllib.parse import urlparse @@ -37,8 +38,26 @@ def audio_norm(x): return x +def update_conf(origin_config_file, new_config_file, conf_item: [str, str]): + + def repl(matched): + key = matched.group(1) + if key in conf_item: + return conf_item[key] + else: + return None + + with open(origin_config_file) as f: + lines = f.readlines() + with open(new_config_file, 'w') as f: + for line in lines: + line = re.sub(r'\$\{(.*)\}', repl, line) + f.write(line) + + def extract_pcm_from_wav(wav: bytes) -> bytes: data = wav + sample_rate = None if len(data) > 44: frame_len = 44 file_len = len(data) @@ -52,29 +71,33 @@ def extract_pcm_from_wav(wav: bytes) -> bytes: 'Subchunk1ID'] == 'fmt ': header_fields['SubChunk1Size'] = struct.unpack( ' Union[bytes, str]: + sample_rate = None result = urlparse(url) if result.scheme is not None and len(result.scheme) > 0: storage = HTTPStorage() data = storage.read(url) - data = extract_pcm_from_wav(data) + data, sample_rate = extract_pcm_from_wav(data) else: data = url - return data + return data, sample_rate diff --git a/modelscope/utils/checkpoint.py b/modelscope/utils/checkpoint.py index a9d7f396..2a7520f2 100644 --- a/modelscope/utils/checkpoint.py +++ b/modelscope/utils/checkpoint.py @@ -207,6 +207,6 @@ def save_pretrained(model, # Dump the config to the configuration.json if ConfigFields.pipeline not in config: config[ConfigFields.pipeline] = {'type': config[ConfigFields.task]} - cfg_str = json.dumps(config, cls=JSONIteratorEncoder) + cfg_str = json.dumps(config, indent=4, cls=JSONIteratorEncoder) config_file = os.path.join(target_folder, ModelFile.CONFIGURATION) storage.write(cfg_str.encode(), config_file) diff --git a/modelscope/utils/config.py b/modelscope/utils/config.py index 0b966bef..e46da7df 100644 --- a/modelscope/utils/config.py +++ b/modelscope/utils/config.py @@ -609,11 +609,12 @@ class Config: return parse_fn(args) -def check_config(cfg: Union[str, ConfigDict]): +def check_config(cfg: Union[str, ConfigDict], is_training=False): """ Check whether configuration file is valid, If anything wrong, exception will be raised. Args: cfg (str or ConfigDict): Config file path or config object. + is_training: indicate if checking training related elements """ if isinstance(cfg, str): @@ -627,12 +628,23 @@ def check_config(cfg: Union[str, ConfigDict]): check_attr(ConfigFields.task) check_attr(ConfigFields.pipeline) - if hasattr(cfg, ConfigFields.train): + if is_training: check_attr(ConfigFields.model) + check_attr(ConfigFields.train) check_attr(ConfigFields.preprocessor) check_attr(ConfigFields.evaluation) +def use_task_specific_params(model, task): + """Update config with summarization specific params.""" + task_specific_params = model.config.task_specific_params + + if task_specific_params is not None: + pars = task_specific_params.get(task, {}) + logger.info(f'using task specific params for {task}: {pars}') + model.config.update(pars) + + class JSONIteratorEncoder(json.JSONEncoder): """Implement this method in order that supporting arbitrary iterators, it returns a serializable object for ``obj``, or calls the base implementation diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 7968fcd1..6394ad8a 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -9,6 +9,7 @@ class Fields(object): nlp = 'nlp' audio = 'audio' multi_modal = 'multi-modal' + science = 'science' class CVTasks(object): @@ -19,6 +20,7 @@ class CVTasks(object): # human face body related animal_recognition = 'animal-recognition' face_detection = 'face-detection' + card_detection = 'card-detection' face_recognition = 'face-recognition' facial_expression_recognition = 'facial-expression-recognition' face_2d_keypoints = 'face-2d-keypoints' @@ -29,6 +31,7 @@ class CVTasks(object): body_3d_keypoints = 'body-3d-keypoints' hand_2d_keypoints = 'hand-2d-keypoints' general_recognition = 'general-recognition' + human_wholebody_keypoint = 'human-wholebody-keypoint' image_classification = 'image-classification' image_multilabel_classification = 'image-multilabel-classification' @@ -36,6 +39,7 @@ class CVTasks(object): image_classification_dailylife = 'image-classification-dailylife' image_object_detection = 'image-object-detection' + video_object_detection = 'video-object-detection' image_segmentation = 'image-segmentation' semantic_segmentation = 'semantic-segmentation' @@ -47,6 +51,8 @@ class CVTasks(object): face_emotion = 'face-emotion' product_segmentation = 'product-segmentation' + crowd_counting = 'crowd-counting' + # image editing skin_retouching = 'skin-retouching' image_super_resolution = 'image-super-resolution' @@ -54,13 +60,14 @@ class CVTasks(object): image_color_enhancement = 'image-color-enhancement' image_denoising = 'image-denoising' image_portrait_enhancement = 'image-portrait-enhancement' + image_inpainting = 'image-inpainting' # image generation image_to_image_translation = 'image-to-image-translation' image_to_image_generation = 'image-to-image-generation' image_style_transfer = 'image-style-transfer' image_portrait_stylization = 'image-portrait-stylization' - + image_body_reshaping = 'image-body-reshaping' image_embedding = 'image-embedding' product_retrieval_embedding = 'product-retrieval-embedding' @@ -72,9 +79,11 @@ class CVTasks(object): video_category = 'video-category' video_embedding = 'video-embedding' virtual_try_on = 'virtual-try-on' - crowd_counting = 'crowd-counting' movie_scene_segmentation = 'movie-scene-segmentation' + # video segmentation + referring_video_object_segmentation = 'referring-video-object-segmentation' + # video editing video_inpainting = 'video-inpainting' @@ -95,7 +104,7 @@ class NLPTasks(object): sentence_similarity = 'sentence-similarity' text_classification = 'text-classification' sentence_embedding = 'sentence-embedding' - passage_ranking = 'passage-ranking' + text_ranking = 'text-ranking' relation_extraction = 'relation-extraction' zero_shot = 'zero-shot' translation = 'translation' @@ -107,15 +116,13 @@ class NLPTasks(object): dialog_intent_prediction = 'dialog-intent-prediction' dialog_state_tracking = 'dialog-state-tracking' table_question_answering = 'table-question-answering' - sentence_embedding = 'sentence-embedding' fill_mask = 'fill-mask' - summarization = 'summarization' + text_summarization = 'text-summarization' question_answering = 'question-answering' zero_shot_classification = 'zero-shot-classification' backbone = 'backbone' text_error_correction = 'text-error-correction' faq_question_answering = 'faq-question-answering' - conversational_text_to_sql = 'conversational-text-to-sql' information_extraction = 'information-extraction' document_segmentation = 'document-segmentation' feature_extraction = 'feature-extraction' @@ -145,6 +152,10 @@ class MultiModalTasks(object): image_text_retrieval = 'image-text-retrieval' +class ScienceTasks(object): + protein_structure = 'protein-structure' + + class TasksIODescriptions(object): image_to_image = 'image_to_image', images_to_image = 'images_to_image', @@ -161,7 +172,7 @@ class TasksIODescriptions(object): generative_multi_modal_embedding = 'generative_multi_modal_embedding' -class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks): +class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks, ScienceTasks): """ Names for tasks supported by modelscope. Holds the standard task name to use for identifying different tasks. @@ -190,6 +201,10 @@ class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks): getattr(Tasks, attr) for attr in dir(MultiModalTasks) if not attr.startswith('__') ], + Fields.science: [ + getattr(Tasks, attr) for attr in dir(ScienceTasks) + if not attr.startswith('__') + ], } for field, tasks in field_dict.items(): @@ -267,12 +282,14 @@ class ConfigFields(object): preprocessor = 'preprocessor' train = 'train' evaluation = 'evaluation' + postprocessor = 'postprocessor' class ConfigKeys(object): """Fixed keywords in configuration file""" train = 'train' val = 'val' + test = 'test' class Requirements(object): @@ -294,7 +311,9 @@ class Frameworks(object): kaldi = 'kaldi' -DEFAULT_MODEL_REVISION = 'master' +DEFAULT_MODEL_REVISION = None +MASTER_MODEL_BRANCH = 'master' +DEFAULT_REPOSITORY_REVISION = 'master' DEFAULT_DATASET_REVISION = 'master' DEFAULT_DATASET_NAMESPACE = 'modelscope' diff --git a/modelscope/utils/cv/image_utils.py b/modelscope/utils/cv/image_utils.py index 98ba533e..34dc2348 100644 --- a/modelscope/utils/cv/image_utils.py +++ b/modelscope/utils/cv/image_utils.py @@ -80,7 +80,7 @@ def realtime_object_detection_bbox_vis(image, bboxes): def draw_keypoints(output, original_image): - poses = np.array(output[OutputKeys.POSES]) + poses = np.array(output[OutputKeys.KEYPOINTS]) scores = np.array(output[OutputKeys.SCORES]) boxes = np.array(output[OutputKeys.BOXES]) assert len(poses) == len(scores) and len(poses) == len(boxes) @@ -113,12 +113,9 @@ def draw_face_detection_no_lm_result(img_path, detection_result): def draw_facial_expression_result(img_path, facial_expression_result): - label_idx = facial_expression_result[OutputKeys.LABELS] - map_list = [ - 'Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral' - ] - label = map_list[label_idx] - + scores = facial_expression_result[OutputKeys.SCORES] + labels = facial_expression_result[OutputKeys.LABELS] + label = labels[np.argmax(scores)] img = cv2.imread(img_path) assert img is not None, f"Can't read img: {img_path}" cv2.putText( @@ -157,6 +154,54 @@ def draw_face_detection_result(img_path, detection_result): return img +def draw_card_detection_result(img_path, detection_result): + + def warp_img(src_img, kps, ratio): + short_size = 500 + if ratio > 1: + obj_h = short_size + obj_w = int(obj_h * ratio) + else: + obj_w = short_size + obj_h = int(obj_w / ratio) + input_pts = np.float32([kps[0], kps[1], kps[2], kps[3]]) + output_pts = np.float32([[0, obj_h - 1], [0, 0], [obj_w - 1, 0], + [obj_w - 1, obj_h - 1]]) + M = cv2.getPerspectiveTransform(input_pts, output_pts) + obj_img = cv2.warpPerspective(src_img, M, (obj_w, obj_h)) + return obj_img + + bboxes = np.array(detection_result[OutputKeys.BOXES]) + kpss = np.array(detection_result[OutputKeys.KEYPOINTS]) + scores = np.array(detection_result[OutputKeys.SCORES]) + img_list = [] + ver_col = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (0, 255, 255)] + img = cv2.imread(img_path) + img_list += [img] + assert img is not None, f"Can't read img: {img_path}" + for i in range(len(scores)): + bbox = bboxes[i].astype(np.int32) + kps = kpss[i].reshape(-1, 2).astype(np.int32) + _w = (kps[0][0] - kps[3][0])**2 + (kps[0][1] - kps[3][1])**2 + _h = (kps[0][0] - kps[1][0])**2 + (kps[0][1] - kps[1][1])**2 + ratio = 1.59 if _w >= _h else 1 / 1.59 + card_img = warp_img(img, kps, ratio) + img_list += [card_img] + score = scores[i] + x1, y1, x2, y2 = bbox + cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 4) + for k, kp in enumerate(kps): + cv2.circle(img, tuple(kp), 1, color=ver_col[k], thickness=10) + cv2.putText( + img, + f'{score:.2f}', (x1, y2), + 1, + 1.0, (0, 255, 0), + thickness=1, + lineType=8) + return img_list + + def created_boxed_image(image_in, box): image = load_image(image_in) img = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) @@ -186,6 +231,66 @@ def show_video_tracking_result(video_in_path, bboxes, video_save_path): cap.release() +def show_video_object_detection_result(video_in_path, bboxes_list, labels_list, + video_save_path): + + PALETTE = { + 'person': [128, 0, 0], + 'bicycle': [128, 128, 0], + 'car': [64, 0, 0], + 'motorcycle': [0, 128, 128], + 'bus': [64, 128, 0], + 'truck': [192, 128, 0], + 'traffic light': [64, 0, 128], + 'stop sign': [192, 0, 128], + } + from tqdm import tqdm + import math + cap = cv2.VideoCapture(video_in_path) + with tqdm(total=len(bboxes_list)) as pbar: + pbar.set_description( + 'Writing results to video: {}'.format(video_save_path)) + for i in range(len(bboxes_list)): + bboxes = bboxes_list[i].astype(int) + labels = labels_list[i] + success, frame = cap.read() + if success is False: + raise Exception(video_in_path, + ' can not be correctly decoded by OpenCV.') + if i == 0: + size = (frame.shape[1], frame.shape[0]) + fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') + video_writer = cv2.VideoWriter(video_save_path, fourcc, + cap.get(cv2.CAP_PROP_FPS), size, + True) + + FONT_SCALE = 1e-3 # Adjust for larger font size in all images + THICKNESS_SCALE = 1e-3 # Adjust for larger thickness in all images + TEXT_Y_OFFSET_SCALE = 1e-2 # Adjust for larger Y-offset of text and bounding box + H, W, _ = frame.shape + zeros_mask = np.zeros((frame.shape)).astype(np.uint8) + for bbox, l in zip(bboxes, labels): + cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), + PALETTE[l], 1) + cv2.putText( + frame, + l, (bbox[0], bbox[1] - int(TEXT_Y_OFFSET_SCALE * H)), + fontFace=cv2.FONT_HERSHEY_TRIPLEX, + fontScale=min(H, W) * FONT_SCALE, + thickness=math.ceil(min(H, W) * THICKNESS_SCALE), + color=PALETTE[l]) + zeros_mask = cv2.rectangle( + zeros_mask, (bbox[0], bbox[1]), (bbox[2], bbox[3]), + color=PALETTE[l], + thickness=-1) + + frame = cv2.addWeighted(frame, 1., zeros_mask, .65, 0) + video_writer.write(frame) + pbar.update(1) + video_writer.release + cap.release() + + def panoptic_seg_masks_to_image(masks): draw_img = np.zeros([masks[0].shape[0], masks[0].shape[1], 3]) from mmdet.core.visualization.palette import get_palette @@ -237,3 +342,35 @@ def show_video_summarization_result(video_in_path, result, video_save_path): video_writer.write(frame) video_writer.release() cap.release() + + +def show_image_object_detection_auto_result(img_path, + detection_result, + save_path=None): + scores = detection_result[OutputKeys.SCORES] + labels = detection_result[OutputKeys.LABELS] + bboxes = detection_result[OutputKeys.BOXES] + img = cv2.imread(img_path) + assert img is not None, f"Can't read img: {img_path}" + + for (score, label, box) in zip(scores, labels, bboxes): + cv2.rectangle(img, (int(box[0]), int(box[1])), + (int(box[2]), int(box[3])), (0, 0, 255), 2) + cv2.putText( + img, + f'{score:.2f}', (int(box[0]), int(box[1])), + 1, + 1.0, (0, 255, 0), + thickness=1, + lineType=8) + cv2.putText( + img, + label, (int((box[0] + box[2]) * 0.5), int(box[1])), + 1, + 1.0, (0, 255, 0), + thickness=1, + lineType=8) + + if save_path is not None: + cv2.imwrite(save_path, img) + return img diff --git a/modelscope/utils/device.py b/modelscope/utils/device.py index 33c0910d..83faa261 100644 --- a/modelscope/utils/device.py +++ b/modelscope/utils/device.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. - +import os from contextlib import contextmanager from modelscope.utils.constant import Devices, Frameworks @@ -61,8 +61,8 @@ def device_placement(framework, device_name='gpu:0'): if framework == Frameworks.tf: import tensorflow as tf if device_type == Devices.gpu and not tf.test.is_gpu_available(): - logger.warning( - 'tensorflow cuda is not available, using cpu instead.') + logger.debug( + 'tensorflow: cuda is not available, using cpu instead.') device_type = Devices.cpu if device_type == Devices.cpu: with tf.device('/CPU:0'): @@ -78,7 +78,8 @@ def device_placement(framework, device_name='gpu:0'): if torch.cuda.is_available(): torch.cuda.set_device(f'cuda:{device_id}') else: - logger.warning('cuda is not available, using cpu instead.') + logger.debug( + 'pytorch: cuda is not available, using cpu instead.') yield else: yield @@ -96,9 +97,7 @@ def create_device(device_name): if device_type == Devices.gpu: use_cuda = True if not torch.cuda.is_available(): - logger.warning( - 'cuda is not available, create gpu device failed, using cpu instead.' - ) + logger.info('cuda is not available, using cpu instead.') use_cuda = False if use_cuda: @@ -107,3 +106,17 @@ def create_device(device_name): device = torch.device('cpu') return device + + +def get_device(): + import torch + from torch import distributed as dist + if torch.cuda.is_available(): + if dist.is_available() and dist.is_initialized( + ) and 'LOCAL_RANK' in os.environ: + device_id = f"cuda:{os.environ['LOCAL_RANK']}" + else: + device_id = 'cuda:0' + else: + device_id = 'cpu' + return torch.device(device_id) diff --git a/modelscope/utils/error.py b/modelscope/utils/error.py index a6bbc8b3..a894063c 100644 --- a/modelscope/utils/error.py +++ b/modelscope/utils/error.py @@ -111,3 +111,12 @@ You can install it with pip on linux: On windows, please checkout the instructions on the installation page: https://github.com/facebookresearch/fairseq and follow the ones that match your environment. """ + +# docstyle-ignore +FASTTEXT_IMPORT_ERROR = """ +{0} requires the fasttext library but it was not found in your environment. +You can install it with pip on linux or mac: +`pip install fasttext` +Or you can checkout the instructions on the +installation page: https://github.com/facebookresearch/fastText and follow the ones that match your environment. +""" diff --git a/modelscope/utils/file_utils.py b/modelscope/utils/file_utils.py index 9b82f8d2..cf59dc57 100644 --- a/modelscope/utils/file_utils.py +++ b/modelscope/utils/file_utils.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import inspect +import os from pathlib import Path @@ -35,3 +36,10 @@ def get_default_cache_dir(): """ default_cache_dir = Path.home().joinpath('.cache', 'modelscope') return default_cache_dir + + +def read_file(path): + + with open(path, 'r') as f: + text = f.read() + return text diff --git a/modelscope/utils/hub.py b/modelscope/utils/hub.py index 2dbe7045..105b3ffa 100644 --- a/modelscope/utils/hub.py +++ b/modelscope/utils/hub.py @@ -82,7 +82,8 @@ def get_model_type(model_dir): this file does not exist, the method will try to get the 'model_type' field from the config.json. - @param model_dir: The local model dir to use. @return: The model type + Args: + model_dir: The local model dir to use. @return: The model type string, returns None if nothing is found. """ try: @@ -112,8 +113,11 @@ def parse_label_mapping(model_dir): 2. Try to read label-id mapping from the configuration.json 3. Try to read label-id mapping from the config.json - @param model_dir: The local model dir to use. - @return: The label2id mapping if found. + Args: + model_dir: The local model dir to use. + + Returns: + The label2id mapping if found. """ import json import os diff --git a/modelscope/utils/import_utils.py b/modelscope/utils/import_utils.py index 2a6fdc80..5db5ea98 100644 --- a/modelscope/utils/import_utils.py +++ b/modelscope/utils/import_utils.py @@ -292,6 +292,7 @@ REQUIREMENTS_MAAPING = OrderedDict([ ('decord', (is_package_available('decord'), DECORD_IMPORT_ERROR)), ('deepspeed', (is_package_available('deepspeed'), DEEPSPEED_IMPORT_ERROR)), ('fairseq', (is_package_available('fairseq'), FAIRSEQ_IMPORT_ERROR)), + ('fasttext', (is_package_available('fasttext'), FASTTEXT_IMPORT_ERROR)), ]) SYSTEM_PACKAGE = set(['os', 'sys', 'typing']) diff --git a/modelscope/utils/multi_modal/fp16/__init__.py b/modelscope/utils/multi_modal/fp16/__init__.py new file mode 100644 index 00000000..81250858 --- /dev/null +++ b/modelscope/utils/multi_modal/fp16/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .fp16 import FP16_Module, FP16_Optimizer diff --git a/modelscope/utils/multi_modal/fp16/fp16.py b/modelscope/utils/multi_modal/fp16/fp16.py new file mode 100755 index 00000000..37a80e65 --- /dev/null +++ b/modelscope/utils/multi_modal/fp16/fp16.py @@ -0,0 +1,655 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Stable version of apex FP16 Optimizer""" +import torch +from torch import nn +from torch.autograd import Variable +from torch.nn.parameter import Parameter + +from .fp16util import (master_params_to_model_params, + model_grads_to_master_grads) +from .loss_scaler import DynamicLossScaler, LossScaler + +FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) +HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) + + +def conversion_helper(val, conversion): + """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" + if not isinstance(val, (tuple, list)): + return conversion(val) + rtn = [conversion_helper(v, conversion) for v in val] + if isinstance(val, tuple): + rtn = tuple(rtn) + return rtn + + +def fp32_to_fp16(val): + """Convert fp32 `val` to fp16""" + + def half_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, FLOAT_TYPES): + val = val.half() + return val + + return conversion_helper(val, half_conversion) + + +def fp16_to_fp32(val): + """Convert fp16 `val` to fp32""" + + def float_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, HALF_TYPES): + val = val.float() + return val + + return conversion_helper(val, float_conversion) + + +class FP16_Module(nn.Module): + + def __init__(self, module): + super(FP16_Module, self).__init__() + self.add_module('module', module.half()) + + def forward(self, *inputs, **kwargs): + return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + return self.module.state_dict(destination, prefix, keep_vars) + + def load_state_dict(self, state_dict, strict=True): + self.module.load_state_dict(state_dict, strict=strict) + + +class FP16_Optimizer(object): + """ + :class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer, + and manage static or dynamic loss scaling and master weights in a manner transparent to the user. + For standard use, only two lines must be changed: creating the :class:`FP16_Optimizer` instance, + and changing the call to ``backward``. + + Example:: + + model = torch.nn.Linear(D_in, D_out).cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + # Name the FP16_Optimizer instance to replace the existing optimizer + # (recommended but not required): + optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) + ... + # loss.backward() becomes: + optimizer.backward(loss) + ... + + Example with dynamic loss scaling:: + + ... + optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) + # optional arg to control dynamic loss scaling behavior + # dynamic_loss_args={'scale_window' : 500}) + # Usually, dynamic_loss_args is not necessary. + + Args: + init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`. # noqa + static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate. # noqa + dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option. # noqa + dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used. # noqa + verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling. # noqa + + ``init_optimizer`` is expected to have been constructed in the ordinary way. + It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be + named to replace ``init_optimizer``, for two reasons: + First, it means that references to the same name + later in the file will not have to change. + Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to + modify ``init_optimizer``. If you do choose a unique name for the new + :class:`FP16_Optimizer` instance, you should only work with this new instance, + because the preexisting optimizer might no longer behave as expected. + + ``init_optimizer`` may be any Pytorch optimizer. + It may contain a mixture of fp16 and fp32 parameters organized into any number of + ``param_groups`` with different hyperparameters. The :class:`FP16_Optimizer` constructor will + ingest these ``param_groups`` and remember them. + + Calls to :: + + loss.backward() + + must be replaced with :: + + optimizer.backward(loss) + + because :class:`FP16_Optimizer` requires ownership of the backward pass to implement + loss scaling and copies to master gradients. + + .. note:: + Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients + are downscaled before being applied. This means that adjusting the loss scale, or using + dynamic loss scaling, should not require retuning the learning rate or any other + hyperparameters. + + + **Advanced options** + + **Closures**: :class:`FP16_Optimizer` can wrap a Pytorch optimizer that receives a closure. + See docstring for :attr:`step`. + + **Gradient clipping**: Use :attr:`clip_master_grads`. + + **Multiple losses**: If your model accumulates gradients from multiple losses, + this can be made more efficient by supplying ``update_master_grads=False`` + to :attr:`backward`. See docstring for :attr:`backward`. + + **Manually adjusting loss scale**: The current loss scale can be retrieved or set via :: + + print(optimizer.loss_scale) + optimizer.loss_scale = new_loss_scale + + For static loss scaling, manually adjusting the loss scale over time is a reasonable + thing to do. During later epochs, gradients may become smaller, and a + higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss + scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting + the loss scale is not recommended. + + **Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in + Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer` + should still work as intended. + """ + + def __init__(self, + init_optimizer, + static_loss_scale=1.0, + dynamic_loss_scale=False, + dynamic_loss_args=None, + verbose=False): + if not torch.cuda.is_available: + raise SystemError('Cannot use fp16 without CUDA.') + + self.verbose = verbose + + self.optimizer = init_optimizer + # init_state_dict sets up an alternative way to cast per-param state tensors. + # Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary. + # init_state_dict = init_optimizer.state_dict() + + self.fp16_groups = [] + self.fp32_from_fp16_groups = [] + self.fp32_from_fp32_groups = [] + for i, param_group in enumerate(self.optimizer.param_groups): + self.maybe_print( + 'FP16_Optimizer processing param group {}:'.format(i)) + fp16_params_this_group = [] + fp32_params_this_group = [] + fp32_from_fp16_params_this_group = [] + for i, param in enumerate(param_group['params']): + if param.requires_grad: + if param.type() == 'torch.cuda.HalfTensor': + self.maybe_print( + 'FP16_Optimizer received torch.cuda.HalfTensor with {}' + .format(param.size())) + fp16_params_this_group.append(param) + master_param = param.detach().clone().float() + master_param.requires_grad = True + # Copythe model parallel flag. + master_param.model_parallel = param.model_parallel + param_group['params'][i] = master_param + fp32_from_fp16_params_this_group.append(master_param) + # Reset existing state dict key to the new master param. + # We still need to recast per-param state tensors, if any, to FP32. + if param in self.optimizer.state: + self.optimizer.state[ + master_param] = self.optimizer.state.pop(param) + elif param.type() == 'torch.cuda.FloatTensor': + self.maybe_print( + 'FP16_Optimizer received torch.cuda.FloatTensor with {}' + .format(param.size())) + fp32_params_this_group.append(param) + param_group['params'][i] = param + else: + raise TypeError( + 'Wrapped parameters must be either ' + 'torch.cuda.FloatTensor or torch.cuda.HalfTensor. ' + 'Received {}'.format(param.type())) + + self.fp16_groups.append(fp16_params_this_group) + self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) + self.fp32_from_fp32_groups.append(fp32_params_this_group) + + # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors + self.optimizer.load_state_dict(self.optimizer.state_dict()) + # alternative way to cast per-param state tensors: + # self.optimizer.load_state_dict(init_state_dict) + + if dynamic_loss_scale: + self.dynamic_loss_scale = True + if dynamic_loss_args is not None: + self.loss_scaler = DynamicLossScaler(**dynamic_loss_args) + else: + self.loss_scaler = DynamicLossScaler() + else: + self.dynamic_loss_scale = False + self.loss_scaler = LossScaler(static_loss_scale) + + self.overflow = False + self.first_closure_call_this_step = True + + self.clip_grad_norm = nn.utils.clip_grad.clip_grad_norm_ + + def maybe_print(self, msg): + if self.verbose: + print(msg) + + def __getstate__(self): + raise RuntimeError( + 'FP16_Optimizer should be serialized using state_dict().') + + def __setstate__(self, state): + raise RuntimeError( + 'FP16_Optimizer should be deserialized using load_state_dict().') + + def zero_grad(self, set_grads_to_None=False): + """ + Zero fp32 and fp16 parameter grads. + """ + # In principle, only the .grad attributes of the model params need to be zeroed, + # because gradients are copied into the FP32 master params. However, we zero + # all gradients owned by the optimizer, just to be safe: + for group in self.optimizer.param_groups: + for p in group['params']: + if set_grads_to_None: + p.grad = None + else: + if p.grad is not None: + p.grad.detach_() + p.grad.zero_() + + # Zero fp16 gradients owned by the model: + for fp16_group in self.fp16_groups: + for param in fp16_group: + if set_grads_to_None: + param.grad = None + else: + if param.grad is not None: + param.grad.detach_( + ) # as in torch.optim.optimizer.zero_grad() + param.grad.zero_() + + def _check_overflow(self): + params = [] + for group in self.fp16_groups: + for param in group: + params.append(param) + for group in self.fp32_from_fp32_groups: + for param in group: + params.append(param) + self.overflow = self.loss_scaler.has_overflow(params) + + def _update_scale(self, has_overflow=False): + self.loss_scaler.update_scale(has_overflow) + + def _master_params_to_model_params(self): + for fp16_group, fp32_from_fp16_group in zip( + self.fp16_groups, self.fp32_from_fp16_groups): + master_params_to_model_params(fp16_group, fp32_from_fp16_group) + + def _model_params_to_master_params(self): + for fp16_group, fp32_from_fp16_group in zip( + self.fp16_groups, self.fp32_from_fp16_groups): + master_params_to_model_params(fp32_from_fp16_group, fp16_group) + + # To consider: Integrate distributed with this wrapper by registering a hook on each variable + # that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream. + def _model_grads_to_master_grads(self): + for fp16_group, fp32_from_fp16_group in zip( + self.fp16_groups, self.fp32_from_fp16_groups): + model_grads_to_master_grads(fp16_group, fp32_from_fp16_group) + + def _downscale_master(self): + if self.loss_scale != 1.0: + for group in self.optimizer.param_groups: + for param in group['params']: + if param.grad is not None: + param.grad.data.mul_(1. / self.loss_scale) + + def clip_master_grads(self, max_norm, norm_type=2): + """ + Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``. + + Args: + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + + Returns: + Total norm of the current fp32 gradients (viewed as a single vector). + + .. warning:: + Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``). # noqa + """ + if not self.overflow: + fp32_params = [] + for param_group in self.optimizer.param_groups: + for param in param_group['params']: + fp32_params.append(param) + return self.clip_grad_norm(fp32_params, max_norm, norm_type) + else: + return -1 + + def state_dict(self): + """ + Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. + This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict + of the contained Pytorch optimizer. + Example:: + + checkpoint = {} + checkpoint['model'] = model.state_dict() + checkpoint['optimizer'] = optimizer.state_dict() + torch.save(checkpoint, "saved.pth") + """ + state_dict = {} + state_dict['loss_scaler'] = self.loss_scaler + state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale + state_dict['overflow'] = self.overflow + state_dict[ + 'first_closure_call_this_step'] = self.first_closure_call_this_step + state_dict['optimizer_state_dict'] = self.optimizer.state_dict() + state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups + return state_dict + + def load_state_dict(self, state_dict): + """ + Loads a state_dict created by an earlier call to state_dict(). + If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, + whose parameters in turn came from ``model``, it is expected that the user + will call ``model.load_state_dict()`` before + ``fp16_optimizer_instance.load_state_dict()`` is called. + + Example:: + + model = torch.nn.Linear(D_in, D_out).cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) + ... + checkpoint = torch.load("saved.pth") + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + """ + # I think it should actually be ok to reload the optimizer before the model. + self.loss_scaler = state_dict['loss_scaler'] + self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] + self.overflow = state_dict['overflow'] + self.first_closure_call_this_step = state_dict[ + 'first_closure_call_this_step'] + self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) + # At this point, the optimizer's references to the model's fp32 parameters are up to date. + # The optimizer's hyperparameters and internal buffers are also up to date. + # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still + # out of date. There are two options. + # 1: Refresh the master params from the model's fp16 params. + # This requires less storage but incurs precision loss. + # 2: Save and restore the fp32 master copies separately. + # We choose option 2. + # + # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device + # of their associated parameters, because it's possible those buffers might not exist yet in + # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been + # constructed in the same way as the one whose state_dict we are loading, the same master params + # are guaranteed to exist, so we can just copy_() from the saved master params. + for current_group, saved_group in zip(self.fp32_from_fp16_groups, + state_dict['fp32_from_fp16']): + for current, saved in zip(current_group, saved_group): + current.data.copy_(saved.data) + + def step(self, closure=None): # could add clip option. + """ + If no closure is supplied, :attr:`step` should be called after + ``fp16_optimizer_obj.backward(loss)``. + :attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to + :class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params + originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run + another forward pass using their model. + + If a closure is supplied, :attr:`step` may be called without a prior call to + :attr:`backward(loss)`. + This control flow is identical to `ordinary Pytorch optimizer use`_ with closures. + However, the user should take care that any ``loss.backward()`` call within the closure + has been replaced by ``fp16_optimizer_obj.backward(loss)``. + + Args: + closure (optional): Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor. closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss. # noqa + + Example with closure:: + + # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an + # existing pytorch optimizer. + for input, target in dataset: + def closure(): + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + # loss.backward() becomes: + optimizer.backward(loss) + return loss + optimizer.step(closure) + + .. warning:: + Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling. + + .. _`ordinary Pytorch optimizer use`: + http://pytorch.org/docs/master/optim.html#optimizer-step-closure + """ + + scale = self.loss_scaler.loss_scale + self._update_scale(self.overflow) + + if self.overflow: + self.maybe_print( + 'OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}' + .format(scale, self.loss_scale)) + return + + if closure is not None: + retval = self._step_with_closure(closure) + else: + retval = self.optimizer.step() + + self._master_params_to_model_params() + + return retval + + def _step_with_closure(self, closure): + + def wrapped_closure(): + # helpful for debugging + # print("Calling wrapped_closure, first_closure_call_this_step = {}" + # .format(self.first_closure_call_this_step)) + if self.first_closure_call_this_step: + # We expect that the fp16 params are initially fresh on entering self.step(), + # so _master_params_to_model_params() is unnecessary the first time wrapped_closure() + # is called within self.optimizer.step(). + self.first_closure_call_this_step = False + else: + # If self.optimizer.step() internally calls wrapped_closure more than once, + # it may update the fp32 params after each call. However, self.optimizer + # doesn't know about the fp16 params at all. If the fp32 params get updated, + # we can't rely on self.optimizer to refresh the fp16 params. We need + # to handle that manually: + self._master_params_to_model_params() + # Our API expects the user to give us ownership of the backward() call by + # replacing all calls to loss.backward() with optimizer.backward(loss). + # This requirement holds whether or not the call to backward() is made within a closure. + # If the user is properly calling optimizer.backward(loss) within "closure," + # calling closure() here will give the fp32 master params fresh gradients + # for the optimizer to play with, so all wrapped_closure needs to do is call + # closure() and return the loss. + temp_loss = closure() + while (self.overflow): + scale = self.loss_scaler.loss_scale + self._update_scale(self.overflow) + self.maybe_print( + 'OVERFLOW within closure! Skipping step. Attempted loss scale: {}, ' + 'reducing to {}'.format(scale, self.loss_scale)) + temp_loss = closure() + return temp_loss + + retval = self.optimizer.step(wrapped_closure) + + self.first_closure_call_this_step = True + + return retval + + def backward(self, loss, update_master_grads=True, retain_graph=False): + """ + :attr:`backward` performs the following conceptual steps: + + 1. fp32_loss = loss.float() (see first Note below) + 2. scaled_loss = fp32_loss*loss_scale + 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined). # noqa + 4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32. # noqa + 5. Finally, master grads are divided by loss_scale. + + In this way, after :attr:`backward`, the master params have fresh gradients, + and :attr:`step` may be called. + + .. note:: + :attr:`backward` internally converts the loss to fp32 before applying the loss scale. + This provides some additional safety against overflow if the user has supplied an + fp16 loss value. + However, for maximum overflow safety, the user should + compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to + :attr:`backward`. + + .. warning:: + The gradients found in a model's leaves after the call to + :attr:`backward` should not be regarded as valid in general, + because it's possible + they have been scaled (and in the case of dynamic loss scaling, + the scale factor may change over time). + If the user wants to inspect gradients after a call to :attr:`backward`, + only the master gradients should be regarded as valid. These can be retrieved via + :attr:`inspect_master_grad_data()`. + + Args: + loss: The loss output by the user's model. loss may be either float or half (but see first Note above). + update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`. # noqa + retain_graph (bool, optional, default=False): Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``. If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below). # noqa + + Example:: + + # Ordinary operation: + optimizer.backward(loss) + + # Naive operation with multiple losses (technically valid, but less efficient): + # fp32 grads will be correct after the second call, but + # the first call incurs an unnecessary fp16->fp32 grad copy. + optimizer.backward(loss1) + optimizer.backward(loss2) + + # More efficient way to handle multiple losses: + # The fp16->fp32 grad copy is delayed until fp16 grads from all + # losses have been accumulated. + optimizer.backward(loss1, update_master_grads=False) + optimizer.backward(loss2, update_master_grads=False) + optimizer.update_master_grads() + """ + # To consider: try multiple backward passes using retain_grad=True to find + # a loss scale that works. After you find a loss scale that works, do a final dummy + # backward pass with retain_graph=False to tear down the graph. Doing this would avoid + # discarding the iteration, but probably wouldn't improve overall efficiency. + self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) + if update_master_grads: + self.update_master_grads() + + def update_master_grads(self): + """ + Copy the ``.grad`` attribute from stored references to fp16 parameters to + the ``.grad`` attribute of the fp32 master parameters that are directly + updated by the optimizer. :attr:`update_master_grads` only needs to be called if + ``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``. + """ + if self.dynamic_loss_scale: + self._check_overflow() + if self.overflow: return # noqa + self._model_grads_to_master_grads() + self._downscale_master() + + def inspect_master_grad_data(self): + """ + When running with :class:`FP16_Optimizer`, + ``.grad`` attributes of a model's fp16 leaves should not be + regarded as truthful, because they might be scaled. + After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered, + the fp32 master params' ``.grad`` + attributes will contain valid gradients properly divided by the loss scale. However, + because :class:`FP16_Optimizer` flattens some parameters, accessing them may be + nonintuitive. :attr:`inspect_master_grad_data` + allows those gradients to be viewed with shapes corresponding to their associated model leaves. + + Returns: + List of lists (one list for each parameter group). The list for each parameter group + is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group. + """ + if self.overflow: + print( + 'Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. ' + 'Gradients are currently invalid (may be inf, nan, or stale). Returning None.' + ) + return None + else: + # The optimizer owns only references to master params. + master_grads_data = [] + for param_group in self.optimizer.param_groups: + master_grads_this_group = [] + for param in param_group['params']: + if param.grad is not None: + master_grads_this_group.append(param.grad.data) + else: + master_grads_this_group.append(None) + master_grads_data.append(master_grads_this_group) + return master_grads_data + + # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" + def _get_loss_scale(self): + return self.loss_scaler.loss_scale + + def _set_loss_scale(self, value): + self.loss_scaler.cur_scale = value + + loss_scale = property(_get_loss_scale, _set_loss_scale) + + # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" + def _get_state(self): + return self.optimizer.state + + def _set_state(self, value): + self.optimizer.state = value + + state = property(_get_state, _set_state) + + # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" + # (for example, to adjust the learning rate) + def _get_param_groups(self): + return self.optimizer.param_groups + + def _set_param_groups(self, value): + self.optimizer.param_groups = value + + param_groups = property(_get_param_groups, _set_param_groups) diff --git a/modelscope/utils/multi_modal/fp16/fp16util.py b/modelscope/utils/multi_modal/fp16/fp16util.py new file mode 100644 index 00000000..29595a6c --- /dev/null +++ b/modelscope/utils/multi_modal/fp16/fp16util.py @@ -0,0 +1,216 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from torch.autograd import Variable + + +class tofp16(nn.Module): + """ + Utility module that implements:: + + def forward(self, input): + return input.half() + """ + + def __init__(self): + super(tofp16, self).__init__() + + def forward(self, input): + return input.half() + + +def BN_convert_float(module): + """ + Utility function for network_to_half(). + + Retained for legacy purposes. + """ + if isinstance( + module, + torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: + module.float() + for child in module.children(): + BN_convert_float(child) + return module + + +def network_to_half(network): + """ + Convert model to half precision in a batchnorm-safe way. + + Retained for legacy purposes. It is recommended to use FP16Model. + """ + return nn.Sequential(tofp16(), BN_convert_float(network.half())) + + +def convert_module(module, dtype): + """ + Converts a module's immediate parameters and buffers to dtype. + """ + for param in module.parameters(recurse=False): + if param is not None: + if param.data.dtype.is_floating_point: + param.data = param.data.to(dtype=dtype) + if param._grad is not None and param._grad.data.dtype.is_floating_point: + param._grad.data = param._grad.data.to(dtype=dtype) + + for buf in module.buffers(recurse=False): + if buf is not None and buf.data.dtype.is_floating_point: + buf.data = buf.data.to(dtype=dtype) + + +def convert_network(network, dtype): + """ + Converts a network's parameters and buffers to dtype. + """ + for module in network.modules(): + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm + ) and module.affine is True: + continue + convert_module(module, dtype) + return network + + +class FP16Model(nn.Module): + """ + Convert model to half precision in a batchnorm-safe way. + """ + + def __init__(self, network): + super(FP16Model, self).__init__() + self.network = convert_network(network, dtype=torch.half) + + def forward(self, *inputs): + inputs = tuple(t.half() for t in inputs) + return self.network(*inputs) + + +def backwards_debug_hook(grad): + raise RuntimeError( + 'master_params recieved a gradient in the backward pass!') + + +def prep_param_lists(model, flat_master=False): + """ + Creates a list of FP32 master parameters for a given model, as in + `Training Neural Networks with Mixed Precision: Real Examples`_. + + Args: + model (torch.nn.Module): Existing Pytorch model + flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization. # noqa + Returns: + A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element. # noqa + + Example:: + + model_params, master_params = prep_param_lists(model) + + .. warning:: + Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`. # noqa + + .. _`Training Neural Networks with Mixed Precision: Real Examples`: + http://on-demand.gputechconf.com/gtc/2018/video/S81012/ + """ + model_params = [ + param for param in model.parameters() if param.requires_grad + ] + + if flat_master: + # Give the user some more useful error messages + try: + # flatten_dense_tensors returns a contiguous flat array. + # http://pytorch.org/docs/master/_modules/torch/_utils.html + master_params = _flatten_dense_tensors( + [param.data for param in model_params]).float() + except: # noqa + print( + 'Error in prep_param_lists: model may contain a mixture of parameters ' + 'of different types. Use flat_master=False, or use F16_Optimizer.' + ) + raise + master_params = torch.nn.Parameter(master_params) + master_params.requires_grad = True + # master_params.register_hook(backwards_debug_hook) + if master_params.grad is None: + master_params.grad = master_params.new(*master_params.size()) + return model_params, [master_params] + else: + master_params = [ + param.clone().float().detach() for param in model_params + ] + for param in master_params: + param.requires_grad = True + return model_params, master_params + + +def model_grads_to_master_grads(model_params, + master_params, + flat_master=False): + """ + Copy model gradients to master gradients. + + Args: + model_params: List of model parameters created by :func:`prep_param_lists`. + master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`. # noqa + """ + if flat_master: + # The flattening may incur one more deep copy than is necessary. + master_params[0].grad.data.copy_( + _flatten_dense_tensors([p.grad.data for p in model_params])) + else: + for model, master in zip(model_params, master_params): + if model.grad is not None: + if master.grad is None: + master.grad = Variable( + master.data.new(*master.data.size())) + master.grad.data.copy_(model.grad.data) + else: + master.grad = None + + +def master_params_to_model_params(model_params, + master_params, + flat_master=False): + """ + Copy master parameters to model parameters. + + Args: + model_params: List of model parameters created by :func:`prep_param_lists`. + master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. # noqa + """ + if flat_master: + for model, master in zip( + model_params, + _unflatten_dense_tensors(master_params[0].data, model_params)): + model.data.copy_(master) + else: + for model, master in zip(model_params, master_params): + model.data.copy_(master.data) + + +# Backward compatibility fixes + + +def to_python_float(t): + if hasattr(t, 'item'): + return t.item() + else: + return t[0] + + +TORCH_MAJOR = int(torch.__version__.split('.')[0]) +TORCH_MINOR = int(torch.__version__.split('.')[1]) diff --git a/modelscope/utils/multi_modal/fp16/loss_scaler.py b/modelscope/utils/multi_modal/fp16/loss_scaler.py new file mode 100755 index 00000000..fc55a4ed --- /dev/null +++ b/modelscope/utils/multi_modal/fp16/loss_scaler.py @@ -0,0 +1,237 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +# item() is a recent addition, so this helps with backward compatibility. +def to_python_float(t): + if hasattr(t, 'item'): + return t.item() + else: + return t[0] + + +class LossScaler: + """ + Class that manages a static loss scale. This class is intended to interact with + :class:`FP16_Optimizer`, and should not be directly manipulated by the user. + + Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to + :class:`FP16_Optimizer`'s constructor. + + Args: + scale (float, optional, default=1.0): The loss scale. + """ + + def __init__(self, scale=1): + self.cur_scale = scale + + # `params` is a list / generator of torch.Variable + def has_overflow(self, params): + return False + + # `x` is a torch.Tensor + def _has_inf_or_nan(x): + return False + + def update_scale(self, overflow): + pass + + @property + def loss_scale(self): + return self.cur_scale + + def scale_gradient(self, module, grad_in, grad_out): + return tuple(self.loss_scale * g for g in grad_in) + + def backward(self, loss, retain_graph=False): + scaled_loss = loss * self.loss_scale + scaled_loss.backward(retain_graph=retain_graph) + + +class DynamicLossScaler: + """ + Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler` + indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of + :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler` + operates, because the default options can be changed using the + the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor. + + Loss scaling is designed to combat the problem of underflowing gradients encountered at long + times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss + scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are + encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has + occurred. + :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch, + and :class:`DynamicLossScaler` adjusts the loss scale to a lower value. + If a certain number of iterations occur without overflowing gradients detected, + :class:`DynamicLossScaler` increases the loss scale once more. + In this way :class:`DynamicLossScaler` attempts to "ride the edge" of + always using the highest loss scale possible without incurring overflow. + + Args: + init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` + scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. # noqa + scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. # noqa + """ + + def __init__(self, + init_scale=2**32, + scale_factor=2., + scale_window=1000, + min_scale=1, + delayed_shift=1, + consecutive_hysteresis=False): + self.cur_scale = init_scale + self.cur_iter = 0 + self.last_overflow_iter = -1 + self.scale_factor = scale_factor + self.scale_window = scale_window + self.min_scale = min_scale + self.delayed_shift = delayed_shift + self.cur_hysteresis = delayed_shift + self.consecutive_hysteresis = consecutive_hysteresis + + # `params` is a list / generator of torch.Variable + def has_overflow_serial(self, params): + for p in params: + if p.grad is not None and DynamicLossScaler._has_inf_or_nan( + p.grad.data): + return True + + return False + + def has_overflow(self, params): + overflow = self.has_overflow_serial(params) + overflow_gpu = torch.cuda.ByteTensor([overflow]) + overflow = overflow_gpu[0].item() + return bool(overflow) + + # `x` is a torch.Tensor + def _has_inf_or_nan(x): + try: + # if x is half, the .float() incurs an additional deep copy, but it's necessary if + # Pytorch's .sum() creates a one-element tensor of the same type as x + # (which is true for some recent version of pytorch). + cpu_sum = float(x.float().sum()) + # More efficient version that can be used if .sum() returns a Python scalar + # cpu_sum = float(x.sum()) + except RuntimeError as instance: + # We want to check if inst is actually an overflow exception. + # RuntimeError could come from a different error. + # If so, we still want the exception to propagate. + if 'value cannot be converted' not in instance.args[0]: + raise + return True + else: + if cpu_sum == float( + 'inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: + return True + return False + + # `overflow` is boolean indicating whether the gradient overflowed + def update_scale(self, overflow): + + if not hasattr(self, 'min_scale'): + self.min_scale = 1 + if not hasattr(self, 'delayed_shift'): + self.delayed_shift = 1 + if not hasattr(self, 'cur_hysteresis'): + self.cur_hysteresis = 1 + if not hasattr(self, 'consecutive_hysteresis'): + self.consecutive_hysteresis = True + if overflow: + # self.cur_scale /= self.scale_factor + if self.delayed_shift == 1 or self.cur_hysteresis == 1: + self.cur_scale = max(self.cur_scale / self.scale_factor, + self.min_scale) + else: + self.cur_hysteresis -= 1 + self.last_overflow_iter = self.cur_iter + else: + if self.consecutive_hysteresis: + self.cur_hysteresis = self.delayed_shift + if (self.cur_iter + - self.last_overflow_iter) % self.scale_window == 0: + if not self.consecutive_hysteresis: + self.cur_hysteresis = self.delayed_shift + self.cur_scale *= self.scale_factor + self.cur_iter += 1 + + @property + def loss_scale(self): + return self.cur_scale + + def scale_gradient(self, module, grad_in, grad_out): + return tuple(self.loss_scale * g for g in grad_in) + + def backward(self, loss, retain_graph=False): + scaled_loss = loss * self.loss_scale + scaled_loss.backward(retain_graph=retain_graph) + + +############################################################## +# Example usage below here -- assuming it's in a separate file +############################################################## +""" +TO-DO separate out into an example. +if __name__ == "__main__": + import torch + from torch.autograd import Variable + from dynamic_loss_scaler import DynamicLossScaler + + # N is batch size; D_in is input dimension; + # H is hidden dimension; D_out is output dimension. + N, D_in, H, D_out = 64, 1000, 100, 10 + + # Create random Tensors to hold inputs and outputs, and wrap them in Variables. + x = Variable(torch.randn(N, D_in), requires_grad=False) + y = Variable(torch.randn(N, D_out), requires_grad=False) + + w1 = Variable(torch.randn(D_in, H), requires_grad=True) + w2 = Variable(torch.randn(H, D_out), requires_grad=True) + parameters = [w1, w2] + + learning_rate = 1e-6 + optimizer = torch.optim.SGD(parameters, lr=learning_rate) + loss_scaler = DynamicLossScaler() + + for t in range(500): + y_pred = x.mm(w1).clamp(min=0).mm(w2) + loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale + print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) + print('Iter {} scaled loss: {}'.format(t, loss.data[0])) + print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) + + # Run backprop + optimizer.zero_grad() + loss.backward() + + # Check for overflow + has_overflow = DynamicLossScaler.has_overflow(parameters) + + # If no overflow, unscale grad and update as usual + if not has_overflow: + for param in parameters: + param.grad.data.mul_(1. / loss_scaler.loss_scale) + optimizer.step() + # Otherwise, don't do anything -- ie, skip iteration + else: + print('OVERFLOW!') + + # Update loss scale for next iteration + loss_scaler.update_scale(has_overflow) + +""" diff --git a/modelscope/utils/nlp/distributed.py b/modelscope/utils/nlp/distributed.py index 2b590a10..53332c0f 100755 --- a/modelscope/utils/nlp/distributed.py +++ b/modelscope/utils/nlp/distributed.py @@ -35,7 +35,10 @@ def initialize_distributed(rank, mpu, world_size, model_parallel_size, init_method = 'tcp://' init_method += master_ip + ':' + master_port torch.distributed.init_process_group( - backend='nccl', world_size=8, rank=rank, init_method=init_method) + backend='nccl', + world_size=world_size, + rank=rank, + init_method=init_method) # Set the model-parallel communicators. mpu.initialize_model_parallel(model_parallel_size) diff --git a/modelscope/utils/nlp/nlp_utils.py b/modelscope/utils/nlp/nlp_utils.py deleted file mode 100644 index eba12103..00000000 --- a/modelscope/utils/nlp/nlp_utils.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import List - -from modelscope.outputs import OutputKeys -from modelscope.pipelines.nlp import (ConversationalTextToSqlPipeline, - DialogStateTrackingPipeline, - TableQuestionAnsweringPipeline) - - -def text2sql_tracking_and_print_results( - test_case, pipelines: List[ConversationalTextToSqlPipeline]): - for p in pipelines: - last_sql, history = '', [] - for item in test_case['utterance']: - case = { - 'utterance': item, - 'history': history, - 'last_sql': last_sql, - 'database_id': test_case['database_id'], - 'local_db_path': test_case['local_db_path'] - } - results = p(case) - print({'question': item}) - print(results) - last_sql = results['text'] - history.append(item) - - -def tracking_and_print_dialog_states( - test_case, pipelines: List[DialogStateTrackingPipeline]): - import json - pipelines_len = len(pipelines) - history_states = [{}] - utter = {} - for step, item in enumerate(test_case): - utter.update(item) - result = pipelines[step % pipelines_len]({ - 'utter': - utter, - 'history_states': - history_states - }) - print(json.dumps(result)) - - history_states.extend([result[OutputKeys.OUTPUT], {}]) - - -def tableqa_tracking_and_print_results( - test_case, pipelines: List[TableQuestionAnsweringPipeline]): - for pipeline in pipelines: - historical_queries = None - for question in test_case['utterance']: - output_dict = pipeline({ - 'question': question, - 'history_sql': historical_queries - }) - print('output_dict', output_dict['output'].string, - output_dict['output'].query) - historical_queries = output_dict['history'] diff --git a/modelscope/utils/nlp/space/args.py b/modelscope/utils/nlp/space/args.py index d9e91e74..c92401c5 100644 --- a/modelscope/utils/nlp/space/args.py +++ b/modelscope/utils/nlp/space/args.py @@ -1,6 +1,4 @@ -""" -Parse argument. -""" +# Copyright (c) Alibaba, Inc. and its affiliates. import argparse diff --git a/modelscope/utils/nlp/space/clean_dataset.py b/modelscope/utils/nlp/space/clean_dataset.py index 4578ccc4..2c971b10 100644 --- a/modelscope/utils/nlp/space/clean_dataset.py +++ b/modelscope/utils/nlp/space/clean_dataset.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os import re diff --git a/modelscope/utils/nlp/space/criterions.py b/modelscope/utils/nlp/space/criterions.py index 60f98457..82ef4ba5 100644 --- a/modelscope/utils/nlp/space/criterions.py +++ b/modelscope/utils/nlp/space/criterions.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import torch import torch.nn.functional as F from torch.nn.modules.loss import _Loss diff --git a/modelscope/utils/nlp/space/db_ops.py b/modelscope/utils/nlp/space/db_ops.py index 880b018b..d1d14ef9 100644 --- a/modelscope/utils/nlp/space/db_ops.py +++ b/modelscope/utils/nlp/space/db_ops.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os import random import sqlite3 diff --git a/modelscope/utils/nlp/space/ontology.py b/modelscope/utils/nlp/space/ontology.py index 99b084bb..c55d12e1 100644 --- a/modelscope/utils/nlp/space/ontology.py +++ b/modelscope/utils/nlp/space/ontology.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + all_domains = [ 'restaurant', 'hotel', 'attraction', 'train', 'taxi', 'police', 'hospital' ] diff --git a/modelscope/utils/nlp/space/scores.py b/modelscope/utils/nlp/space/scores.py index fe0a8a17..eb6dd41c 100644 --- a/modelscope/utils/nlp/space/scores.py +++ b/modelscope/utils/nlp/space/scores.py @@ -1,3 +1,6 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + + def hierarchical_set_score(frame1, frame2): # deal with empty frame if not (frame1 and frame2): diff --git a/modelscope/utils/nlp/space/utils.py b/modelscope/utils/nlp/space/utils.py index 81d1b1c5..56e67671 100644 --- a/modelscope/utils/nlp/space/utils.py +++ b/modelscope/utils/nlp/space/utils.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import logging from collections import OrderedDict diff --git a/modelscope/utils/nlp/space/utils_dst.py b/modelscope/utils/nlp/space/utils_dst.py index 2a7e67d7..6277172e 100644 --- a/modelscope/utils/nlp/space/utils_dst.py +++ b/modelscope/utils/nlp/space/utils_dst.py @@ -1,3 +1,29 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import List + +from modelscope.outputs import OutputKeys +from modelscope.pipelines.nlp import DialogStateTrackingPipeline + + +def tracking_and_print_dialog_states( + test_case, pipelines: List[DialogStateTrackingPipeline]): + import json + pipelines_len = len(pipelines) + history_states = [{}] + utter = {} + for step, item in enumerate(test_case): + utter.update(item) + result = pipelines[step % pipelines_len]({ + 'utter': + utter, + 'history_states': + history_states + }) + print(json.dumps(result)) + + history_states.extend([result[OutputKeys.OUTPUT], {}]) + + def batch_to_device(batch, device): batch_on_device = [] for element in batch: diff --git a/modelscope/utils/nlp/space_T_en/__init__.py b/modelscope/utils/nlp/space_T_en/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/utils/nlp/space_T_en/utils.py b/modelscope/utils/nlp/space_T_en/utils.py new file mode 100644 index 00000000..d884c241 --- /dev/null +++ b/modelscope/utils/nlp/space_T_en/utils.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import List + +from modelscope.outputs import OutputKeys +from modelscope.pipelines.nlp import ConversationalTextToSqlPipeline + + +def text2sql_tracking_and_print_results( + test_case, pipelines: List[ConversationalTextToSqlPipeline]): + for p in pipelines: + last_sql, history = '', [] + for item in test_case['utterance']: + case = { + 'utterance': item, + 'history': history, + 'last_sql': last_sql, + 'database_id': test_case['database_id'], + 'local_db_path': test_case['local_db_path'] + } + results = p(case) + print({'question': item}) + print(results) + last_sql = results[OutputKeys.OUTPUT][OutputKeys.TEXT] + history.append(item) diff --git a/modelscope/utils/registry.py b/modelscope/utils/registry.py index 7a9c79e2..5284aa43 100644 --- a/modelscope/utils/registry.py +++ b/modelscope/utils/registry.py @@ -74,6 +74,7 @@ class Registry(object): raise KeyError(f'{module_name} is already registered in ' f'{self._name}[{group_key}]') self._modules[group_key][module_name] = module_cls + module_cls.group_key = group_key def register_module(self, group_key: str = default_group, @@ -176,7 +177,7 @@ def build_from_cfg(cfg, raise TypeError('default_args must be a dict or None, ' f'but got {type(default_args)}') - # dynamic load installation reqruiements for this module + # dynamic load installation requirements for this module from modelscope.utils.import_utils import LazyImportModule sig = (registry.name.upper(), group_key, cfg['type']) LazyImportModule.import_module(sig) @@ -193,8 +194,10 @@ def build_from_cfg(cfg, if isinstance(obj_type, str): obj_cls = registry.get(obj_type, group_key=group_key) if obj_cls is None: - raise KeyError(f'{obj_type} is not in the {registry.name}' - f' registry group {group_key}') + raise KeyError( + f'{obj_type} is not in the {registry.name}' + f' registry group {group_key}. Please make' + f' sure the correct version of ModelScope library is used.') obj_cls.group_key = group_key elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): obj_cls = obj_type diff --git a/modelscope/utils/regress_test_utils.py b/modelscope/utils/regress_test_utils.py index 47bbadfe..8045d3e9 100644 --- a/modelscope/utils/regress_test_utils.py +++ b/modelscope/utils/regress_test_utils.py @@ -7,6 +7,7 @@ import pickle import random import shutil import tempfile +from collections import OrderedDict from collections.abc import Mapping from pathlib import Path from types import FunctionType @@ -14,6 +15,7 @@ from typing import Any, Dict, Union import json import numpy as np +import torch import torch.optim from torch import nn @@ -65,12 +67,14 @@ class RegressTool: def monitor_module_single_forward(self, module: nn.Module, file_name: str, - compare_fn=None): + compare_fn=None, + **kwargs): """Monitor a pytorch module in a single forward. - @param module: A torch module - @param file_name: The file_name to store or load file - @param compare_fn: A custom fn used to compare the results manually. + Args: + module: A torch module + file_name: The file_name to store or load file + compare_fn: A custom fn used to compare the results manually. >>> def compare_fn(v1, v2, key, type): >>> return None @@ -79,6 +83,10 @@ class RegressTool: v2 is the value of current version key is the key of submodules type is in one of 'input', 'output' + + kwargs: + atol: The absolute gap between two np arrays. + rtol: The relative gap between two np arrays. """ baseline = os.getenv('REGRESSION_BASELINE') if baseline is None or self.baseline is None: @@ -107,7 +115,7 @@ class RegressTool: baseline = os.path.join(tempfile.gettempdir(), name) self.load(baseline, name) with open(baseline, 'rb') as f: - baseline_json = pickle.load(f) + base = pickle.load(f) class NumpyEncoder(json.JSONEncoder): """Special json encoder for numpy types @@ -122,9 +130,9 @@ class RegressTool: return obj.tolist() return json.JSONEncoder.default(self, obj) - print(f'baseline: {json.dumps(baseline_json, cls=NumpyEncoder)}') + print(f'baseline: {json.dumps(base, cls=NumpyEncoder)}') print(f'latest : {json.dumps(io_json, cls=NumpyEncoder)}') - if not compare_io_and_print(baseline_json, io_json, compare_fn): + if not compare_io_and_print(base, io_json, compare_fn, **kwargs): raise ValueError('Result not match!') @contextlib.contextmanager @@ -136,26 +144,31 @@ class RegressTool: ignore_keys=None, compare_random=True, reset_dropout=True, - lazy_stop_callback=None): + lazy_stop_callback=None, + **kwargs): """Monitor a pytorch module's backward data and cfg data within a step of the optimizer. This is usually useful when you try to change some dangerous code which has the risk of affecting the training loop. - @param trainer: A dict or an object contains the model/optimizer/lr_scheduler - @param file_name: The file_name to store or load file - @param level: The regression level. + Args: + trainer: A dict or an object contains the model/optimizer/lr_scheduler + file_name: The file_name to store or load file + level: The regression level. 'strict' for matching every single tensor. Please make sure the parameters of head are fixed and the drop-out rate is zero. 'config' for matching the initial config, like cfg file, optimizer param_groups, lr_scheduler params and the random seed. 'metric' for compare the best metrics in the evaluation loop. - @param compare_fn: A custom fn used to compare the results manually. - @param ignore_keys: The keys to ignore of the named_parameters. - @param compare_random: If to compare random setttings, default True. - @param reset_dropout: Reset all dropout modules to 0.0. - @param lazy_stop_callback: A callback passed in, when the moniting is over, this callback will be called. + compare_fn: A custom fn used to compare the results manually. + ignore_keys: The keys to ignore of the named_parameters. + compare_random: If to compare random setttings, default True. + reset_dropout: Reset all dropout modules to 0.0. + lazy_stop_callback: A callback passed in, when the moniting is over, this callback will be called. + kwargs: + atol: The absolute gap between two np arrays. + rtol: The relative gap between two np arrays. >>> def compare_fn(v1, v2, key, type): >>> return None @@ -265,14 +278,15 @@ class RegressTool: baseline_json = pickle.load(f) if level == 'strict' and not compare_io_and_print( - baseline_json['forward'], io_json, compare_fn): + baseline_json['forward'], io_json, compare_fn, **kwargs): raise RuntimeError('Forward not match!') if not compare_backward_and_print( baseline_json['backward'], bw_json, compare_fn=compare_fn, ignore_keys=ignore_keys, - level=level): + level=level, + **kwargs): raise RuntimeError('Backward not match!') cfg_opt1 = { 'optimizer': baseline_json['optimizer'], @@ -286,7 +300,8 @@ class RegressTool: 'cfg': summary['cfg'], 'state': None if not compare_random else summary['state'] } - if not compare_cfg_and_optimizers(cfg_opt1, cfg_opt2, compare_fn): + if not compare_cfg_and_optimizers(cfg_opt1, cfg_opt2, compare_fn, + **kwargs): raise RuntimeError('Cfg or optimizers not match!') @@ -303,7 +318,8 @@ class MsRegressTool(RegressTool): compare_fn=None, ignore_keys=None, compare_random=True, - lazy_stop_callback=None): + lazy_stop_callback=None, + **kwargs): if lazy_stop_callback is None: @@ -319,7 +335,7 @@ class MsRegressTool(RegressTool): trainer.register_hook(EarlyStopHook()) - def _train_loop(trainer, *args, **kwargs): + def _train_loop(trainer, *args_train, **kwargs_train): with self.monitor_module_train( trainer, file_name, @@ -327,9 +343,11 @@ class MsRegressTool(RegressTool): compare_fn=compare_fn, ignore_keys=ignore_keys, compare_random=compare_random, - lazy_stop_callback=lazy_stop_callback): + lazy_stop_callback=lazy_stop_callback, + **kwargs): try: - return trainer.train_loop_origin(*args, **kwargs) + return trainer.train_loop_origin(*args_train, + **kwargs_train) except MsRegressTool.EarlyStopError: pass @@ -346,16 +364,22 @@ def compare_module(module1: nn.Module, module2: nn.Module): def numpify_tensor_nested(tensors, reduction=None, clip_value=10000): - import torch + try: + from modelscope.outputs import ModelOutputBase + except ImportError: + ModelOutputBase = dict "Numpify `tensors` (even if it's a nested list/tuple of tensors)." - if isinstance(tensors, (list, tuple)): - return type(tensors)( - numpify_tensor_nested(t, reduction, clip_value) for t in tensors) - if isinstance(tensors, Mapping): - return { + if isinstance(tensors, (Mapping, ModelOutputBase)): + return OrderedDict({ k: numpify_tensor_nested(t, reduction, clip_value) for k, t in tensors.items() - } + }) + if isinstance(tensors, list): + return list( + numpify_tensor_nested(t, reduction, clip_value) for t in tensors) + if isinstance(tensors, tuple): + return tuple( + numpify_tensor_nested(t, reduction, clip_value) for t in tensors) if isinstance(tensors, torch.Tensor): t: np.ndarray = tensors.cpu().numpy() if clip_value is not None: @@ -370,12 +394,19 @@ def numpify_tensor_nested(tensors, reduction=None, clip_value=10000): def detach_tensor_nested(tensors): - import torch + try: + from modelscope.outputs import ModelOutputBase + except ImportError: + ModelOutputBase = dict "Detach `tensors` (even if it's a nested list/tuple of tensors)." - if isinstance(tensors, (list, tuple)): - return type(tensors)(detach_tensor_nested(t) for t in tensors) - if isinstance(tensors, Mapping): - return {k: detach_tensor_nested(t) for k, t in tensors.items()} + if isinstance(tensors, (Mapping, ModelOutputBase)): + return OrderedDict( + {k: detach_tensor_nested(t) + for k, t in tensors.items()}) + if isinstance(tensors, list): + return list(detach_tensor_nested(t) for t in tensors) + if isinstance(tensors, tuple): + return tuple(detach_tensor_nested(t) for t in tensors) if isinstance(tensors, torch.Tensor): return tensors.detach() return tensors @@ -530,7 +561,8 @@ def compare_arguments_nested(print_content, ) return False if not all([ - compare_arguments_nested(None, sub_arg1, sub_arg2) + compare_arguments_nested( + None, sub_arg1, sub_arg2, rtol=rtol, atol=atol) for sub_arg1, sub_arg2 in zip(arg1, arg2) ]): if print_content is not None: @@ -551,7 +583,8 @@ def compare_arguments_nested(print_content, print(f'{print_content}, key diff:{set(keys1) - set(keys2)}') return False if not all([ - compare_arguments_nested(None, arg1[key], arg2[key]) + compare_arguments_nested( + None, arg1[key], arg2[key], rtol=rtol, atol=atol) for key in keys1 ]): if print_content is not None: @@ -574,7 +607,7 @@ def compare_arguments_nested(print_content, raise ValueError(f'type not supported: {type1}') -def compare_io_and_print(baseline_json, io_json, compare_fn=None): +def compare_io_and_print(baseline_json, io_json, compare_fn=None, **kwargs): if compare_fn is None: def compare_fn(*args, **kwargs): @@ -602,10 +635,10 @@ def compare_io_and_print(baseline_json, io_json, compare_fn=None): else: match = compare_arguments_nested( f'unmatched module {key} input args', v1input['args'], - v2input['args']) and match + v2input['args'], **kwargs) and match match = compare_arguments_nested( f'unmatched module {key} input kwargs', v1input['kwargs'], - v2input['kwargs']) and match + v2input['kwargs'], **kwargs) and match v1output = numpify_tensor_nested(v1['output']) v2output = numpify_tensor_nested(v2['output']) res = compare_fn(v1output, v2output, key, 'output') @@ -615,8 +648,11 @@ def compare_io_and_print(baseline_json, io_json, compare_fn=None): ) match = match and res else: - match = compare_arguments_nested(f'unmatched module {key} outputs', - v1output, v2output) and match + match = compare_arguments_nested( + f'unmatched module {key} outputs', + arg1=v1output, + arg2=v2output, + **kwargs) and match return match @@ -624,7 +660,8 @@ def compare_backward_and_print(baseline_json, bw_json, level, ignore_keys=None, - compare_fn=None): + compare_fn=None, + **kwargs): if compare_fn is None: def compare_fn(*args, **kwargs): @@ -653,18 +690,26 @@ def compare_backward_and_print(baseline_json, data2, grad2, data_after2 = bw_json[key]['data'], bw_json[key][ 'grad'], bw_json[key]['data_after'] match = compare_arguments_nested( - f'unmatched module {key} tensor data', data1, data2) and match + f'unmatched module {key} tensor data', + arg1=data1, + arg2=data2, + **kwargs) and match if level == 'strict': match = compare_arguments_nested( - f'unmatched module {key} grad data', grad1, - grad2) and match + f'unmatched module {key} grad data', + arg1=grad1, + arg2=grad2, + **kwargs) and match match = compare_arguments_nested( f'unmatched module {key} data after step', data_after1, - data_after2) and match + data_after2, **kwargs) and match return match -def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): +def compare_cfg_and_optimizers(baseline_json, + cfg_json, + compare_fn=None, + **kwargs): if compare_fn is None: def compare_fn(*args, **kwargs): @@ -686,12 +731,12 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): print( f"Optimizer type not equal:{optimizer1['type']} and {optimizer2['type']}" ) - match = compare_arguments_nested('unmatched optimizer defaults', - optimizer1['defaults'], - optimizer2['defaults']) and match - match = compare_arguments_nested('unmatched optimizer state_dict', - optimizer1['state_dict'], - optimizer2['state_dict']) and match + match = compare_arguments_nested( + 'unmatched optimizer defaults', optimizer1['defaults'], + optimizer2['defaults'], **kwargs) and match + match = compare_arguments_nested( + 'unmatched optimizer state_dict', optimizer1['state_dict'], + optimizer2['state_dict'], **kwargs) and match res = compare_fn(lr_scheduler1, lr_scheduler2, None, 'lr_scheduler') if res is not None: @@ -703,16 +748,17 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): print( f"Optimizer type not equal:{lr_scheduler1['type']} and {lr_scheduler2['type']}" ) - match = compare_arguments_nested('unmatched lr_scheduler state_dict', - lr_scheduler1['state_dict'], - lr_scheduler2['state_dict']) and match + match = compare_arguments_nested( + 'unmatched lr_scheduler state_dict', lr_scheduler1['state_dict'], + lr_scheduler2['state_dict'], **kwargs) and match res = compare_fn(cfg1, cfg2, None, 'cfg') if res is not None: print(f'cfg compared with user compare_fn with result:{res}\n') match = match and res else: - match = compare_arguments_nested('unmatched cfg', cfg1, cfg2) and match + match = compare_arguments_nested( + 'unmatched cfg', arg1=cfg1, arg2=cfg2, **kwargs) and match res = compare_fn(state1, state2, None, 'state') if res is not None: @@ -721,6 +767,6 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): match = match and res else: match = compare_arguments_nested('unmatched random state', state1, - state2) and match + state2, **kwargs) and match return match diff --git a/modelscope/utils/tensor_utils.py b/modelscope/utils/tensor_utils.py index b68a639c..8f580d19 100644 --- a/modelscope/utils/tensor_utils.py +++ b/modelscope/utils/tensor_utils.py @@ -1,6 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # Part of the implementation is borrowed from huggingface/transformers. -from collections import Mapping +from collections.abc import Mapping def torch_nested_numpify(tensors): @@ -8,8 +8,11 @@ def torch_nested_numpify(tensors): NOTE: If the type of input tensors is dict-like(Mapping, dict, OrderedDict, etc.), the return type will be dict. - @param tensors: Nested torch tensors. - @return: The numpify tensors. + Args: + tensors: Nested torch tensors. + + Returns: + The numpify tensors. """ import torch @@ -30,8 +33,11 @@ def torch_nested_detach(tensors): NOTE: If the type of input tensors is dict-like(Mapping, dict, OrderedDict, etc.), the return type will be dict. - @param tensors: Nested torch tensors. - @return: The detached tensors. + Args: + tensors: Nested torch tensors. + + Returns: + The detached tensors. """ import torch diff --git a/modelscope/version.py b/modelscope/version.py index 1e4826d6..541dfc57 100644 --- a/modelscope/version.py +++ b/modelscope/version.py @@ -1 +1,5 @@ -__version__ = '0.4.7' +# Make sure to modify __release_datetime__ to release time when making official release. +__version__ = '0.5.0' +# default release datetime for branches under active development is set +# to be a time far-far-away-into-the-future +__release_datetime__ = '2099-10-13 08:56:12' diff --git a/requirements/audio.txt b/requirements/audio.txt index d22ad8f1..bef32121 100644 --- a/requirements/audio.txt +++ b/requirements/audio.txt @@ -1,5 +1,5 @@ easyasr>=0.0.2 -espnet>=202204 +espnet==202204 h5py inflect keras @@ -14,7 +14,11 @@ nltk numpy<=1.18 # protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged. protobuf>3,<3.21.0 -py_sound_connect +ptflops +py_sound_connect>=0.1 +pytorch_wavelets +PyWavelets>=1.0.0 +scikit-learn SoundFile>0.10 sox torchaudio diff --git a/requirements/cv.txt b/requirements/cv.txt index 5a2d7763..f29b296b 100644 --- a/requirements/cv.txt +++ b/requirements/cv.txt @@ -1,4 +1,5 @@ albumentations>=1.0.3 +av>=9.2.0 easydict fairscale>=0.4.1 fastai>=1.0.51 @@ -7,14 +8,18 @@ ffmpeg-python>=0.2.0 ftfy imageio>=2.9.0 imageio-ffmpeg>=0.4.2 +imgaug>=0.4.0 +kornia>=0.5.0 lmdb lpips ml_collections mmcls>=0.21.0 mmdet>=2.25.0 +moviepy>=1.0.3 networkx>=2.5 +numba onnxruntime>=1.10 -pai-easycv>=0.6.3.6 +pai-easycv>=0.6.3.9 pandas psutil regex diff --git a/requirements/framework.txt b/requirements/framework.txt index b51faeda..2408cda6 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -1,4 +1,5 @@ addict +attrs datasets easydict einops diff --git a/requirements/multi-modal.txt b/requirements/multi-modal.txt index 02e87baa..255f6155 100644 --- a/requirements/multi-modal.txt +++ b/requirements/multi-modal.txt @@ -5,6 +5,7 @@ pycocotools>=2.0.4 # rough-score was just recently updated from 0.0.4 to 0.0.7 # which introduced compatability issues that are being investigated rouge_score<=0.0.4 +sacrebleu taming-transformers-rom1504 timm tokenizers diff --git a/requirements/nlp.txt b/requirements/nlp.txt index 15f2f41a..9a4abd71 100644 --- a/requirements/nlp.txt +++ b/requirements/nlp.txt @@ -2,6 +2,10 @@ en_core_web_sm>=2.3.5 jieba>=0.42.1 megatron_util pai-easynlp +# protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged. +protobuf>=3.19.0,<3.21.0 +pythainlp +pyvi # rough-score was just recently updated from 0.0.4 to 0.0.7 # which introduced compatability issues that are being investigated rouge_score<=0.0.4 diff --git a/requirements/science.txt b/requirements/science.txt new file mode 100644 index 00000000..72994f72 --- /dev/null +++ b/requirements/science.txt @@ -0,0 +1,6 @@ +iopath +lmdb +ml_collections +scipy +tensorboardX +tokenizers diff --git a/tests/export/test_export_sbert_sequence_classification.py b/tests/export/test_export_sbert_sequence_classification.py index 535b3f5d..0e4f8349 100644 --- a/tests/export/test_export_sbert_sequence_classification.py +++ b/tests/export/test_export_sbert_sequence_classification.py @@ -3,9 +3,10 @@ import os import shutil import tempfile import unittest +from collections import OrderedDict from modelscope.exporters import Exporter, TorchModelExporter -from modelscope.models.base import Model +from modelscope.models import Model from modelscope.utils.test_utils import test_level @@ -22,15 +23,47 @@ class TestExportSbertSequenceClassification(unittest.TestCase): shutil.rmtree(self.tmp_dir) super().tearDown() - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skip def test_export_sbert_sequence_classification(self): model = Model.from_pretrained(self.model_id) print( Exporter.from_model(model).export_onnx( - shape=(2, 256), outputs=self.tmp_dir)) + shape=(2, 256), output_dir=self.tmp_dir)) print( TorchModelExporter.from_model(model).export_torch_script( - shape=(2, 256), outputs=self.tmp_dir)) + shape=(2, 256), output_dir=self.tmp_dir)) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_export_outer_module(self): + from transformers import BertForSequenceClassification, BertTokenizerFast + model = BertForSequenceClassification.from_pretrained( + 'bert-base-uncased') + tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') + dummy_inputs = tokenizer( + tokenizer.unk_token, + padding='max_length', + max_length=256, + return_tensors='pt') + dynamic_axis = {0: 'batch', 1: 'sequence'} + inputs = OrderedDict([ + ('input_ids', dynamic_axis), + ('attention_mask', dynamic_axis), + ('token_type_ids', dynamic_axis), + ]) + outputs = OrderedDict({'logits': {0: 'batch'}}) + output_files = TorchModelExporter().export_onnx( + model=model, + dummy_inputs=dummy_inputs, + inputs=inputs, + outputs=outputs, + output_dir='/tmp') + print(output_files) + output_files = TorchModelExporter().export_torch_script( + model=model, + dummy_inputs=dummy_inputs, + output_dir='/tmp', + strict=False) + print(output_files) if __name__ == '__main__': diff --git a/tests/hub/test_download_dataset.py b/tests/hub/test_download_dataset.py new file mode 100644 index 00000000..29b5d1ab --- /dev/null +++ b/tests/hub/test_download_dataset.py @@ -0,0 +1,709 @@ +import unittest + +from modelscope.msdatasets import MsDataset +from modelscope.utils.test_utils import test_level + + +class DownloadDatasetTest(unittest.TestCase): + + def setUp(self): + self.subset_count = 10 + + def download_subset(self, dataset, subset_name): + dataset = MsDataset.load(dataset, subset_name=subset_name) + if isinstance(dataset, MsDataset): + lens = len(dataset) + print(f'dataset {subset_name} len: {lens}') + self.assertTrue(lens > 0) + else: + assert isinstance(dataset, dict) + lens = {key: len(subset) for key, subset in dataset.items()} + print(f'dataset {subset_name} len: {lens}') + self.assertTrue(all([_len > 0 for _len in lens.values()])) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_download_glue(self): + subset = [ + 'cola', 'sst2', 'mrpc', 'qqp', 'stsb', 'mnli', 'mnli_mismatched', + 'mnli_matched', 'qnli', 'rte', 'wnli', 'ax' + ] + for subset_name in subset: + self.download_subset('glue', subset_name) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_download_super_glue(self): + subset = [ + 'boolq', 'cb', 'copa', 'multirc', 'record', 'rte', 'wic', 'wsc', + 'wsc.fixed', 'axb', 'axg' + ] + for subset_name in subset: + self.download_subset('super_glue', subset_name) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_download_nllb(self): + subset = [ + 'ace_Latn-ban_Latn', 'ace_Latn-bjn_Latn', 'ace_Latn-bug_Latn', + 'ace_Latn-ceb_Latn', 'ace_Latn-eng_Latn', 'ace_Latn-fij_Latn', + 'ace_Latn-ilo_Latn', 'ace_Latn-jav_Latn', 'ace_Latn-min_Latn', + 'ace_Latn-mri_Latn', 'ace_Latn-pag_Latn', 'ace_Latn-plt_Latn', + 'ace_Latn-smo_Latn', 'ace_Latn-sun_Latn', 'ace_Latn-war_Latn', + 'afr_Latn-aka_Latn', 'afr_Latn-amh_Ethi', 'afr_Latn-bam_Latn', + 'afr_Latn-bem_Latn', 'afr_Latn-cjk_Latn', 'afr_Latn-dik_Latn', + 'afr_Latn-dyu_Latn', 'afr_Latn-eng_Latn', 'afr_Latn-ewe_Latn', + 'afr_Latn-fon_Latn', 'afr_Latn-fra_Latn', 'afr_Latn-fuv_Latn', + 'afr_Latn-gaz_Latn', 'afr_Latn-hau_Latn', 'afr_Latn-ibo_Latn', + 'afr_Latn-kam_Latn', 'afr_Latn-kik_Latn', 'afr_Latn-kin_Latn', + 'afr_Latn-kmb_Latn', 'afr_Latn-knc_Arab', 'afr_Latn-knc_Latn', + 'afr_Latn-kon_Latn', 'afr_Latn-lin_Latn', 'afr_Latn-lua_Latn', + 'afr_Latn-lug_Latn', 'afr_Latn-luo_Latn', 'afr_Latn-nso_Latn', + 'afr_Latn-nus_Latn', 'afr_Latn-nya_Latn', 'afr_Latn-run_Latn', + 'afr_Latn-sna_Latn', 'afr_Latn-som_Latn', 'afr_Latn-sot_Latn', + 'afr_Latn-ssw_Latn', 'afr_Latn-swh_Latn', 'afr_Latn-tir_Ethi', + 'afr_Latn-tsn_Latn', 'afr_Latn-tso_Latn', 'afr_Latn-tum_Latn', + 'afr_Latn-twi_Latn', 'afr_Latn-umb_Latn', 'afr_Latn-wol_Latn', + 'afr_Latn-xho_Latn', 'afr_Latn-yor_Latn', 'afr_Latn-zul_Latn', + 'aka_Latn-amh_Ethi', 'aka_Latn-bam_Latn', 'aka_Latn-bem_Latn', + 'aka_Latn-cjk_Latn', 'aka_Latn-dik_Latn', 'aka_Latn-dyu_Latn', + 'aka_Latn-eng_Latn', 'aka_Latn-ewe_Latn', 'aka_Latn-fon_Latn', + 'aka_Latn-fra_Latn', 'aka_Latn-fuv_Latn', 'aka_Latn-gaz_Latn', + 'aka_Latn-hau_Latn', 'aka_Latn-ibo_Latn', 'aka_Latn-kam_Latn', + 'aka_Latn-kik_Latn', 'aka_Latn-kin_Latn', 'aka_Latn-kmb_Latn', + 'aka_Latn-knc_Arab', 'aka_Latn-knc_Latn', 'aka_Latn-kon_Latn', + 'aka_Latn-lin_Latn', 'aka_Latn-lua_Latn', 'aka_Latn-lug_Latn', + 'aka_Latn-luo_Latn', 'aka_Latn-nso_Latn', 'aka_Latn-nus_Latn', + 'aka_Latn-nya_Latn', 'aka_Latn-run_Latn', 'aka_Latn-sna_Latn', + 'aka_Latn-som_Latn', 'aka_Latn-sot_Latn', 'aka_Latn-ssw_Latn', + 'aka_Latn-swh_Latn', 'aka_Latn-tir_Ethi', 'aka_Latn-tsn_Latn', + 'aka_Latn-tso_Latn', 'aka_Latn-tum_Latn', 'aka_Latn-twi_Latn', + 'aka_Latn-umb_Latn', 'aka_Latn-wol_Latn', 'aka_Latn-xho_Latn', + 'aka_Latn-yor_Latn', 'aka_Latn-zul_Latn', 'amh_Ethi-bam_Latn', + 'amh_Ethi-bem_Latn', 'amh_Ethi-cjk_Latn', 'amh_Ethi-dik_Latn', + 'amh_Ethi-dyu_Latn', 'amh_Ethi-eng_Latn', 'amh_Ethi-ewe_Latn', + 'amh_Ethi-fon_Latn', 'amh_Ethi-fra_Latn', 'amh_Ethi-fuv_Latn', + 'amh_Ethi-gaz_Latn', 'amh_Ethi-hau_Latn', 'amh_Ethi-ibo_Latn', + 'amh_Ethi-kam_Latn', 'amh_Ethi-kik_Latn', 'amh_Ethi-kin_Latn', + 'amh_Ethi-kmb_Latn', 'amh_Ethi-knc_Arab', 'amh_Ethi-knc_Latn', + 'amh_Ethi-kon_Latn', 'amh_Ethi-lin_Latn', 'amh_Ethi-lua_Latn', + 'amh_Ethi-lug_Latn', 'amh_Ethi-luo_Latn', 'amh_Ethi-nso_Latn', + 'amh_Ethi-nus_Latn', 'amh_Ethi-nya_Latn', 'amh_Ethi-run_Latn', + 'amh_Ethi-sna_Latn', 'amh_Ethi-som_Latn', 'amh_Ethi-sot_Latn', + 'amh_Ethi-ssw_Latn', 'amh_Ethi-swh_Latn', 'amh_Ethi-tir_Ethi', + 'amh_Ethi-tsn_Latn', 'amh_Ethi-tso_Latn', 'amh_Ethi-tum_Latn', + 'amh_Ethi-twi_Latn', 'amh_Ethi-umb_Latn', 'amh_Ethi-wol_Latn', + 'amh_Ethi-xho_Latn', 'amh_Ethi-yor_Latn', 'amh_Ethi-zul_Latn', + 'arb_Arab-ckb_Arab', 'arb_Arab-crh_Latn', 'arb_Arab-dik_Latn', + 'arb_Arab-diq_Latn', 'arb_Arab-fuv_Latn', 'arb_Arab-kmr_Latn', + 'arb_Arab-knc_Latn', 'arb_Arab-nus_Latn', 'arb_Arab-som_Latn', + 'arb_Arab-tat_Cyrl', 'arb_Arab-tzm_Tfng', 'arb_Arab-urd_Arab', + 'arb_Arab-wol_Latn', 'asm_Beng-awa_Deva', 'asm_Beng-ben_Beng', + 'asm_Beng-bho_Deva', 'asm_Beng-eng_Latn', 'asm_Beng-guj_Gujr', + 'asm_Beng-hin_Deva', 'asm_Beng-hne_Deva', 'asm_Beng-kan_Knda', + 'asm_Beng-kas_Arab', 'asm_Beng-kas_Deva', 'asm_Beng-mag_Deva', + 'asm_Beng-mai_Deva', 'asm_Beng-mal_Mlym', 'asm_Beng-mar_Deva', + 'asm_Beng-npi_Deva', 'asm_Beng-ory_Orya', 'asm_Beng-pan_Guru', + 'asm_Beng-san_Deva', 'asm_Beng-sat_Beng', 'asm_Beng-sin_Sinh', + 'asm_Beng-snd_Arab', 'asm_Beng-tam_Taml', 'asm_Beng-tel_Telu', + 'asm_Beng-urd_Arab', 'awa_Deva-ben_Beng', 'awa_Deva-bho_Deva', + 'awa_Deva-eng_Latn', 'awa_Deva-guj_Gujr', 'awa_Deva-hin_Deva', + 'awa_Deva-hne_Deva', 'awa_Deva-kan_Knda', 'awa_Deva-kas_Arab', + 'awa_Deva-kas_Deva', 'awa_Deva-mag_Deva', 'awa_Deva-mai_Deva', + 'awa_Deva-mal_Mlym', 'awa_Deva-mar_Deva', 'awa_Deva-npi_Deva', + 'awa_Deva-ory_Orya', 'awa_Deva-pan_Guru', 'awa_Deva-san_Deva', + 'awa_Deva-sat_Beng', 'awa_Deva-sin_Sinh', 'awa_Deva-snd_Arab', + 'awa_Deva-tam_Taml', 'awa_Deva-tel_Telu', 'awa_Deva-urd_Arab', + 'ayr_Latn-eng_Latn', 'ayr_Latn-spa_Latn', 'azb_Arab-eng_Latn', + 'azj_Latn-eng_Latn', 'azj_Latn-rus_Cyrl', 'bak_Cyrl-crh_Latn', + 'bak_Cyrl-eng_Latn', 'bak_Cyrl-kir_Cyrl', 'bak_Cyrl-rus_Cyrl', + 'bak_Cyrl-tat_Cyrl', 'bak_Cyrl-tuk_Latn', 'bak_Cyrl-uig_Arab', + 'bak_Cyrl-uzn_Latn', 'bam_Latn-bem_Latn', 'bam_Latn-cjk_Latn', + 'bam_Latn-dik_Latn', 'bam_Latn-dyu_Latn', 'bam_Latn-eng_Latn', + 'bam_Latn-ewe_Latn', 'bam_Latn-fon_Latn', 'bam_Latn-fra_Latn', + 'bam_Latn-fuv_Latn', 'bam_Latn-gaz_Latn', 'bam_Latn-hau_Latn', + 'bam_Latn-ibo_Latn', 'bam_Latn-kam_Latn', 'bam_Latn-kik_Latn', + 'bam_Latn-kin_Latn', 'bam_Latn-kmb_Latn', 'bam_Latn-knc_Arab', + 'bam_Latn-knc_Latn', 'bam_Latn-kon_Latn', 'bam_Latn-lin_Latn', + 'bam_Latn-lua_Latn', 'bam_Latn-lug_Latn', 'bam_Latn-luo_Latn', + 'bam_Latn-nso_Latn', 'bam_Latn-nus_Latn', 'bam_Latn-nya_Latn', + 'bam_Latn-run_Latn', 'bam_Latn-sna_Latn', 'bam_Latn-som_Latn', + 'bam_Latn-sot_Latn', 'bam_Latn-ssw_Latn', 'bam_Latn-swh_Latn', + 'bam_Latn-tir_Ethi', 'bam_Latn-tsn_Latn', 'bam_Latn-tso_Latn', + 'bam_Latn-tum_Latn', 'bam_Latn-twi_Latn', 'bam_Latn-umb_Latn', + 'bam_Latn-wol_Latn', 'bam_Latn-xho_Latn', 'bam_Latn-yor_Latn', + 'bam_Latn-zul_Latn', 'ban_Latn-bjn_Latn', 'ban_Latn-bug_Latn', + 'ban_Latn-ceb_Latn', 'ban_Latn-eng_Latn', 'ban_Latn-fij_Latn', + 'ban_Latn-ilo_Latn', 'ban_Latn-jav_Latn', 'ban_Latn-min_Latn', + 'ban_Latn-mri_Latn', 'ban_Latn-pag_Latn', 'ban_Latn-plt_Latn', + 'ban_Latn-smo_Latn', 'ban_Latn-sun_Latn', 'ban_Latn-war_Latn', + 'bel_Cyrl-eng_Latn', 'bel_Cyrl-rus_Cyrl', 'bem_Latn-cjk_Latn', + 'bem_Latn-dik_Latn', 'bem_Latn-dyu_Latn', 'bem_Latn-eng_Latn', + 'bem_Latn-ewe_Latn', 'bem_Latn-fon_Latn', 'bem_Latn-fra_Latn', + 'bem_Latn-fuv_Latn', 'bem_Latn-gaz_Latn', 'bem_Latn-hau_Latn', + 'bem_Latn-ibo_Latn', 'bem_Latn-kam_Latn', 'bem_Latn-kik_Latn', + 'bem_Latn-kin_Latn', 'bem_Latn-kmb_Latn', 'bem_Latn-knc_Arab', + 'bem_Latn-knc_Latn', 'bem_Latn-kon_Latn', 'bem_Latn-lin_Latn', + 'bem_Latn-lua_Latn', 'bem_Latn-lug_Latn', 'bem_Latn-luo_Latn', + 'bem_Latn-nso_Latn', 'bem_Latn-nus_Latn', 'bem_Latn-nya_Latn', + 'bem_Latn-run_Latn', 'bem_Latn-sna_Latn', 'bem_Latn-som_Latn', + 'bem_Latn-sot_Latn', 'bem_Latn-ssw_Latn', 'bem_Latn-swh_Latn', + 'bem_Latn-tir_Ethi', 'bem_Latn-tsn_Latn', 'bem_Latn-tso_Latn', + 'bem_Latn-tum_Latn', 'bem_Latn-twi_Latn', 'bem_Latn-umb_Latn', + 'bem_Latn-wol_Latn', 'bem_Latn-xho_Latn', 'bem_Latn-yor_Latn', + 'bem_Latn-zul_Latn', 'ben_Beng-bho_Deva', 'ben_Beng-eng_Latn', + 'ben_Beng-guj_Gujr', 'ben_Beng-hin_Deva', 'ben_Beng-hne_Deva', + 'ben_Beng-kan_Knda', 'ben_Beng-kas_Arab', 'ben_Beng-kas_Deva', + 'ben_Beng-mag_Deva', 'ben_Beng-mai_Deva', 'ben_Beng-mal_Mlym', + 'ben_Beng-mar_Deva', 'ben_Beng-npi_Deva', 'ben_Beng-ory_Orya', + 'ben_Beng-pan_Guru', 'ben_Beng-pbt_Arab', 'ben_Beng-san_Deva', + 'ben_Beng-sat_Beng', 'ben_Beng-sin_Sinh', 'ben_Beng-snd_Arab', + 'ben_Beng-tam_Taml', 'ben_Beng-tel_Telu', 'ben_Beng-urd_Arab', + 'bho_Deva-eng_Latn', 'bho_Deva-guj_Gujr', 'bho_Deva-hin_Deva', + 'bho_Deva-hne_Deva', 'bho_Deva-kan_Knda', 'bho_Deva-kas_Arab', + 'bho_Deva-kas_Deva', 'bho_Deva-mag_Deva', 'bho_Deva-mai_Deva', + 'bho_Deva-mal_Mlym', 'bho_Deva-mar_Deva', 'bho_Deva-npi_Deva', + 'bho_Deva-ory_Orya', 'bho_Deva-pan_Guru', 'bho_Deva-san_Deva', + 'bho_Deva-sat_Beng', 'bho_Deva-sin_Sinh', 'bho_Deva-snd_Arab', + 'bho_Deva-tam_Taml', 'bho_Deva-tel_Telu', 'bho_Deva-urd_Arab', + 'bjn_Latn-bug_Latn', 'bjn_Latn-ceb_Latn', 'bjn_Latn-eng_Latn', + 'bjn_Latn-fij_Latn', 'bjn_Latn-ilo_Latn', 'bjn_Latn-ind_Latn', + 'bjn_Latn-jav_Latn', 'bjn_Latn-min_Latn', 'bjn_Latn-mri_Latn', + 'bjn_Latn-pag_Latn', 'bjn_Latn-plt_Latn', 'bjn_Latn-smo_Latn', + 'bjn_Latn-sun_Latn', 'bjn_Latn-war_Latn', 'bod_Tibt-eng_Latn', + 'bos_Latn-eng_Latn', 'bug_Latn-ceb_Latn', 'bug_Latn-eng_Latn', + 'bug_Latn-fij_Latn', 'bug_Latn-ilo_Latn', 'bug_Latn-jav_Latn', + 'bug_Latn-min_Latn', 'bug_Latn-mri_Latn', 'bug_Latn-pag_Latn', + 'bug_Latn-plt_Latn', 'bug_Latn-smo_Latn', 'bug_Latn-sun_Latn', + 'bug_Latn-war_Latn', 'ceb_Latn-eng_Latn', 'ceb_Latn-fij_Latn', + 'ceb_Latn-ilo_Latn', 'ceb_Latn-jav_Latn', 'ceb_Latn-min_Latn', + 'ceb_Latn-mri_Latn', 'ceb_Latn-pag_Latn', 'ceb_Latn-plt_Latn', + 'ceb_Latn-smo_Latn', 'ceb_Latn-sun_Latn', 'ceb_Latn-war_Latn', + 'cjk_Latn-dik_Latn', 'cjk_Latn-dyu_Latn', 'cjk_Latn-eng_Latn', + 'cjk_Latn-ewe_Latn', 'cjk_Latn-fon_Latn', 'cjk_Latn-fra_Latn', + 'cjk_Latn-fuv_Latn', 'cjk_Latn-gaz_Latn', 'cjk_Latn-hau_Latn', + 'cjk_Latn-ibo_Latn', 'cjk_Latn-kam_Latn', 'cjk_Latn-kik_Latn', + 'cjk_Latn-kin_Latn', 'cjk_Latn-kmb_Latn', 'cjk_Latn-knc_Arab', + 'cjk_Latn-knc_Latn', 'cjk_Latn-kon_Latn', 'cjk_Latn-lin_Latn', + 'cjk_Latn-lua_Latn', 'cjk_Latn-lug_Latn', 'cjk_Latn-luo_Latn', + 'cjk_Latn-nso_Latn', 'cjk_Latn-nus_Latn', 'cjk_Latn-nya_Latn', + 'cjk_Latn-por_Latn', 'cjk_Latn-run_Latn', 'cjk_Latn-sna_Latn', + 'cjk_Latn-som_Latn', 'cjk_Latn-sot_Latn', 'cjk_Latn-ssw_Latn', + 'cjk_Latn-swh_Latn', 'cjk_Latn-tir_Ethi', 'cjk_Latn-tsn_Latn', + 'cjk_Latn-tso_Latn', 'cjk_Latn-tum_Latn', 'cjk_Latn-twi_Latn', + 'cjk_Latn-umb_Latn', 'cjk_Latn-wol_Latn', 'cjk_Latn-xho_Latn', + 'cjk_Latn-yor_Latn', 'cjk_Latn-zul_Latn', 'ckb_Arab-diq_Latn', + 'ckb_Arab-eng_Latn', 'ckb_Arab-kmr_Latn', 'ckb_Arab-pbt_Arab', + 'ckb_Arab-prs_Arab', 'ckb_Arab-tgk_Cyrl', 'crh_Latn-eng_Latn', + 'crh_Latn-kir_Cyrl', 'crh_Latn-rus_Cyrl', 'crh_Latn-tat_Cyrl', + 'crh_Latn-tuk_Latn', 'crh_Latn-uig_Arab', 'crh_Latn-uzn_Latn', + 'cym_Latn-eng_Latn', 'dik_Latn-dyu_Latn', 'dik_Latn-eng_Latn', + 'dik_Latn-ewe_Latn', 'dik_Latn-fon_Latn', 'dik_Latn-fra_Latn', + 'dik_Latn-fuv_Latn', 'dik_Latn-gaz_Latn', 'dik_Latn-hau_Latn', + 'dik_Latn-ibo_Latn', 'dik_Latn-kam_Latn', 'dik_Latn-kik_Latn', + 'dik_Latn-kin_Latn', 'dik_Latn-kmb_Latn', 'dik_Latn-knc_Arab', + 'dik_Latn-knc_Latn', 'dik_Latn-kon_Latn', 'dik_Latn-lin_Latn', + 'dik_Latn-lua_Latn', 'dik_Latn-lug_Latn', 'dik_Latn-luo_Latn', + 'dik_Latn-nso_Latn', 'dik_Latn-nus_Latn', 'dik_Latn-nya_Latn', + 'dik_Latn-run_Latn', 'dik_Latn-sna_Latn', 'dik_Latn-som_Latn', + 'dik_Latn-sot_Latn', 'dik_Latn-ssw_Latn', 'dik_Latn-swh_Latn', + 'dik_Latn-tir_Ethi', 'dik_Latn-tsn_Latn', 'dik_Latn-tso_Latn', + 'dik_Latn-tum_Latn', 'dik_Latn-twi_Latn', 'dik_Latn-umb_Latn', + 'dik_Latn-wol_Latn', 'dik_Latn-xho_Latn', 'dik_Latn-yor_Latn', + 'dik_Latn-zul_Latn', 'diq_Latn-eng_Latn', 'diq_Latn-kmr_Latn', + 'diq_Latn-pbt_Arab', 'diq_Latn-prs_Arab', 'diq_Latn-tgk_Cyrl', + 'dyu_Latn-eng_Latn', 'dyu_Latn-ewe_Latn', 'dyu_Latn-fon_Latn', + 'dyu_Latn-fra_Latn', 'dyu_Latn-fuv_Latn', 'dyu_Latn-gaz_Latn', + 'dyu_Latn-hau_Latn', 'dyu_Latn-ibo_Latn', 'dyu_Latn-kam_Latn', + 'dyu_Latn-kik_Latn', 'dyu_Latn-kin_Latn', 'dyu_Latn-kmb_Latn', + 'dyu_Latn-knc_Arab', 'dyu_Latn-knc_Latn', 'dyu_Latn-kon_Latn', + 'dyu_Latn-lin_Latn', 'dyu_Latn-lua_Latn', 'dyu_Latn-lug_Latn', + 'dyu_Latn-luo_Latn', 'dyu_Latn-nso_Latn', 'dyu_Latn-nus_Latn', + 'dyu_Latn-nya_Latn', 'dyu_Latn-run_Latn', 'dyu_Latn-sna_Latn', + 'dyu_Latn-som_Latn', 'dyu_Latn-sot_Latn', 'dyu_Latn-ssw_Latn', + 'dyu_Latn-swh_Latn', 'dyu_Latn-tir_Ethi', 'dyu_Latn-tsn_Latn', + 'dyu_Latn-tso_Latn', 'dyu_Latn-tum_Latn', 'dyu_Latn-twi_Latn', + 'dyu_Latn-umb_Latn', 'dyu_Latn-wol_Latn', 'dyu_Latn-xho_Latn', + 'dyu_Latn-yor_Latn', 'dyu_Latn-zul_Latn', 'dzo_Tibt-eng_Latn', + 'eng_Latn-als_Latn', 'eng_Latn-epo_Latn', 'eng_Latn-ewe_Latn', + 'eng_Latn-fao_Latn', 'eng_Latn-fij_Latn', 'eng_Latn-fon_Latn', + 'eng_Latn-fur_Latn', 'eng_Latn-fuv_Latn', 'eng_Latn-gaz_Latn', + 'eng_Latn-gla_Latn', 'eng_Latn-gle_Latn', 'eng_Latn-grn_Latn', + 'eng_Latn-guj_Gujr', 'eng_Latn-hat_Latn', 'eng_Latn-hau_Latn', + 'eng_Latn-hin_Deva', 'eng_Latn-hne_Deva', 'eng_Latn-hye_Armn', + 'eng_Latn-ibo_Latn', 'eng_Latn-ilo_Latn', 'eng_Latn-jav_Latn', + 'eng_Latn-kab_Latn', 'eng_Latn-kac_Latn', 'eng_Latn-kam_Latn', + 'eng_Latn-kan_Knda', 'eng_Latn-kas_Arab', 'eng_Latn-kas_Deva', + 'eng_Latn-kat_Geor', 'eng_Latn-kaz_Cyrl', 'eng_Latn-kbp_Latn', + 'eng_Latn-kea_Latn', 'eng_Latn-khk_Cyrl', 'eng_Latn-khm_Khmr', + 'eng_Latn-kik_Latn', 'eng_Latn-kin_Latn', 'eng_Latn-kir_Cyrl', + 'eng_Latn-kmb_Latn', 'eng_Latn-kmr_Latn', 'eng_Latn-knc_Arab', + 'eng_Latn-knc_Latn', 'eng_Latn-kon_Latn', 'eng_Latn-lao_Laoo', + 'eng_Latn-lij_Latn', 'eng_Latn-lim_Latn', 'eng_Latn-lin_Latn', + 'eng_Latn-lmo_Latn', 'eng_Latn-ltg_Latn', 'eng_Latn-ltz_Latn', + 'eng_Latn-lua_Latn', 'eng_Latn-lug_Latn', 'eng_Latn-luo_Latn', + 'eng_Latn-lus_Latn', 'eng_Latn-mag_Deva', 'eng_Latn-mai_Deva', + 'eng_Latn-mal_Mlym', 'eng_Latn-mar_Deva', 'eng_Latn-min_Latn', + 'eng_Latn-mlt_Latn', 'eng_Latn-mni_Beng', 'eng_Latn-mos_Latn', + 'eng_Latn-mri_Latn', 'eng_Latn-mya_Mymr', 'eng_Latn-npi_Deva', + 'eng_Latn-nso_Latn', 'eng_Latn-nus_Latn', 'eng_Latn-nya_Latn', + 'eng_Latn-ory_Orya', 'eng_Latn-pag_Latn', 'eng_Latn-pan_Guru', + 'eng_Latn-pap_Latn', 'eng_Latn-pbt_Arab', 'eng_Latn-plt_Latn', + 'eng_Latn-prs_Arab', 'eng_Latn-quy_Latn', 'eng_Latn-run_Latn', + 'eng_Latn-sag_Latn', 'eng_Latn-san_Deva', 'eng_Latn-sat_Beng', + 'eng_Latn-scn_Latn', 'eng_Latn-shn_Mymr', 'eng_Latn-sin_Sinh', + 'eng_Latn-smo_Latn', 'eng_Latn-sna_Latn', 'eng_Latn-snd_Arab', + 'eng_Latn-som_Latn', 'eng_Latn-sot_Latn', 'eng_Latn-srd_Latn', + 'eng_Latn-ssw_Latn', 'eng_Latn-sun_Latn', 'eng_Latn-swh_Latn', + 'eng_Latn-szl_Latn', 'eng_Latn-tam_Taml', 'eng_Latn-taq_Latn', + 'eng_Latn-tat_Cyrl', 'eng_Latn-tel_Telu', 'eng_Latn-tgk_Cyrl', + 'eng_Latn-tgl_Latn', 'eng_Latn-tir_Ethi', 'eng_Latn-tpi_Latn', + 'eng_Latn-tsn_Latn', 'eng_Latn-tso_Latn', 'eng_Latn-tuk_Latn', + 'eng_Latn-tum_Latn', 'eng_Latn-twi_Latn', 'eng_Latn-tzm_Tfng', + 'eng_Latn-uig_Arab', 'eng_Latn-umb_Latn', 'eng_Latn-urd_Arab', + 'eng_Latn-uzn_Latn', 'eng_Latn-vec_Latn', 'eng_Latn-war_Latn', + 'eng_Latn-wol_Latn', 'eng_Latn-xho_Latn', 'eng_Latn-ydd_Hebr', + 'eng_Latn-yor_Latn', 'eng_Latn-zho_Hant', 'eng_Latn-zsm_Latn', + 'eng_Latn-zul_Latn', 'epo_Latn-fra_Latn', 'ewe_Latn-fon_Latn', + 'ewe_Latn-fra_Latn', 'ewe_Latn-fuv_Latn', 'ewe_Latn-gaz_Latn', + 'ewe_Latn-hau_Latn', 'ewe_Latn-ibo_Latn', 'ewe_Latn-kam_Latn', + 'ewe_Latn-kik_Latn', 'ewe_Latn-kin_Latn', 'ewe_Latn-kmb_Latn', + 'ewe_Latn-knc_Arab', 'ewe_Latn-knc_Latn', 'ewe_Latn-kon_Latn', + 'ewe_Latn-lin_Latn', 'ewe_Latn-lua_Latn', 'ewe_Latn-lug_Latn', + 'ewe_Latn-luo_Latn', 'ewe_Latn-nso_Latn', 'ewe_Latn-nus_Latn', + 'ewe_Latn-nya_Latn', 'ewe_Latn-run_Latn', 'ewe_Latn-sna_Latn', + 'ewe_Latn-som_Latn', 'ewe_Latn-sot_Latn', 'ewe_Latn-ssw_Latn', + 'ewe_Latn-swh_Latn', 'ewe_Latn-tir_Ethi', 'ewe_Latn-tsn_Latn', + 'ewe_Latn-tso_Latn', 'ewe_Latn-tum_Latn', 'ewe_Latn-twi_Latn', + 'ewe_Latn-umb_Latn', 'ewe_Latn-wol_Latn', 'ewe_Latn-xho_Latn', + 'ewe_Latn-yor_Latn', 'ewe_Latn-zul_Latn', 'fij_Latn-hin_Deva', + 'fij_Latn-ilo_Latn', 'fij_Latn-jav_Latn', 'fij_Latn-min_Latn', + 'fij_Latn-mri_Latn', 'fij_Latn-pag_Latn', 'fij_Latn-plt_Latn', + 'fij_Latn-smo_Latn', 'fij_Latn-sun_Latn', 'fij_Latn-war_Latn', + 'fon_Latn-fra_Latn', 'fon_Latn-fuv_Latn', 'fon_Latn-gaz_Latn', + 'fon_Latn-hau_Latn', 'fon_Latn-ibo_Latn', 'fon_Latn-kam_Latn', + 'fon_Latn-kik_Latn', 'fon_Latn-kin_Latn', 'fon_Latn-kmb_Latn', + 'fon_Latn-knc_Arab', 'fon_Latn-knc_Latn', 'fon_Latn-kon_Latn', + 'fon_Latn-lin_Latn', 'fon_Latn-lua_Latn', 'fon_Latn-lug_Latn', + 'fon_Latn-luo_Latn', 'fon_Latn-nso_Latn', 'fon_Latn-nus_Latn', + 'fon_Latn-nya_Latn', 'fon_Latn-run_Latn', 'fon_Latn-sna_Latn', + 'fon_Latn-som_Latn', 'fon_Latn-sot_Latn', 'fon_Latn-ssw_Latn', + 'fon_Latn-swh_Latn', 'fon_Latn-tir_Ethi', 'fon_Latn-tsn_Latn', + 'fon_Latn-tso_Latn', 'fon_Latn-tum_Latn', 'fon_Latn-twi_Latn', + 'fon_Latn-umb_Latn', 'fon_Latn-wol_Latn', 'fon_Latn-xho_Latn', + 'fon_Latn-yor_Latn', 'fon_Latn-zul_Latn', 'fra_Latn-fuv_Latn', + 'fra_Latn-gaz_Latn', 'fra_Latn-glg_Latn', 'fra_Latn-hat_Latn', + 'fra_Latn-hau_Latn', 'fra_Latn-ibo_Latn', 'fra_Latn-kab_Latn', + 'fra_Latn-kam_Latn', 'fra_Latn-kik_Latn', 'fra_Latn-kin_Latn', + 'fra_Latn-kmb_Latn', 'fra_Latn-knc_Arab', 'fra_Latn-knc_Latn', + 'fra_Latn-kon_Latn', 'fra_Latn-lin_Latn', 'fra_Latn-ltz_Latn', + 'fra_Latn-lua_Latn', 'fra_Latn-lug_Latn', 'fra_Latn-luo_Latn', + 'fra_Latn-nso_Latn', 'fra_Latn-nus_Latn', 'fra_Latn-nya_Latn', + 'fra_Latn-oci_Latn', 'fra_Latn-plt_Latn', 'fra_Latn-run_Latn', + 'fra_Latn-sag_Latn', 'fra_Latn-scn_Latn', 'fra_Latn-sna_Latn', + 'fra_Latn-som_Latn', 'fra_Latn-sot_Latn', 'fra_Latn-ssw_Latn', + 'fra_Latn-swh_Latn', 'fra_Latn-tir_Ethi', 'fra_Latn-tsn_Latn', + 'fra_Latn-tso_Latn', 'fra_Latn-tum_Latn', 'fra_Latn-twi_Latn', + 'fra_Latn-tzm_Tfng', 'fra_Latn-umb_Latn', 'fra_Latn-wol_Latn', + 'fra_Latn-xho_Latn', 'fra_Latn-yor_Latn', 'fra_Latn-zul_Latn', + 'fuv_Latn-gaz_Latn', 'fuv_Latn-hau_Latn', 'fuv_Latn-ibo_Latn', + 'fuv_Latn-kam_Latn', 'fuv_Latn-kik_Latn', 'fuv_Latn-kin_Latn', + 'fuv_Latn-kmb_Latn', 'fuv_Latn-knc_Arab', 'fuv_Latn-knc_Latn', + 'fuv_Latn-kon_Latn', 'fuv_Latn-lin_Latn', 'fuv_Latn-lua_Latn', + 'fuv_Latn-lug_Latn', 'fuv_Latn-luo_Latn', 'fuv_Latn-nso_Latn', + 'fuv_Latn-nus_Latn', 'fuv_Latn-nya_Latn', 'fuv_Latn-run_Latn', + 'fuv_Latn-sna_Latn', 'fuv_Latn-som_Latn', 'fuv_Latn-sot_Latn', + 'fuv_Latn-ssw_Latn', 'fuv_Latn-swh_Latn', 'fuv_Latn-tir_Ethi', + 'fuv_Latn-tsn_Latn', 'fuv_Latn-tso_Latn', 'fuv_Latn-tum_Latn', + 'fuv_Latn-twi_Latn', 'fuv_Latn-umb_Latn', 'fuv_Latn-wol_Latn', + 'fuv_Latn-xho_Latn', 'fuv_Latn-yor_Latn', 'fuv_Latn-zul_Latn', + 'gaz_Latn-run_Latn', 'gaz_Latn-sna_Latn', 'gaz_Latn-som_Latn', + 'gaz_Latn-sot_Latn', 'gaz_Latn-ssw_Latn', 'gaz_Latn-swh_Latn', + 'gaz_Latn-tir_Ethi', 'gaz_Latn-tsn_Latn', 'gaz_Latn-tso_Latn', + 'gaz_Latn-tum_Latn', 'gaz_Latn-twi_Latn', 'gaz_Latn-umb_Latn', + 'gaz_Latn-wol_Latn', 'gaz_Latn-xho_Latn', 'gaz_Latn-yor_Latn', + 'gaz_Latn-zul_Latn', 'glg_Latn-por_Latn', 'grn_Latn-por_Latn', + 'guj_Gujr-hin_Deva', 'guj_Gujr-hne_Deva', 'guj_Gujr-kan_Knda', + 'guj_Gujr-kas_Arab', 'guj_Gujr-kas_Deva', 'guj_Gujr-mag_Deva', + 'guj_Gujr-mai_Deva', 'guj_Gujr-mal_Mlym', 'guj_Gujr-mar_Deva', + 'guj_Gujr-npi_Deva', 'guj_Gujr-ory_Orya', 'guj_Gujr-pan_Guru', + 'guj_Gujr-san_Deva', 'guj_Gujr-sat_Beng', 'guj_Gujr-sin_Sinh', + 'guj_Gujr-snd_Arab', 'guj_Gujr-tam_Taml', 'guj_Gujr-tel_Telu', + 'guj_Gujr-urd_Arab', 'hau_Latn-gaz_Latn', 'hau_Latn-ibo_Latn', + 'hau_Latn-kam_Latn', 'hau_Latn-kik_Latn', 'hau_Latn-kin_Latn', + 'hau_Latn-kmb_Latn', 'hau_Latn-knc_Arab', 'hau_Latn-knc_Latn', + 'hau_Latn-kon_Latn', 'hau_Latn-lin_Latn', 'hau_Latn-lua_Latn', + 'hau_Latn-lug_Latn', 'hau_Latn-luo_Latn', 'hau_Latn-nso_Latn', + 'hau_Latn-nus_Latn', 'hau_Latn-nya_Latn', 'hau_Latn-run_Latn', + 'hau_Latn-sna_Latn', 'hau_Latn-som_Latn', 'hau_Latn-sot_Latn', + 'hau_Latn-ssw_Latn', 'hau_Latn-swh_Latn', 'hau_Latn-tir_Ethi', + 'hau_Latn-tsn_Latn', 'hau_Latn-tso_Latn', 'hau_Latn-tum_Latn', + 'hau_Latn-twi_Latn', 'hau_Latn-umb_Latn', 'hau_Latn-wol_Latn', + 'hau_Latn-xho_Latn', 'hau_Latn-yor_Latn', 'hau_Latn-zul_Latn', + 'hin_Deva-hne_Deva', 'hin_Deva-kan_Knda', 'hin_Deva-kas_Arab', + 'hin_Deva-kas_Deva', 'hin_Deva-mag_Deva', 'hin_Deva-mai_Deva', + 'hin_Deva-mal_Mlym', 'hin_Deva-mar_Deva', 'hin_Deva-npi_Deva', + 'hin_Deva-ory_Orya', 'hin_Deva-pan_Guru', 'hin_Deva-pbt_Arab', + 'hin_Deva-san_Deva', 'hin_Deva-sat_Beng', 'hin_Deva-sin_Sinh', + 'hin_Deva-snd_Arab', 'hin_Deva-tam_Taml', 'hin_Deva-tel_Telu', + 'hin_Deva-urd_Arab', 'hne_Deva-kan_Knda', 'hne_Deva-kas_Arab', + 'hne_Deva-kas_Deva', 'hne_Deva-mag_Deva', 'hne_Deva-mai_Deva', + 'hne_Deva-mal_Mlym', 'hne_Deva-mar_Deva', 'hne_Deva-npi_Deva', + 'hne_Deva-ory_Orya', 'hne_Deva-pan_Guru', 'hne_Deva-san_Deva', + 'hne_Deva-sat_Beng', 'hne_Deva-sin_Sinh', 'hne_Deva-snd_Arab', + 'hne_Deva-tam_Taml', 'hne_Deva-tel_Telu', 'hne_Deva-urd_Arab', + 'hye_Armn-rus_Cyrl', 'ibo_Latn-gaz_Latn', 'ibo_Latn-kam_Latn', + 'ibo_Latn-kik_Latn', 'ibo_Latn-kin_Latn', 'ibo_Latn-kmb_Latn', + 'ibo_Latn-knc_Arab', 'ibo_Latn-knc_Latn', 'ibo_Latn-kon_Latn', + 'ibo_Latn-lin_Latn', 'ibo_Latn-lua_Latn', 'ibo_Latn-lug_Latn', + 'ibo_Latn-luo_Latn', 'ibo_Latn-nso_Latn', 'ibo_Latn-nus_Latn', + 'ibo_Latn-nya_Latn', 'ibo_Latn-run_Latn', 'ibo_Latn-sna_Latn', + 'ibo_Latn-som_Latn', 'ibo_Latn-sot_Latn', 'ibo_Latn-ssw_Latn', + 'ibo_Latn-swh_Latn', 'ibo_Latn-tir_Ethi', 'ibo_Latn-tsn_Latn', + 'ibo_Latn-tso_Latn', 'ibo_Latn-tum_Latn', 'ibo_Latn-twi_Latn', + 'ibo_Latn-umb_Latn', 'ibo_Latn-wol_Latn', 'ibo_Latn-xho_Latn', + 'ibo_Latn-yor_Latn', 'ibo_Latn-zul_Latn', 'ilo_Latn-jav_Latn', + 'ilo_Latn-min_Latn', 'ilo_Latn-mri_Latn', 'ilo_Latn-pag_Latn', + 'ilo_Latn-plt_Latn', 'ilo_Latn-smo_Latn', 'ilo_Latn-sun_Latn', + 'ilo_Latn-war_Latn', 'ind_Latn-ace_Latn', 'ind_Latn-ban_Latn', + 'ind_Latn-jav_Latn', 'ind_Latn-khm_Khmr', 'ind_Latn-lao_Laoo', + 'ind_Latn-min_Latn', 'ind_Latn-mya_Mymr', 'ind_Latn-shn_Mymr', + 'ind_Latn-sun_Latn', 'jav_Latn-min_Latn', 'jav_Latn-mri_Latn', + 'jav_Latn-pag_Latn', 'jav_Latn-plt_Latn', 'jav_Latn-smo_Latn', + 'jav_Latn-sun_Latn', 'jav_Latn-war_Latn', 'kam_Latn-gaz_Latn', + 'kam_Latn-kik_Latn', 'kam_Latn-kin_Latn', 'kam_Latn-kmb_Latn', + 'kam_Latn-knc_Arab', 'kam_Latn-knc_Latn', 'kam_Latn-kon_Latn', + 'kam_Latn-lin_Latn', 'kam_Latn-lua_Latn', 'kam_Latn-lug_Latn', + 'kam_Latn-luo_Latn', 'kam_Latn-nso_Latn', 'kam_Latn-nus_Latn', + 'kam_Latn-nya_Latn', 'kam_Latn-run_Latn', 'kam_Latn-sna_Latn', + 'kam_Latn-som_Latn', 'kam_Latn-sot_Latn', 'kam_Latn-ssw_Latn', + 'kam_Latn-swh_Latn', 'kam_Latn-tir_Ethi', 'kam_Latn-tsn_Latn', + 'kam_Latn-tso_Latn', 'kam_Latn-tum_Latn', 'kam_Latn-twi_Latn', + 'kam_Latn-umb_Latn', 'kam_Latn-wol_Latn', 'kam_Latn-xho_Latn', + 'kam_Latn-yor_Latn', 'kam_Latn-zul_Latn', 'kan_Knda-kas_Arab', + 'kan_Knda-kas_Deva', 'kan_Knda-mag_Deva', 'kan_Knda-mai_Deva', + 'kan_Knda-mal_Mlym', 'kan_Knda-mar_Deva', 'kan_Knda-npi_Deva', + 'kan_Knda-ory_Orya', 'kan_Knda-pan_Guru', 'kan_Knda-san_Deva', + 'kan_Knda-sat_Beng', 'kan_Knda-sin_Sinh', 'kan_Knda-snd_Arab', + 'kan_Knda-tam_Taml', 'kan_Knda-tel_Telu', 'kan_Knda-urd_Arab', + 'kas_Arab-kas_Deva', 'kas_Arab-mag_Deva', 'kas_Arab-mai_Deva', + 'kas_Arab-mal_Mlym', 'kas_Arab-mar_Deva', 'kas_Arab-npi_Deva', + 'kas_Arab-ory_Orya', 'kas_Arab-pan_Guru', 'kas_Arab-san_Deva', + 'kas_Arab-sat_Beng', 'kas_Arab-sin_Sinh', 'kas_Arab-snd_Arab', + 'kas_Arab-tam_Taml', 'kas_Arab-tel_Telu', 'kas_Arab-urd_Arab', + 'kas_Deva-mag_Deva', 'kas_Deva-mai_Deva', 'kas_Deva-mal_Mlym', + 'kas_Deva-mar_Deva', 'kas_Deva-npi_Deva', 'kas_Deva-ory_Orya', + 'kas_Deva-pan_Guru', 'kas_Deva-san_Deva', 'kas_Deva-sat_Beng', + 'kas_Deva-sin_Sinh', 'kas_Deva-snd_Arab', 'kas_Deva-tam_Taml', + 'kas_Deva-tel_Telu', 'kas_Deva-urd_Arab', 'kat_Geor-rus_Cyrl', + 'kea_Latn-por_Latn', 'kik_Latn-gaz_Latn', 'kik_Latn-kin_Latn', + 'kik_Latn-kmb_Latn', 'kik_Latn-kon_Latn', 'kik_Latn-lin_Latn', + 'kik_Latn-lua_Latn', 'kik_Latn-lug_Latn', 'kik_Latn-luo_Latn', + 'kik_Latn-nso_Latn', 'kik_Latn-nus_Latn', 'kik_Latn-nya_Latn', + 'kik_Latn-run_Latn', 'kik_Latn-sna_Latn', 'kik_Latn-som_Latn', + 'kik_Latn-sot_Latn', 'kik_Latn-ssw_Latn', 'kik_Latn-swh_Latn', + 'kik_Latn-tir_Ethi', 'kik_Latn-tsn_Latn', 'kik_Latn-tso_Latn', + 'kik_Latn-tum_Latn', 'kik_Latn-twi_Latn', 'kik_Latn-umb_Latn', + 'kik_Latn-wol_Latn', 'kik_Latn-xho_Latn', 'kik_Latn-yor_Latn', + 'kik_Latn-zul_Latn', 'kin_Latn-gaz_Latn', 'kin_Latn-kmb_Latn', + 'kin_Latn-kon_Latn', 'kin_Latn-lin_Latn', 'kin_Latn-lua_Latn', + 'kin_Latn-lug_Latn', 'kin_Latn-luo_Latn', 'kin_Latn-nso_Latn', + 'kin_Latn-nus_Latn', 'kin_Latn-nya_Latn', 'kin_Latn-run_Latn', + 'kin_Latn-sna_Latn', 'kin_Latn-som_Latn', 'kin_Latn-sot_Latn', + 'kin_Latn-ssw_Latn', 'kin_Latn-swh_Latn', 'kin_Latn-tir_Ethi', + 'kin_Latn-tsn_Latn', 'kin_Latn-tso_Latn', 'kin_Latn-tum_Latn', + 'kin_Latn-twi_Latn', 'kin_Latn-umb_Latn', 'kin_Latn-wol_Latn', + 'kin_Latn-xho_Latn', 'kin_Latn-yor_Latn', 'kin_Latn-zul_Latn', + 'kir_Cyrl-rus_Cyrl', 'kir_Cyrl-tat_Cyrl', 'kir_Cyrl-tuk_Latn', + 'kir_Cyrl-uig_Arab', 'kir_Cyrl-uzn_Latn', 'kmb_Latn-gaz_Latn', + 'kmb_Latn-kon_Latn', 'kmb_Latn-lin_Latn', 'kmb_Latn-lua_Latn', + 'kmb_Latn-lug_Latn', 'kmb_Latn-luo_Latn', 'kmb_Latn-nso_Latn', + 'kmb_Latn-nus_Latn', 'kmb_Latn-nya_Latn', 'kmb_Latn-por_Latn', + 'kmb_Latn-run_Latn', 'kmb_Latn-sna_Latn', 'kmb_Latn-som_Latn', + 'kmb_Latn-sot_Latn', 'kmb_Latn-ssw_Latn', 'kmb_Latn-swh_Latn', + 'kmb_Latn-tir_Ethi', 'kmb_Latn-tsn_Latn', 'kmb_Latn-tso_Latn', + 'kmb_Latn-tum_Latn', 'kmb_Latn-twi_Latn', 'kmb_Latn-umb_Latn', + 'kmb_Latn-wol_Latn', 'kmb_Latn-xho_Latn', 'kmb_Latn-yor_Latn', + 'kmb_Latn-zul_Latn', 'kmr_Latn-pbt_Arab', 'kmr_Latn-prs_Arab', + 'kmr_Latn-tgk_Cyrl', 'knc_Arab-gaz_Latn', 'knc_Arab-kik_Latn', + 'knc_Arab-kin_Latn', 'knc_Arab-kmb_Latn', 'knc_Arab-knc_Latn', + 'knc_Arab-kon_Latn', 'knc_Arab-lin_Latn', 'knc_Arab-lua_Latn', + 'knc_Arab-lug_Latn', 'knc_Arab-luo_Latn', 'knc_Arab-nso_Latn', + 'knc_Arab-nus_Latn', 'knc_Arab-nya_Latn', 'knc_Arab-run_Latn', + 'knc_Arab-sna_Latn', 'knc_Arab-som_Latn', 'knc_Arab-sot_Latn', + 'knc_Arab-ssw_Latn', 'knc_Arab-swh_Latn', 'knc_Arab-tir_Ethi', + 'knc_Arab-tsn_Latn', 'knc_Arab-tso_Latn', 'knc_Arab-tum_Latn', + 'knc_Arab-twi_Latn', 'knc_Arab-umb_Latn', 'knc_Arab-wol_Latn', + 'knc_Arab-xho_Latn', 'knc_Arab-yor_Latn', 'knc_Arab-zul_Latn', + 'knc_Latn-gaz_Latn', 'knc_Latn-kik_Latn', 'knc_Latn-kin_Latn', + 'knc_Latn-kmb_Latn', 'knc_Latn-kon_Latn', 'knc_Latn-lin_Latn', + 'knc_Latn-lua_Latn', 'knc_Latn-lug_Latn', 'knc_Latn-luo_Latn', + 'knc_Latn-nso_Latn', 'knc_Latn-nus_Latn', 'knc_Latn-nya_Latn', + 'knc_Latn-run_Latn', 'knc_Latn-sna_Latn', 'knc_Latn-som_Latn', + 'knc_Latn-sot_Latn', 'knc_Latn-ssw_Latn', 'knc_Latn-swh_Latn', + 'knc_Latn-tir_Ethi', 'knc_Latn-tsn_Latn', 'knc_Latn-tso_Latn', + 'knc_Latn-tum_Latn', 'knc_Latn-twi_Latn', 'knc_Latn-umb_Latn', + 'knc_Latn-wol_Latn', 'knc_Latn-xho_Latn', 'knc_Latn-yor_Latn', + 'knc_Latn-zul_Latn', 'kon_Latn-gaz_Latn', 'kon_Latn-lin_Latn', + 'kon_Latn-lua_Latn', 'kon_Latn-lug_Latn', 'kon_Latn-luo_Latn', + 'kon_Latn-nso_Latn', 'kon_Latn-nus_Latn', 'kon_Latn-nya_Latn', + 'kon_Latn-run_Latn', 'kon_Latn-sna_Latn', 'kon_Latn-som_Latn', + 'kon_Latn-sot_Latn', 'kon_Latn-ssw_Latn', 'kon_Latn-swh_Latn', + 'kon_Latn-tir_Ethi', 'kon_Latn-tsn_Latn', 'kon_Latn-tso_Latn', + 'kon_Latn-tum_Latn', 'kon_Latn-twi_Latn', 'kon_Latn-umb_Latn', + 'kon_Latn-wol_Latn', 'kon_Latn-xho_Latn', 'kon_Latn-yor_Latn', + 'kon_Latn-zul_Latn', 'lao_Laoo-rus_Cyrl', 'lin_Latn-gaz_Latn', + 'lin_Latn-lua_Latn', 'lin_Latn-lug_Latn', 'lin_Latn-luo_Latn', + 'lin_Latn-nso_Latn', 'lin_Latn-nus_Latn', 'lin_Latn-nya_Latn', + 'lin_Latn-run_Latn', 'lin_Latn-sna_Latn', 'lin_Latn-som_Latn', + 'lin_Latn-sot_Latn', 'lin_Latn-ssw_Latn', 'lin_Latn-swh_Latn', + 'lin_Latn-tir_Ethi', 'lin_Latn-tsn_Latn', 'lin_Latn-tso_Latn', + 'lin_Latn-tum_Latn', 'lin_Latn-twi_Latn', 'lin_Latn-umb_Latn', + 'lin_Latn-wol_Latn', 'lin_Latn-xho_Latn', 'lin_Latn-yor_Latn', + 'lin_Latn-zul_Latn', 'ltg_Latn-rus_Cyrl', 'lua_Latn-gaz_Latn', + 'lua_Latn-lug_Latn', 'lua_Latn-luo_Latn', 'lua_Latn-nso_Latn', + 'lua_Latn-nus_Latn', 'lua_Latn-nya_Latn', 'lua_Latn-run_Latn', + 'lua_Latn-sna_Latn', 'lua_Latn-som_Latn', 'lua_Latn-sot_Latn', + 'lua_Latn-ssw_Latn', 'lua_Latn-swh_Latn', 'lua_Latn-tir_Ethi', + 'lua_Latn-tsn_Latn', 'lua_Latn-tso_Latn', 'lua_Latn-tum_Latn', + 'lua_Latn-twi_Latn', 'lua_Latn-umb_Latn', 'lua_Latn-wol_Latn', + 'lua_Latn-xho_Latn', 'lua_Latn-yor_Latn', 'lua_Latn-zul_Latn', + 'lug_Latn-gaz_Latn', 'lug_Latn-luo_Latn', 'lug_Latn-nso_Latn', + 'lug_Latn-nus_Latn', 'lug_Latn-nya_Latn', 'lug_Latn-run_Latn', + 'lug_Latn-sna_Latn', 'lug_Latn-som_Latn', 'lug_Latn-sot_Latn', + 'lug_Latn-ssw_Latn', 'lug_Latn-swh_Latn', 'lug_Latn-tir_Ethi', + 'lug_Latn-tsn_Latn', 'lug_Latn-tso_Latn', 'lug_Latn-tum_Latn', + 'lug_Latn-twi_Latn', 'lug_Latn-umb_Latn', 'lug_Latn-wol_Latn', + 'lug_Latn-xho_Latn', 'lug_Latn-yor_Latn', 'lug_Latn-zul_Latn', + 'luo_Latn-gaz_Latn', 'luo_Latn-nso_Latn', 'luo_Latn-nus_Latn', + 'luo_Latn-nya_Latn', 'luo_Latn-run_Latn', 'luo_Latn-sna_Latn', + 'luo_Latn-som_Latn', 'luo_Latn-sot_Latn', 'luo_Latn-ssw_Latn', + 'luo_Latn-swh_Latn', 'luo_Latn-tir_Ethi', 'luo_Latn-tsn_Latn', + 'luo_Latn-tso_Latn', 'luo_Latn-tum_Latn', 'luo_Latn-twi_Latn', + 'luo_Latn-umb_Latn', 'luo_Latn-wol_Latn', 'luo_Latn-xho_Latn', + 'luo_Latn-yor_Latn', 'luo_Latn-zul_Latn', 'mag_Deva-mai_Deva', + 'mag_Deva-mal_Mlym', 'mag_Deva-mar_Deva', 'mag_Deva-npi_Deva', + 'mag_Deva-ory_Orya', 'mag_Deva-pan_Guru', 'mag_Deva-san_Deva', + 'mag_Deva-sat_Beng', 'mag_Deva-sin_Sinh', 'mag_Deva-snd_Arab', + 'mag_Deva-tam_Taml', 'mag_Deva-tel_Telu', 'mag_Deva-urd_Arab', + 'mai_Deva-mal_Mlym', 'mai_Deva-mar_Deva', 'mai_Deva-npi_Deva', + 'mai_Deva-ory_Orya', 'mai_Deva-pan_Guru', 'mai_Deva-san_Deva', + 'mai_Deva-sat_Beng', 'mai_Deva-sin_Sinh', 'mai_Deva-snd_Arab', + 'mai_Deva-tam_Taml', 'mai_Deva-tel_Telu', 'mai_Deva-urd_Arab', + 'mal_Mlym-mar_Deva', 'mal_Mlym-npi_Deva', 'mal_Mlym-ory_Orya', + 'mal_Mlym-pan_Guru', 'mal_Mlym-san_Deva', 'mal_Mlym-sat_Beng', + 'mal_Mlym-sin_Sinh', 'mal_Mlym-snd_Arab', 'mal_Mlym-tam_Taml', + 'mal_Mlym-tel_Telu', 'mal_Mlym-urd_Arab', 'mar_Deva-npi_Deva', + 'mar_Deva-ory_Orya', 'mar_Deva-pan_Guru', 'mar_Deva-san_Deva', + 'mar_Deva-sat_Beng', 'mar_Deva-sin_Sinh', 'mar_Deva-snd_Arab', + 'mar_Deva-tam_Taml', 'mar_Deva-tel_Telu', 'mar_Deva-urd_Arab', + 'min_Latn-mri_Latn', 'min_Latn-pag_Latn', 'min_Latn-plt_Latn', + 'min_Latn-smo_Latn', 'min_Latn-sun_Latn', 'min_Latn-war_Latn', + 'mri_Latn-pag_Latn', 'mri_Latn-smo_Latn', 'mri_Latn-sun_Latn', + 'mri_Latn-war_Latn', 'npi_Deva-ory_Orya', 'npi_Deva-pan_Guru', + 'npi_Deva-san_Deva', 'npi_Deva-sat_Beng', 'npi_Deva-sin_Sinh', + 'npi_Deva-snd_Arab', 'npi_Deva-tam_Taml', 'npi_Deva-tel_Telu', + 'npi_Deva-urd_Arab', 'nso_Latn-gaz_Latn', 'nso_Latn-nus_Latn', + 'nso_Latn-nya_Latn', 'nso_Latn-run_Latn', 'nso_Latn-sna_Latn', + 'nso_Latn-som_Latn', 'nso_Latn-sot_Latn', 'nso_Latn-ssw_Latn', + 'nso_Latn-swh_Latn', 'nso_Latn-tir_Ethi', 'nso_Latn-tsn_Latn', + 'nso_Latn-tso_Latn', 'nso_Latn-tum_Latn', 'nso_Latn-twi_Latn', + 'nso_Latn-umb_Latn', 'nso_Latn-wol_Latn', 'nso_Latn-xho_Latn', + 'nso_Latn-yor_Latn', 'nso_Latn-zul_Latn', 'nus_Latn-gaz_Latn', + 'nus_Latn-nya_Latn', 'nus_Latn-run_Latn', 'nus_Latn-sna_Latn', + 'nus_Latn-som_Latn', 'nus_Latn-sot_Latn', 'nus_Latn-ssw_Latn', + 'nus_Latn-swh_Latn', 'nus_Latn-tir_Ethi', 'nus_Latn-tsn_Latn', + 'nus_Latn-tso_Latn', 'nus_Latn-tum_Latn', 'nus_Latn-twi_Latn', + 'nus_Latn-umb_Latn', 'nus_Latn-wol_Latn', 'nus_Latn-xho_Latn', + 'nus_Latn-yor_Latn', 'nus_Latn-zul_Latn', 'nya_Latn-gaz_Latn', + 'nya_Latn-run_Latn', 'nya_Latn-sna_Latn', 'nya_Latn-som_Latn', + 'nya_Latn-sot_Latn', 'nya_Latn-ssw_Latn', 'nya_Latn-swh_Latn', + 'nya_Latn-tir_Ethi', 'nya_Latn-tsn_Latn', 'nya_Latn-tso_Latn', + 'nya_Latn-tum_Latn', 'nya_Latn-twi_Latn', 'nya_Latn-umb_Latn', + 'nya_Latn-wol_Latn', 'nya_Latn-xho_Latn', 'nya_Latn-yor_Latn', + 'nya_Latn-zul_Latn', 'oci_Latn-por_Latn', 'ory_Orya-pan_Guru', + 'ory_Orya-san_Deva', 'ory_Orya-sat_Beng', 'ory_Orya-sin_Sinh', + 'ory_Orya-snd_Arab', 'ory_Orya-tam_Taml', 'ory_Orya-tel_Telu', + 'ory_Orya-urd_Arab', 'pag_Latn-smo_Latn', 'pag_Latn-sun_Latn', + 'pan_Guru-san_Deva', 'pan_Guru-sat_Beng', 'pan_Guru-sin_Sinh', + 'pan_Guru-snd_Arab', 'pan_Guru-tam_Taml', 'pan_Guru-tel_Telu', + 'pan_Guru-urd_Arab', 'pbt_Arab-tam_Taml', 'pbt_Arab-tgk_Cyrl', + 'plt_Latn-mri_Latn', 'plt_Latn-pag_Latn', 'plt_Latn-smo_Latn', + 'plt_Latn-sun_Latn', 'plt_Latn-war_Latn', 'por_Latn-ayr_Latn', + 'por_Latn-quy_Latn', 'prs_Arab-pbt_Arab', 'prs_Arab-tgk_Cyrl', + 'quy_Latn-spa_Latn', 'run_Latn-sna_Latn', 'run_Latn-som_Latn', + 'run_Latn-sot_Latn', 'run_Latn-ssw_Latn', 'run_Latn-swh_Latn', + 'run_Latn-tir_Ethi', 'run_Latn-tsn_Latn', 'run_Latn-tso_Latn', + 'run_Latn-tum_Latn', 'run_Latn-twi_Latn', 'run_Latn-umb_Latn', + 'run_Latn-wol_Latn', 'run_Latn-xho_Latn', 'run_Latn-yor_Latn', + 'run_Latn-zul_Latn', 'rus_Cyrl-tat_Cyrl', 'rus_Cyrl-tgk_Cyrl', + 'san_Deva-sat_Beng', 'san_Deva-sin_Sinh', 'san_Deva-snd_Arab', + 'san_Deva-tam_Taml', 'san_Deva-tel_Telu', 'san_Deva-urd_Arab', + 'sat_Beng-sin_Sinh', 'sat_Beng-snd_Arab', 'sat_Beng-tam_Taml', + 'sat_Beng-tel_Telu', 'sat_Beng-urd_Arab', 'sin_Sinh-snd_Arab', + 'sin_Sinh-tam_Taml', 'sin_Sinh-tel_Telu', 'sin_Sinh-urd_Arab', + 'smo_Latn-sun_Latn', 'smo_Latn-war_Latn', 'sna_Latn-som_Latn', + 'sna_Latn-sot_Latn', 'sna_Latn-ssw_Latn', 'sna_Latn-swh_Latn', + 'sna_Latn-tir_Ethi', 'sna_Latn-tsn_Latn', 'sna_Latn-tso_Latn', + 'sna_Latn-tum_Latn', 'sna_Latn-twi_Latn', 'sna_Latn-umb_Latn', + 'sna_Latn-wol_Latn', 'sna_Latn-xho_Latn', 'sna_Latn-yor_Latn', + 'sna_Latn-zul_Latn', 'snd_Arab-tam_Taml', 'snd_Arab-tel_Telu', + 'snd_Arab-urd_Arab', 'som_Latn-sot_Latn', 'som_Latn-ssw_Latn', + 'som_Latn-swh_Latn', 'som_Latn-tir_Ethi', 'som_Latn-tsn_Latn', + 'som_Latn-tso_Latn', 'som_Latn-tum_Latn', 'som_Latn-twi_Latn', + 'som_Latn-umb_Latn', 'som_Latn-wol_Latn', 'som_Latn-xho_Latn', + 'som_Latn-yor_Latn', 'som_Latn-zul_Latn', 'sot_Latn-ssw_Latn', + 'sot_Latn-swh_Latn', 'sot_Latn-tir_Ethi', 'sot_Latn-tsn_Latn', + 'sot_Latn-tso_Latn', 'sot_Latn-tum_Latn', 'sot_Latn-twi_Latn', + 'sot_Latn-umb_Latn', 'sot_Latn-wol_Latn', 'sot_Latn-xho_Latn', + 'sot_Latn-yor_Latn', 'sot_Latn-zul_Latn', 'ssw_Latn-swh_Latn', + 'ssw_Latn-tir_Ethi', 'ssw_Latn-tsn_Latn', 'ssw_Latn-tso_Latn', + 'ssw_Latn-tum_Latn', 'ssw_Latn-twi_Latn', 'ssw_Latn-umb_Latn', + 'ssw_Latn-wol_Latn', 'ssw_Latn-xho_Latn', 'ssw_Latn-yor_Latn', + 'ssw_Latn-zul_Latn', 'sun_Latn-war_Latn', 'swh_Latn-tir_Ethi', + 'swh_Latn-tsn_Latn', 'swh_Latn-tso_Latn', 'swh_Latn-tum_Latn', + 'swh_Latn-twi_Latn', 'swh_Latn-umb_Latn', 'swh_Latn-wol_Latn', + 'swh_Latn-xho_Latn', 'swh_Latn-yor_Latn', 'swh_Latn-zul_Latn', + 'tam_Taml-tel_Telu', 'tam_Taml-urd_Arab', 'tat_Cyrl-tuk_Latn', + 'tat_Cyrl-uig_Arab', 'tat_Cyrl-uzn_Latn', 'tel_Telu-urd_Arab', + 'tir_Ethi-tsn_Latn', 'tir_Ethi-tso_Latn', 'tir_Ethi-tum_Latn', + 'tir_Ethi-twi_Latn', 'tir_Ethi-umb_Latn', 'tir_Ethi-wol_Latn', + 'tir_Ethi-xho_Latn', 'tir_Ethi-yor_Latn', 'tir_Ethi-zul_Latn', + 'tsn_Latn-tso_Latn', 'tsn_Latn-tum_Latn', 'tsn_Latn-twi_Latn', + 'tsn_Latn-umb_Latn', 'tsn_Latn-wol_Latn', 'tsn_Latn-xho_Latn', + 'tsn_Latn-yor_Latn', 'tsn_Latn-zul_Latn', 'tso_Latn-tum_Latn', + 'tso_Latn-twi_Latn', 'tso_Latn-umb_Latn', 'tso_Latn-wol_Latn', + 'tso_Latn-xho_Latn', 'tso_Latn-yor_Latn', 'tso_Latn-zul_Latn', + 'tuk_Latn-uig_Arab', 'tuk_Latn-uzn_Latn', 'tum_Latn-twi_Latn', + 'tum_Latn-umb_Latn', 'tum_Latn-wol_Latn', 'tum_Latn-xho_Latn', + 'tum_Latn-yor_Latn', 'tum_Latn-zul_Latn', 'twi_Latn-umb_Latn', + 'twi_Latn-wol_Latn', 'twi_Latn-xho_Latn', 'twi_Latn-yor_Latn', + 'twi_Latn-zul_Latn', 'uig_Arab-uzn_Latn', 'umb_Latn-wol_Latn', + 'umb_Latn-xho_Latn', 'umb_Latn-yor_Latn', 'umb_Latn-zul_Latn', + 'wol_Latn-xho_Latn', 'wol_Latn-yor_Latn', 'wol_Latn-zul_Latn', + 'xho_Latn-yor_Latn', 'xho_Latn-zul_Latn', 'yor_Latn-zul_Latn' + ] + subset = subset[:self.subset_count] + for subset_name in subset: + self.download_subset('nllb', subset_name) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_download_universal_dependencies(self): + subset = [ + 'af_afribooms', 'akk_pisandub', 'akk_riao', 'aqz_tudet', 'sq_tsa', + 'am_att', 'grc_perseus', 'grc_proiel', 'apu_ufpa', 'ar_nyuad', + 'ar_padt', 'ar_pud', 'hy_armtdp', 'aii_as', 'bm_crb', 'eu_bdt', + 'be_hse', 'bho_bhtb', 'br_keb', 'bg_btb', 'bxr_bdt', 'yue_hk', + 'ca_ancora', 'zh_cfl', 'zh_gsd', 'zh_gsdsimp', 'zh_hk', 'zh_pud', + 'ckt_hse', 'lzh_kyoto', 'cop_scriptorium', 'hr_set', 'cs_cac', + 'cs_cltt', 'cs_fictree', 'cs_pdt', 'cs_pud', 'da_ddt', 'nl_alpino', + 'nl_lassysmall', 'en_esl', 'en_ewt', 'en_gum', 'en_gumreddit', + 'en_lines', 'en_partut', 'en_pronouns', 'en_pud', 'myv_jr', + 'et_edt', 'et_ewt', 'fo_farpahc', 'fo_oft', 'fi_ftb', 'fi_ood', + 'fi_pud', 'fi_tdt', 'fr_fqb', 'fr_ftb', 'fr_gsd', 'fr_partut', + 'fr_pud', 'fr_sequoia', 'fr_spoken', 'gl_ctg', 'gl_treegal', + 'de_gsd', 'de_hdt', 'de_lit', 'de_pud', 'got_proiel', 'el_gdt', + 'he_htb', 'qhe_hiencs', 'hi_hdtb', 'hi_pud', 'hu_szeged', + 'is_icepahc', 'is_pud', 'id_csui', 'id_gsd', 'id_pud', 'ga_idt', + 'it_isdt', 'it_partut', 'it_postwita', 'it_pud', 'it_twittiro', + 'it_vit', 'ja_bccwj', 'ja_gsd', 'ja_modern', 'ja_pud', 'krl_kkpp', + 'kk_ktb', 'kfm_aha', 'koi_uh', 'kpv_ikdp', 'kpv_lattice', 'ko_gsd', + 'ko_kaist', 'ko_pud', 'kmr_mg', 'la_ittb', 'la_llct', 'la_perseus', + 'la_proiel', 'lv_lvtb', 'lt_alksnis', 'lt_hse', 'olo_kkpp', + 'mt_mudt', 'gv_cadhan', 'mr_ufal', 'gun_dooley', 'gun_thomas', + 'mdf_jr', 'myu_tudet', 'pcm_nsc', 'nyq_aha', 'sme_giella', + 'no_bokmaal', 'no_nynorsk', 'no_nynorsklia', 'cu_proiel', + 'fro_srcmf', 'orv_rnc', 'orv_torot', 'otk_tonqq', 'fa_perdt', + 'fa_seraji', 'pl_lfg', 'pl_pdb', 'pl_pud', 'pt_bosque', 'pt_gsd', + 'pt_pud', 'ro_nonstandard', 'ro_rrt', 'ro_simonero', 'ru_gsd', + 'ru_pud', 'ru_syntagrus', 'ru_taiga', 'sa_ufal', 'sa_vedic', + 'gd_arcosg', 'sr_set', 'sms_giellagas', 'sk_snk', 'sl_ssj', + 'sl_sst', 'soj_aha', 'ajp_madar', 'es_ancora', 'es_gsd', 'es_pud', + 'swl_sslc', 'sv_lines', 'sv_pud', 'sv_talbanken', 'gsw_uzh', + 'tl_trg', 'tl_ugnayan', 'ta_mwtt', 'ta_ttb', 'te_mtg', 'th_pud', + 'tpn_tudet', 'qtd_sagt', 'tr_boun', 'tr_gb', 'tr_imst', 'tr_pud', + 'uk_iu', 'hsb_ufal', 'ur_udtb', 'ug_udt', 'vi_vtb', 'wbp_ufal', + 'cy_ccg', 'wo_wtb', 'yo_ytb' + ] + subset = subset[:self.subset_count] + for subset_name in subset: + self.download_subset('universal_dependencies', subset_name) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_download_imdb(self): + dataset = MsDataset.load('imdb') + if isinstance(dataset, MsDataset): + lens = len(dataset) + print(f'dataset imdb len: {lens}') + self.assertTrue(lens > 0) + else: + assert isinstance(dataset, dict) + lens = {key: len(subset) for key, subset in dataset.items()} + print(f'dataset imdb len: {lens}') + self.assertTrue(all([_len > 0 for _len in lens.values()])) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_download_clue(self): + subset = [ + 'afqmc', 'tnews', 'iflytek', 'cmnli', 'cluewsc2020', 'csl', + 'cmrc2018', 'drcd', 'chid', 'c3', 'ocnli', 'diagnostics' + ] + for subset_name in subset: + self.download_subset('clue', subset_name) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_download_wikitext(self): + subset = [ + 'wikitext-103-v1', 'wikitext-2-v1', 'wikitext-103-raw-v1', + 'wikitext-2-raw-v1' + ] + for subset_name in subset: + self.download_subset('wikitext', subset_name) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_download_xnli(self): + subset = [ + 'XNLI', 'tydiqa', 'SQuAD', 'PAN-X.af', 'PAN-X.ar', 'PAN-X.bg', + 'PAN-X.bn', 'PAN-X.de', 'PAN-X.el', 'PAN-X.en', 'PAN-X.es', + 'PAN-X.et', 'PAN-X.eu', 'PAN-X.fa', 'PAN-X.fi', 'PAN-X.fr', + 'PAN-X.he', 'PAN-X.hi', 'PAN-X.hu', 'PAN-X.id', 'PAN-X.it', + 'PAN-X.ja', 'PAN-X.jv', 'PAN-X.ka', 'PAN-X.kk', 'PAN-X.ko', + 'PAN-X.ml', 'PAN-X.mr', 'PAN-X.ms', 'PAN-X.my', 'PAN-X.nl', + 'PAN-X.pt', 'PAN-X.ru', 'PAN-X.sw', 'PAN-X.ta', 'PAN-X.te', + 'PAN-X.th', 'PAN-X.tl', 'PAN-X.tr', 'PAN-X.ur', 'PAN-X.vi', + 'PAN-X.yo', 'PAN-X.zh', 'MLQA.ar.ar', 'MLQA.ar.de', 'MLQA.ar.vi', + 'MLQA.ar.zh', 'MLQA.ar.en', 'MLQA.ar.es', 'MLQA.ar.hi', + 'MLQA.de.ar', 'MLQA.de.de', 'MLQA.de.vi', 'MLQA.de.zh', + 'MLQA.de.en', 'MLQA.de.es', 'MLQA.de.hi', 'MLQA.vi.ar', + 'MLQA.vi.de', 'MLQA.vi.vi', 'MLQA.vi.zh', 'MLQA.vi.en', + 'MLQA.vi.es', 'MLQA.vi.hi', 'MLQA.zh.ar', 'MLQA.zh.de', + 'MLQA.zh.vi', 'MLQA.zh.zh', 'MLQA.zh.en', 'MLQA.zh.es', + 'MLQA.zh.hi', 'MLQA.en.ar', 'MLQA.en.de', 'MLQA.en.vi', + 'MLQA.en.zh', 'MLQA.en.en', 'MLQA.en.es', 'MLQA.en.hi', + 'MLQA.es.ar', 'MLQA.es.de', 'MLQA.es.vi', 'MLQA.es.zh', + 'MLQA.es.en', 'MLQA.es.es', 'MLQA.es.hi', 'MLQA.hi.ar', + 'MLQA.hi.de', 'MLQA.hi.vi', 'MLQA.hi.zh', 'MLQA.hi.en', + 'MLQA.hi.es', 'MLQA.hi.hi', 'XQuAD.ar', 'XQuAD.de', 'XQuAD.vi', + 'XQuAD.zh', 'XQuAD.en', 'XQuAD.es', 'XQuAD.hi', 'XQuAD.el', + 'XQuAD.ru', 'XQuAD.th', 'XQuAD.tr', 'bucc18.de', 'bucc18.fr', + 'bucc18.zh', 'bucc18.ru', 'PAWS-X.de', 'PAWS-X.en', 'PAWS-X.es', + 'PAWS-X.fr', 'PAWS-X.ja', 'PAWS-X.ko', 'PAWS-X.zh', 'tatoeba.afr', + 'tatoeba.ara', 'tatoeba.ben', 'tatoeba.bul', 'tatoeba.deu', + 'tatoeba.cmn', 'tatoeba.ell', 'tatoeba.est', 'tatoeba.eus', + 'tatoeba.fin', 'tatoeba.fra', 'tatoeba.heb', 'tatoeba.hin', + 'tatoeba.hun', 'tatoeba.ind', 'tatoeba.ita', 'tatoeba.jav', + 'tatoeba.jpn', 'tatoeba.kat', 'tatoeba.kaz', 'tatoeba.kor', + 'tatoeba.mal', 'tatoeba.mar', 'tatoeba.nld', 'tatoeba.pes', + 'tatoeba.por', 'tatoeba.rus', 'tatoeba.spa', 'tatoeba.swh', + 'tatoeba.tam', 'tatoeba.tel', 'tatoeba.tgl', 'tatoeba.tha', + 'tatoeba.tur', 'tatoeba.urd', 'tatoeba.vie', 'udpos.Afrikaans', + 'udpos.Arabic', 'udpos.Basque', 'udpos.Bulgarian', 'udpos.Dutch', + 'udpos.English', 'udpos.Estonian', 'udpos.Finnish', 'udpos.French', + 'udpos.German', 'udpos.Greek', 'udpos.Hebrew', 'udpos.Hindi', + 'udpos.Hungarian', 'udpos.Indonesian', 'udpos.Italian', + 'udpos.Japanese', 'udpos.Kazakh', 'udpos.Korean', 'udpos.Chinese', + 'udpos.Marathi', 'udpos.Persian', 'udpos.Portuguese', + 'udpos.Russian', 'udpos.Spanish', 'udpos.Tagalog', 'udpos.Tamil', + 'udpos.Telugu', 'udpos.Thai', 'udpos.Turkish', 'udpos.Urdu', + 'udpos.Vietnamese', 'udpos.Yoruba' + ] + subset = subset[:self.subset_count] + for subset_name in subset: + self.download_subset('xtreme', subset_name) diff --git a/tests/hub/test_hub_operation.py b/tests/hub/test_hub_operation.py index c96db986..828b97f8 100644 --- a/tests/hub/test_hub_operation.py +++ b/tests/hub/test_hub_operation.py @@ -25,10 +25,10 @@ class HubOperationTest(unittest.TestCase): def setUp(self): self.api = HubApi() - # note this is temporary before official account management is ready self.api.login(TEST_ACCESS_TOKEN1) - self.model_name = uuid.uuid4().hex + self.model_name = 'op-%s' % (uuid.uuid4().hex) self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) + self.revision = 'v0.1_test_revision' self.api.create_model( model_id=self.model_id, visibility=ModelVisibility.PUBLIC, @@ -46,6 +46,7 @@ class HubOperationTest(unittest.TestCase): os.system("echo 'testtest'>%s" % os.path.join(self.model_dir, download_model_file_name)) repo.push('add model') + repo.tag_and_push(self.revision, 'Test revision') def test_model_repo_creation(self): # change to proper model names before use @@ -61,7 +62,9 @@ class HubOperationTest(unittest.TestCase): def test_download_single_file(self): self.prepare_case() downloaded_file = model_file_download( - model_id=self.model_id, file_path=download_model_file_name) + model_id=self.model_id, + file_path=download_model_file_name, + revision=self.revision) assert os.path.exists(downloaded_file) mdtime1 = os.path.getmtime(downloaded_file) # download again @@ -78,17 +81,16 @@ class HubOperationTest(unittest.TestCase): assert os.path.exists(downloaded_file_path) mdtime1 = os.path.getmtime(downloaded_file_path) # download again - snapshot_path = snapshot_download(model_id=self.model_id) + snapshot_path = snapshot_download( + model_id=self.model_id, revision=self.revision) mdtime2 = os.path.getmtime(downloaded_file_path) assert mdtime1 == mdtime2 - model_file_download( - model_id=self.model_id, - file_path=download_model_file_name) # not add counter def test_download_public_without_login(self): self.prepare_case() rmtree(ModelScopeConfig.path_credential) - snapshot_path = snapshot_download(model_id=self.model_id) + snapshot_path = snapshot_download( + model_id=self.model_id, revision=self.revision) downloaded_file_path = os.path.join(snapshot_path, download_model_file_name) assert os.path.exists(downloaded_file_path) @@ -96,26 +98,38 @@ class HubOperationTest(unittest.TestCase): downloaded_file = model_file_download( model_id=self.model_id, file_path=download_model_file_name, + revision=self.revision, cache_dir=temporary_dir) assert os.path.exists(downloaded_file) self.api.login(TEST_ACCESS_TOKEN1) def test_snapshot_delete_download_cache_file(self): self.prepare_case() - snapshot_path = snapshot_download(model_id=self.model_id) + snapshot_path = snapshot_download( + model_id=self.model_id, revision=self.revision) downloaded_file_path = os.path.join(snapshot_path, download_model_file_name) assert os.path.exists(downloaded_file_path) os.remove(downloaded_file_path) # download again in cache file_download_path = model_file_download( - model_id=self.model_id, file_path=ModelFile.README) + model_id=self.model_id, + file_path=ModelFile.README, + revision=self.revision) assert os.path.exists(file_download_path) # deleted file need download again file_download_path = model_file_download( - model_id=self.model_id, file_path=download_model_file_name) + model_id=self.model_id, + file_path=download_model_file_name, + revision=self.revision) assert os.path.exists(file_download_path) + def test_snapshot_download_default_revision(self): + pass # TOTO + + def test_file_download_default_revision(self): + pass # TODO + def get_model_download_times(self): url = f'{self.api.endpoint}/api/v1/models/{self.model_id}/downloads' cookies = ModelScopeConfig.get_cookies() @@ -127,7 +141,7 @@ class HubOperationTest(unittest.TestCase): return None def test_list_model(self): - data = self.api.list_model(TEST_MODEL_ORG) + data = self.api.list_models(TEST_MODEL_ORG) assert len(data['Models']) >= 1 diff --git a/tests/hub/test_hub_private_files.py b/tests/hub/test_hub_private_files.py index d19a7c64..73c4cca3 100644 --- a/tests/hub/test_hub_private_files.py +++ b/tests/hub/test_hub_private_files.py @@ -17,23 +17,34 @@ from .test_utils import (TEST_ACCESS_TOKEN1, TEST_ACCESS_TOKEN2, TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG, delete_credential) +download_model_file_name = 'test.bin' + class HubPrivateFileDownloadTest(unittest.TestCase): def setUp(self): self.old_cwd = os.getcwd() self.api = HubApi() - # note this is temporary before official account management is ready self.token, _ = self.api.login(TEST_ACCESS_TOKEN1) - self.model_name = uuid.uuid4().hex + self.model_name = 'pf-%s' % (uuid.uuid4().hex) self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) + self.revision = 'v0.1_test_revision' self.api.create_model( model_id=self.model_id, - visibility=ModelVisibility.PRIVATE, # 1-private, 5-public + visibility=ModelVisibility.PRIVATE, license=Licenses.APACHE_V2, chinese_name=TEST_MODEL_CHINESE_NAME, ) + def prepare_case(self): + temporary_dir = tempfile.mkdtemp() + self.model_dir = os.path.join(temporary_dir, self.model_name) + repo = Repository(self.model_dir, clone_from=self.model_id) + os.system("echo 'testtest'>%s" + % os.path.join(self.model_dir, download_model_file_name)) + repo.push('add model') + repo.tag_and_push(self.revision, 'Test revision') + def tearDown(self): # credential may deleted or switch login name, we need re-login here # to ensure the temporary model is deleted. @@ -42,49 +53,67 @@ class HubPrivateFileDownloadTest(unittest.TestCase): self.api.delete_model(model_id=self.model_id) def test_snapshot_download_private_model(self): - snapshot_path = snapshot_download(self.model_id) + self.prepare_case() + snapshot_path = snapshot_download(self.model_id, self.revision) assert os.path.exists(os.path.join(snapshot_path, ModelFile.README)) def test_snapshot_download_private_model_no_permission(self): + self.prepare_case() self.token, _ = self.api.login(TEST_ACCESS_TOKEN2) with self.assertRaises(HTTPError): - snapshot_download(self.model_id) + snapshot_download(self.model_id, self.revision) def test_snapshot_download_private_model_without_login(self): + self.prepare_case() delete_credential() with self.assertRaises(HTTPError): - snapshot_download(self.model_id) + snapshot_download(self.model_id, self.revision) def test_download_file_private_model(self): - file_path = model_file_download(self.model_id, ModelFile.README) + self.prepare_case() + file_path = model_file_download(self.model_id, ModelFile.README, + self.revision) assert os.path.exists(file_path) def test_download_file_private_model_no_permission(self): + self.prepare_case() self.token, _ = self.api.login(TEST_ACCESS_TOKEN2) with self.assertRaises(HTTPError): - model_file_download(self.model_id, ModelFile.README) + model_file_download(self.model_id, ModelFile.README, self.revision) def test_download_file_private_model_without_login(self): + self.prepare_case() delete_credential() with self.assertRaises(HTTPError): - model_file_download(self.model_id, ModelFile.README) + model_file_download(self.model_id, ModelFile.README, self.revision) def test_snapshot_download_local_only(self): + self.prepare_case() with self.assertRaises(ValueError): - snapshot_download(self.model_id, local_files_only=True) - snapshot_path = snapshot_download(self.model_id) + snapshot_download( + self.model_id, self.revision, local_files_only=True) + snapshot_path = snapshot_download(self.model_id, self.revision) assert os.path.exists(os.path.join(snapshot_path, ModelFile.README)) - snapshot_path = snapshot_download(self.model_id, local_files_only=True) + snapshot_path = snapshot_download( + self.model_id, self.revision, local_files_only=True) assert os.path.exists(snapshot_path) def test_file_download_local_only(self): + self.prepare_case() with self.assertRaises(ValueError): model_file_download( - self.model_id, ModelFile.README, local_files_only=True) - file_path = model_file_download(self.model_id, ModelFile.README) + self.model_id, + ModelFile.README, + self.revision, + local_files_only=True) + file_path = model_file_download(self.model_id, ModelFile.README, + self.revision) assert os.path.exists(file_path) file_path = model_file_download( - self.model_id, ModelFile.README, local_files_only=True) + self.model_id, + ModelFile.README, + revision=self.revision, + local_files_only=True) assert os.path.exists(file_path) diff --git a/tests/hub/test_hub_private_repository.py b/tests/hub/test_hub_private_repository.py index dab2b891..271a715c 100644 --- a/tests/hub/test_hub_private_repository.py +++ b/tests/hub/test_hub_private_repository.py @@ -21,13 +21,12 @@ class HubPrivateRepositoryTest(unittest.TestCase): def setUp(self): self.old_cwd = os.getcwd() self.api = HubApi() - # note this is temporary before official account management is ready self.token, _ = self.api.login(TEST_ACCESS_TOKEN1) - self.model_name = uuid.uuid4().hex + self.model_name = 'pr-%s' % (uuid.uuid4().hex) self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) self.api.create_model( model_id=self.model_id, - visibility=ModelVisibility.PRIVATE, # 1-private, 5-public + visibility=ModelVisibility.PRIVATE, license=Licenses.APACHE_V2, chinese_name=TEST_MODEL_CHINESE_NAME, ) diff --git a/tests/hub/test_hub_repository.py b/tests/hub/test_hub_repository.py index 9dfe8efd..850d5840 100644 --- a/tests/hub/test_hub_repository.py +++ b/tests/hub/test_hub_repository.py @@ -22,6 +22,7 @@ from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME, logger = get_logger() logger.setLevel('DEBUG') DEFAULT_GIT_PATH = 'git' +download_model_file_name = 'test.bin' class HubRepositoryTest(unittest.TestCase): @@ -29,13 +30,13 @@ class HubRepositoryTest(unittest.TestCase): def setUp(self): self.old_cwd = os.getcwd() self.api = HubApi() - # note this is temporary before official account management is ready self.api.login(TEST_ACCESS_TOKEN1) - self.model_name = uuid.uuid4().hex + self.model_name = 'repo-%s' % (uuid.uuid4().hex) self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) + self.revision = 'v0.1_test_revision' self.api.create_model( model_id=self.model_id, - visibility=ModelVisibility.PUBLIC, # 1-private, 5-public + visibility=ModelVisibility.PUBLIC, license=Licenses.APACHE_V2, chinese_name=TEST_MODEL_CHINESE_NAME, ) @@ -67,9 +68,10 @@ class HubRepositoryTest(unittest.TestCase): os.system("echo 'lfs'>%s" % os.path.join(self.model_dir, lfs_file1)) os.system("echo 'lfs2'>%s" % os.path.join(self.model_dir, lfs_file2)) repo.push('test') - add1 = model_file_download(self.model_id, 'add1.py') + repo.tag_and_push(self.revision, 'Test revision') + add1 = model_file_download(self.model_id, 'add1.py', self.revision) assert os.path.exists(add1) - add2 = model_file_download(self.model_id, 'add2.py') + add2 = model_file_download(self.model_id, 'add2.py', self.revision) assert os.path.exists(add2) # check lfs files. git_wrapper = GitCommandWrapper() diff --git a/tests/hub/test_hub_revision.py b/tests/hub/test_hub_revision.py new file mode 100644 index 00000000..13ec1c9a --- /dev/null +++ b/tests/hub/test_hub_revision.py @@ -0,0 +1,145 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import tempfile +import unittest +import uuid +from datetime import datetime + +from modelscope.hub.api import HubApi +from modelscope.hub.constants import Licenses, ModelVisibility +from modelscope.hub.errors import NotExistError, NoValidRevisionError +from modelscope.hub.file_download import model_file_download +from modelscope.hub.repository import Repository +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger +from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME, + TEST_MODEL_ORG) + +logger = get_logger() +logger.setLevel('DEBUG') +download_model_file_name = 'test.bin' +download_model_file_name2 = 'test2.bin' + + +class HubRevisionTest(unittest.TestCase): + + def setUp(self): + self.api = HubApi() + self.api.login(TEST_ACCESS_TOKEN1) + self.model_name = 'rv-%s' % (uuid.uuid4().hex) + self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) + self.revision = 'v0.1_test_revision' + self.revision2 = 'v0.2_test_revision' + self.api.create_model( + model_id=self.model_id, + visibility=ModelVisibility.PUBLIC, + license=Licenses.APACHE_V2, + chinese_name=TEST_MODEL_CHINESE_NAME, + ) + + def tearDown(self): + self.api.delete_model(model_id=self.model_id) + + def prepare_repo_data(self): + temporary_dir = tempfile.mkdtemp() + self.model_dir = os.path.join(temporary_dir, self.model_name) + self.repo = Repository(self.model_dir, clone_from=self.model_id) + os.system("echo 'testtest'>%s" + % os.path.join(self.model_dir, download_model_file_name)) + self.repo.push('add model') + self.repo.tag_and_push(self.revision, 'Test revision') + + def test_no_tag(self): + with self.assertRaises(NoValidRevisionError): + snapshot_download(self.model_id, None) + + with self.assertRaises(NoValidRevisionError): + model_file_download(self.model_id, ModelFile.README) + + def test_with_only_one_tag(self): + self.prepare_repo_data() + with tempfile.TemporaryDirectory() as temp_cache_dir: + snapshot_path = snapshot_download( + self.model_id, cache_dir=temp_cache_dir) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name)) + with tempfile.TemporaryDirectory() as temp_cache_dir: + file_path = model_file_download( + self.model_id, ModelFile.README, cache_dir=temp_cache_dir) + assert os.path.exists(file_path) + + def add_new_file_and_tag(self): + os.system("echo 'testtest'>%s" + % os.path.join(self.model_dir, download_model_file_name2)) + self.repo.push('add new file') + self.repo.tag_and_push(self.revision2, 'Test revision') + + def test_snapshot_download_different_revision(self): + self.prepare_repo_data() + t1 = datetime.now().isoformat(sep=' ', timespec='seconds') + logger.info('First time stamp: %s' % t1) + snapshot_path = snapshot_download(self.model_id, self.revision) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name)) + self.add_new_file_and_tag() + with tempfile.TemporaryDirectory() as temp_cache_dir: + snapshot_path = snapshot_download( + self.model_id, + revision=self.revision, + cache_dir=temp_cache_dir) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name)) + assert not os.path.exists( + os.path.join(snapshot_path, download_model_file_name2)) + with tempfile.TemporaryDirectory() as temp_cache_dir: + snapshot_path = snapshot_download( + self.model_id, + revision=self.revision2, + cache_dir=temp_cache_dir) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name)) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name2)) + + def test_file_download_different_revision(self): + self.prepare_repo_data() + t1 = datetime.now().isoformat(sep=' ', timespec='seconds') + logger.info('First time stamp: %s' % t1) + file_path = model_file_download(self.model_id, + download_model_file_name, + self.revision) + assert os.path.exists(file_path) + self.add_new_file_and_tag() + with tempfile.TemporaryDirectory() as temp_cache_dir: + file_path = model_file_download( + self.model_id, + download_model_file_name, + revision=self.revision, + cache_dir=temp_cache_dir) + assert os.path.exists(file_path) + with self.assertRaises(NotExistError): + model_file_download( + self.model_id, + download_model_file_name2, + revision=self.revision, + cache_dir=temp_cache_dir) + + with tempfile.TemporaryDirectory() as temp_cache_dir: + file_path = model_file_download( + self.model_id, + download_model_file_name, + revision=self.revision2, + cache_dir=temp_cache_dir) + print('Downloaded file path: %s' % file_path) + assert os.path.exists(file_path) + file_path = model_file_download( + self.model_id, + download_model_file_name2, + revision=self.revision2, + cache_dir=temp_cache_dir) + assert os.path.exists(file_path) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/hub/test_hub_revision_release_mode.py b/tests/hub/test_hub_revision_release_mode.py new file mode 100644 index 00000000..729a1861 --- /dev/null +++ b/tests/hub/test_hub_revision_release_mode.py @@ -0,0 +1,190 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import tempfile +import time +import unittest +import uuid +from datetime import datetime +from unittest import mock + +from modelscope import version +from modelscope.hub.api import HubApi +from modelscope.hub.constants import (MODELSCOPE_SDK_DEBUG, Licenses, + ModelVisibility) +from modelscope.hub.errors import NotExistError +from modelscope.hub.file_download import model_file_download +from modelscope.hub.repository import Repository +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.utils.logger import get_logger +from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME, + TEST_MODEL_ORG) + +logger = get_logger() +logger.setLevel('DEBUG') +download_model_file_name = 'test.bin' +download_model_file_name2 = 'test2.bin' + + +class HubRevisionTest(unittest.TestCase): + + def setUp(self): + self.api = HubApi() + self.api.login(TEST_ACCESS_TOKEN1) + self.model_name = 'rvr-%s' % (uuid.uuid4().hex) + self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) + self.revision = 'v0.1_test_revision' + self.revision2 = 'v0.2_test_revision' + self.api.create_model( + model_id=self.model_id, + visibility=ModelVisibility.PUBLIC, + license=Licenses.APACHE_V2, + chinese_name=TEST_MODEL_CHINESE_NAME, + ) + names_to_remove = {MODELSCOPE_SDK_DEBUG} + self.modified_environ = { + k: v + for k, v in os.environ.items() if k not in names_to_remove + } + + def tearDown(self): + self.api.delete_model(model_id=self.model_id) + + def prepare_repo_data(self): + temporary_dir = tempfile.mkdtemp() + self.model_dir = os.path.join(temporary_dir, self.model_name) + self.repo = Repository(self.model_dir, clone_from=self.model_id) + os.system("echo 'testtest'>%s" + % os.path.join(self.model_dir, download_model_file_name)) + self.repo.push('add model') + + def prepare_repo_data_and_tag(self): + self.prepare_repo_data() + self.repo.tag_and_push(self.revision, 'Test revision') + + def add_new_file_and_tag_to_repo(self): + os.system("echo 'testtest'>%s" + % os.path.join(self.model_dir, download_model_file_name2)) + self.repo.push('add new file') + self.repo.tag_and_push(self.revision2, 'Test revision') + + def add_new_file_and_branch_to_repo(self, branch_name): + os.system("echo 'testtest'>%s" + % os.path.join(self.model_dir, download_model_file_name2)) + self.repo.push('add new file', remote_branch=branch_name) + + def test_dev_mode_default_master(self): + with mock.patch.dict(os.environ, self.modified_environ, clear=True): + self.prepare_repo_data() # no tag, default get master + with tempfile.TemporaryDirectory() as temp_cache_dir: + snapshot_path = snapshot_download( + self.model_id, cache_dir=temp_cache_dir) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name)) + with tempfile.TemporaryDirectory() as temp_cache_dir: + file_path = model_file_download( + self.model_id, + download_model_file_name, + cache_dir=temp_cache_dir) + assert os.path.exists(file_path) + + def test_dev_mode_specify_branch(self): + with mock.patch.dict(os.environ, self.modified_environ, clear=True): + self.prepare_repo_data() # no tag, default get master + branch_name = 'test' + self.add_new_file_and_branch_to_repo(branch_name) + with tempfile.TemporaryDirectory() as temp_cache_dir: + snapshot_path = snapshot_download( + self.model_id, + revision=branch_name, + cache_dir=temp_cache_dir) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name)) + with tempfile.TemporaryDirectory() as temp_cache_dir: + file_path = model_file_download( + self.model_id, + download_model_file_name, + revision=branch_name, + cache_dir=temp_cache_dir) + assert os.path.exists(file_path) + + def test_snapshot_download_revision(self): + with mock.patch.dict(os.environ, self.modified_environ, clear=True): + self.prepare_repo_data_and_tag() + t1 = datetime.now().isoformat(sep=' ', timespec='seconds') + logger.info('First time: %s' % t1) + time.sleep(10) + self.add_new_file_and_tag_to_repo() + t2 = datetime.now().isoformat(sep=' ', timespec='seconds') + logger.info('Secnod time: %s' % t2) + # set + release_datetime_backup = version.__release_datetime__ + logger.info('Origin __release_datetime__: %s' + % version.__release_datetime__) + try: + logger.info('Setting __release_datetime__ to: %s' % t1) + version.__release_datetime__ = t1 + with tempfile.TemporaryDirectory() as temp_cache_dir: + snapshot_path = snapshot_download( + self.model_id, cache_dir=temp_cache_dir) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name)) + assert not os.path.exists( + os.path.join(snapshot_path, download_model_file_name2)) + version.__release_datetime__ = t2 + logger.info('Setting __release_datetime__ to: %s' % t2) + with tempfile.TemporaryDirectory() as temp_cache_dir: + snapshot_path = snapshot_download( + self.model_id, cache_dir=temp_cache_dir) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name)) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name2)) + finally: + version.__release_datetime__ = release_datetime_backup + + def test_file_download_revision(self): + with mock.patch.dict(os.environ, self.modified_environ, clear=True): + self.prepare_repo_data_and_tag() + t1 = datetime.now().isoformat(sep=' ', timespec='seconds') + logger.info('First time stamp: %s' % t1) + time.sleep(10) + self.add_new_file_and_tag_to_repo() + t2 = datetime.now().isoformat(sep=' ', timespec='seconds') + logger.info('Second time: %s' % t2) + release_datetime_backup = version.__release_datetime__ + logger.info('Origin __release_datetime__: %s' + % version.__release_datetime__) + try: + version.__release_datetime__ = t1 + logger.info('Setting __release_datetime__ to: %s' % t1) + with tempfile.TemporaryDirectory() as temp_cache_dir: + file_path = model_file_download( + self.model_id, + download_model_file_name, + cache_dir=temp_cache_dir) + assert os.path.exists(file_path) + with self.assertRaises(NotExistError): + model_file_download( + self.model_id, + download_model_file_name2, + cache_dir=temp_cache_dir) + version.__release_datetime__ = t2 + logger.info('Setting __release_datetime__ to: %s' % t2) + with tempfile.TemporaryDirectory() as temp_cache_dir: + file_path = model_file_download( + self.model_id, + download_model_file_name, + cache_dir=temp_cache_dir) + print('Downloaded file path: %s' % file_path) + assert os.path.exists(file_path) + file_path = model_file_download( + self.model_id, + download_model_file_name2, + cache_dir=temp_cache_dir) + assert os.path.exists(file_path) + finally: + version.__release_datetime__ = release_datetime_backup + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/hub/test_hub_upload.py b/tests/hub/test_hub_upload.py new file mode 100644 index 00000000..e1f61467 --- /dev/null +++ b/tests/hub/test_hub_upload.py @@ -0,0 +1,145 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest +import uuid + +from modelscope.hub.api import HubApi +from modelscope.hub.constants import Licenses, ModelVisibility +from modelscope.hub.errors import HTTPError, NotLoginException +from modelscope.hub.repository import Repository +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level +from .test_utils import TEST_ACCESS_TOKEN1, TEST_MODEL_ORG, delete_credential + +logger = get_logger() + + +class HubUploadTest(unittest.TestCase): + + def setUp(self): + logger.info('SetUp') + self.api = HubApi() + self.user = TEST_MODEL_ORG + logger.info(self.user) + self.create_model_name = '%s/%s_%s' % (self.user, 'test_model_upload', + uuid.uuid4().hex) + logger.info('create %s' % self.create_model_name) + temporary_dir = tempfile.mkdtemp() + self.work_dir = temporary_dir + self.model_dir = os.path.join(temporary_dir, self.create_model_name) + self.finetune_path = os.path.join(self.work_dir, 'finetune_path') + self.repo_path = os.path.join(self.work_dir, 'repo_path') + os.mkdir(self.finetune_path) + os.system("echo '{}'>%s" + % os.path.join(self.finetune_path, ModelFile.CONFIGURATION)) + + def tearDown(self): + logger.info('TearDown') + shutil.rmtree(self.model_dir, ignore_errors=True) + try: + self.api.delete_model(model_id=self.create_model_name) + except Exception: + pass + + def test_upload_exits_repo_master(self): + logger.info('basic test for upload!') + self.api.login(TEST_ACCESS_TOKEN1) + self.api.create_model( + model_id=self.create_model_name, + visibility=ModelVisibility.PUBLIC, + license=Licenses.APACHE_V2) + os.system("echo '111'>%s" + % os.path.join(self.finetune_path, 'add1.py')) + self.api.push_model( + model_id=self.create_model_name, model_dir=self.finetune_path) + Repository(model_dir=self.repo_path, clone_from=self.create_model_name) + assert os.path.exists(os.path.join(self.repo_path, 'add1.py')) + shutil.rmtree(self.repo_path, ignore_errors=True) + os.system("echo '222'>%s" + % os.path.join(self.finetune_path, 'add2.py')) + self.api.push_model( + model_id=self.create_model_name, + model_dir=self.finetune_path, + revision='new_revision/version1') + Repository( + model_dir=self.repo_path, + clone_from=self.create_model_name, + revision='new_revision/version1') + assert os.path.exists(os.path.join(self.repo_path, 'add2.py')) + shutil.rmtree(self.repo_path, ignore_errors=True) + os.system("echo '333'>%s" + % os.path.join(self.finetune_path, 'add3.py')) + self.api.push_model( + model_id=self.create_model_name, + model_dir=self.finetune_path, + revision='new_revision/version2', + commit_message='add add3.py') + Repository( + model_dir=self.repo_path, + clone_from=self.create_model_name, + revision='new_revision/version2') + assert os.path.exists(os.path.join(self.repo_path, 'add2.py')) + assert os.path.exists(os.path.join(self.repo_path, 'add3.py')) + shutil.rmtree(self.repo_path, ignore_errors=True) + add4_path = os.path.join(self.finetune_path, 'temp') + os.mkdir(add4_path) + os.system("echo '444'>%s" % os.path.join(add4_path, 'add4.py')) + self.api.push_model( + model_id=self.create_model_name, + model_dir=self.finetune_path, + revision='new_revision/version1') + Repository( + model_dir=self.repo_path, + clone_from=self.create_model_name, + revision='new_revision/version1') + assert os.path.exists(os.path.join(add4_path, 'add4.py')) + shutil.rmtree(self.repo_path, ignore_errors=True) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_upload_non_exists_repo(self): + logger.info('test upload non exists repo!') + self.api.login(TEST_ACCESS_TOKEN1) + os.system("echo '111'>%s" + % os.path.join(self.finetune_path, 'add1.py')) + self.api.push_model( + model_id=self.create_model_name, + model_dir=self.finetune_path, + revision='new_model_new_revision', + visibility=ModelVisibility.PUBLIC, + license=Licenses.APACHE_V2) + Repository( + model_dir=self.repo_path, + clone_from=self.create_model_name, + revision='new_model_new_revision') + assert os.path.exists(os.path.join(self.repo_path, 'add1.py')) + shutil.rmtree(self.repo_path, ignore_errors=True) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_upload_without_token(self): + logger.info('test upload without login!') + self.api.login(TEST_ACCESS_TOKEN1) + delete_credential() + with self.assertRaises(NotLoginException): + self.api.push_model( + model_id=self.create_model_name, + model_dir=self.finetune_path, + visibility=ModelVisibility.PUBLIC, + license=Licenses.APACHE_V2) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_upload_invalid_repo(self): + logger.info('test upload to invalid repo!') + self.api.login(TEST_ACCESS_TOKEN1) + with self.assertRaises(HTTPError): + self.api.push_model( + model_id='%s/%s' % ('speech_tts', 'invalid_model_test'), + model_dir=self.finetune_path, + visibility=ModelVisibility.PUBLIC, + license=Licenses.APACHE_V2) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/models/test_deberta_v2_backbone.py b/tests/models/test_deberta_v2_backbone.py new file mode 100644 index 00000000..706b18f8 --- /dev/null +++ b/tests/models/test_deberta_v2_backbone.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from modelscope.models import Model +from modelscope.models.nlp.deberta_v2 import (DebertaV2ForMaskedLM, + DebertaV2Model) +from modelscope.utils.constant import Tasks + + +class DebertaV2BackboneTest(unittest.TestCase): + + def test_load_model(self): + model = Model.from_pretrained( + 'damo/nlp_debertav2_fill-mask_chinese-lite') + self.assertTrue(model.__class__ == DebertaV2ForMaskedLM) + model = Model.from_pretrained( + 'damo/nlp_debertav2_fill-mask_chinese-lite', task=Tasks.backbone) + self.assertTrue(model.__class__ == DebertaV2Model) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/msdatasets/test_dataset_upload.py b/tests/msdatasets/test_dataset_upload.py index 1179414d..3d35d480 100644 --- a/tests/msdatasets/test_dataset_upload.py +++ b/tests/msdatasets/test_dataset_upload.py @@ -6,9 +6,13 @@ import unittest import zipfile from modelscope.msdatasets import MsDataset -from modelscope.utils.constant import ModelFile +from modelscope.msdatasets.utils.dataset_utils import list_dataset_objects +from modelscope.utils import logger as logging +from modelscope.utils.constant import DEFAULT_DATASET_REVISION, ModelFile from modelscope.utils.test_utils import test_level +logger = logging.get_logger(__name__) + KEY_EXTRACTED = 'extracted' @@ -39,7 +43,8 @@ class DatasetUploadTest(unittest.TestCase): def tearDown(self): os.chdir(self.old_dir) shutil.rmtree(self.temp_dir, ignore_errors=True) - print('The test dir successfully removed!') + logger.info( + f'Temporary directory {self.temp_dir} successfully removed!') @staticmethod def get_raw_downloaded_file_path(extracted_path): @@ -68,6 +73,40 @@ class DatasetUploadTest(unittest.TestCase): dataset_name=self.dataset_name, namespace=self.namespace) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_ds_upload_dir(self): + ms_ds_train = MsDataset.load(self.prepared_dataset_name, split='train') + config_train = ms_ds_train._hf_ds.config_kwargs + extracted_path_train = config_train.get('split_config').get('train') + + MsDataset.upload( + object_name='train', + local_file_path=os.path.join(extracted_path_train, + 'Pets/images/train'), + dataset_name=self.dataset_name, + namespace=self.namespace) + MsDataset.upload( + object_name='val', + local_file_path=os.path.join(extracted_path_train, + 'Pets/images/val'), + dataset_name=self.dataset_name, + namespace=self.namespace) + + objects = list_dataset_objects( + hub_api=self.api, + max_limit=-1, + is_recursive=True, + dataset_name=self.dataset_name, + namespace=self.namespace, + version=DEFAULT_DATASET_REVISION) + + logger.info(f'{len(objects)} objects have been uploaded: {objects}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_ds_download_dir(self): + test_ds = MsDataset.load(self.dataset_name, self.namespace) + assert test_ds.config_kwargs['split_config'].values() + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_ds_clone_meta(self): MsDataset.clone_meta( diff --git a/tests/msdatasets/test_ms_dataset.py b/tests/msdatasets/test_ms_dataset.py index 91a3b5c5..dff411f6 100644 --- a/tests/msdatasets/test_ms_dataset.py +++ b/tests/msdatasets/test_ms_dataset.py @@ -52,7 +52,8 @@ class MsDatasetTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_ms_csv_basic(self): ms_ds_train = MsDataset.load( - 'afqmc_small', namespace='userxiaoming', split='train') + 'clue', subset_name='afqmc', + split='train').to_hf_dataset().select(range(5)) print(next(iter(ms_ds_train))) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @@ -70,7 +71,7 @@ class MsDatasetTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @require_torch def test_to_torch_dataset_text(self): - model_id = 'damo/bert-base-sst2' + model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' nlp_model = Model.from_pretrained(model_id) preprocessor = SequenceClassificationPreprocessor( nlp_model.model_dir, @@ -92,7 +93,7 @@ class MsDatasetTest(unittest.TestCase): def test_to_tf_dataset_text(self): import tensorflow as tf tf.compat.v1.enable_eager_execution() - model_id = 'damo/bert-base-sst2' + model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' nlp_model = Model.from_pretrained(model_id) preprocessor = SequenceClassificationPreprocessor( nlp_model.model_dir, diff --git a/tests/outputs/__init__.py b/tests/outputs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/outputs/test_model_outputs.py b/tests/outputs/test_model_outputs.py new file mode 100644 index 00000000..31271869 --- /dev/null +++ b/tests/outputs/test_model_outputs.py @@ -0,0 +1,30 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import torch + +from modelscope.outputs import TextClassificationModelOutput +from modelscope.utils.test_utils import test_level + + +class TestModelOutput(unittest.TestCase): + + def setUp(self): + pass + + def tearDown(self): + super().tearDown() + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_model_outputs(self): + outputs = TextClassificationModelOutput(logits=torch.Tensor([1])) + self.assertEqual(outputs['logits'], torch.Tensor([1])) + self.assertEqual(outputs[0], torch.Tensor([1])) + self.assertEqual(outputs.logits, torch.Tensor([1])) + logits, loss = outputs + self.assertEqual(logits, torch.Tensor([1])) + self.assertTrue(loss is None) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/nlp/test_faq.py b/tests/pipelines/nlp/test_faq.py new file mode 100644 index 00000000..8bac55d4 --- /dev/null +++ b/tests/pipelines/nlp/test_faq.py @@ -0,0 +1,59 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import SbertForFaqRanking, SbertForFaqRetrieval +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import FaqPipeline +from modelscope.preprocessors import FaqPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class FaqTest(unittest.TestCase): + model_id = '/Users/tanfan/Desktop/Workdir/Gitlab/maas/MaaS-lib/.faq_test_model' + param = { + 'query_set': ['明天星期几', '今天星期六', '今天星期六'], + 'support_set': [{ + 'text': '今天星期六', + 'label': 'label0' + }, { + 'text': '明天星期几', + 'label': 'label1' + }] + } + + # @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + # def test_run_with_direct_file_download(self): + # cache_path = self.model_id # snapshot_download(self.model_id) + # preprocessor = FaqPreprocessor(cache_path) + # model = SbertForFaq(cache_path) + # pipeline_ins = FaqPipeline(model, preprocessor=preprocessor) + # + # result = pipeline_ins(self.param) + # print(result) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + preprocessor = FaqPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.faq, model=model, preprocessor=preprocessor) + result = pipeline_ins(self.param) + print(result) + + # @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + # def test_run_with_model_name(self): + # pipeline_ins = pipeline(task=Tasks.faq, model=self.model_id) + # result = pipeline_ins(self.param) + # print(result) + + # @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + # def test_run_with_default_model(self): + # pipeline_ins = pipeline(task=Tasks.faq) + # print(pipeline_ins(self.param)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_automatic_post_editing.py b/tests/pipelines/test_automatic_post_editing.py new file mode 100644 index 00000000..da09851c --- /dev/null +++ b/tests/pipelines/test_automatic_post_editing.py @@ -0,0 +1,30 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class AutomaticPostEditingTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.translation + self.model_id = 'damo/nlp_automatic_post_editing_for_translation_en2de' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_en2de(self): + inputs = 'Simultaneously, the Legion took part to the pacification of Algeria, plagued by various tribal ' \ + 'rebellions and razzias.\005Gleichzeitig nahm die Legion an der Befriedung Algeriens teil, die von ' \ + 'verschiedenen Stammesaufständen und Rasias heimgesucht wurde.' + pipeline_ins = pipeline(self.task, model=self.model_id) + print(pipeline_ins(input=inputs)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_automatic_speech_recognition.py b/tests/pipelines/test_automatic_speech_recognition.py index 303fb6b9..b6532868 100644 --- a/tests/pipelines/test_automatic_speech_recognition.py +++ b/tests/pipelines/test_automatic_speech_recognition.py @@ -45,6 +45,10 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, 'checking_item': OutputKeys.TEXT, 'example': 'wav_example' }, + 'test_run_with_url_pytorch': { + 'checking_item': OutputKeys.TEXT, + 'example': 'wav_example' + }, 'test_run_with_url_tf': { 'checking_item': OutputKeys.TEXT, 'example': 'wav_example' @@ -74,6 +78,147 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, } } + all_models_info = [ + { + 'model_id': + 'damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1', + 'wav_path': 'data/test/audios/asr_example.wav' + }, + { + 'model_id': 'damo/speech_paraformer_asr_nat-aishell1-pytorch', + 'wav_path': 'data/test/audios/asr_example.wav' + }, + { + 'model_id': 'damo/speech_paraformer_asr_nat-aishell2-pytorch', + 'wav_path': 'data/test/audios/asr_example.wav' + }, + { + 'model_id': + 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1', + 'wav_path': 'data/test/audios/asr_example.wav' + }, + { + 'model_id': + 'damo/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1', + 'wav_path': 'data/test/audios/asr_example_8K.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline', + 'wav_path': 'data/test/audios/asr_example.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example_8K.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline', + 'wav_path': 'data/test/audios/asr_example_8K.wav' + }, + { + 'model_id': + 'damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline', + 'wav_path': 'data/test/audios/asr_example.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example_cn_en.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-offline', + 'wav_path': 'data/test/audios/asr_example_cn_en.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example_cn_dialect.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline', + 'wav_path': 'data/test/audios/asr_example_cn_dialect.wav' + }, + { + 'model_id': + 'damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab3444-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example.wav' + }, + { + 'model_id': + 'damo/speech_paraformer_asr_nat-zh-cn-8k-common-vocab3444-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example_8K.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline', + 'wav_path': 'data/test/audios/asr_example_en.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example_en.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline', + 'wav_path': 'data/test/audios/asr_example_ru.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example_ru.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline', + 'wav_path': 'data/test/audios/asr_example_es.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example_es.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline', + 'wav_path': 'data/test/audios/asr_example_ko.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example_ko.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example_ja.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline', + 'wav_path': 'data/test/audios/asr_example_ja.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example_id.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline', + 'wav_path': 'data/test/audios/asr_example_id.wav' + }, + ] + def setUp(self) -> None: self.am_pytorch_model_id = 'damo/speech_paraformer_asr_nat-aishell1-pytorch' self.am_tf_model_id = 'damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1' @@ -90,7 +235,7 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, def run_pipeline(self, model_id: str, audio_in: Union[str, bytes], - sr: int = 16000) -> Dict[str, Any]: + sr: int = None) -> Dict[str, Any]: inference_16k_pipline = pipeline( task=Tasks.auto_speech_recognition, model=model_id) @@ -136,33 +281,26 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, return audio, fs @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_run_with_wav_pytorch(self): - """run with single waveform file + def test_run_with_pcm(self): + """run with wav data """ - logger.info('Run ASR test with waveform file (pytorch)...') + logger.info('Run ASR test with wav data (tensorflow)...') - wav_file_path = os.path.join(os.getcwd(), WAV_FILE) + audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) rec_result = self.run_pipeline( - model_id=self.am_pytorch_model_id, audio_in=wav_file_path) - self.check_result('test_run_with_wav_pytorch', rec_result) - - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_run_with_pcm_pytorch(self): - """run with wav data - """ + model_id=self.am_tf_model_id, audio_in=audio, sr=sr) + self.check_result('test_run_with_pcm_tf', rec_result) logger.info('Run ASR test with wav data (pytorch)...') - audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) - rec_result = self.run_pipeline( model_id=self.am_pytorch_model_id, audio_in=audio, sr=sr) self.check_result('test_run_with_pcm_pytorch', rec_result) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_run_with_wav_tf(self): + def test_run_with_wav(self): """run with single waveform file """ @@ -174,21 +312,14 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, model_id=self.am_tf_model_id, audio_in=wav_file_path) self.check_result('test_run_with_wav_tf', rec_result) - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_run_with_pcm_tf(self): - """run with wav data - """ - - logger.info('Run ASR test with wav data (tensorflow)...') - - audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) + logger.info('Run ASR test with waveform file (pytorch)...') rec_result = self.run_pipeline( - model_id=self.am_tf_model_id, audio_in=audio, sr=sr) - self.check_result('test_run_with_pcm_tf', rec_result) + model_id=self.am_pytorch_model_id, audio_in=wav_file_path) + self.check_result('test_run_with_wav_pytorch', rec_result) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_run_with_url_tf(self): + def test_run_with_url(self): """run with single url file """ @@ -198,6 +329,12 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, model_id=self.am_tf_model_id, audio_in=URL_FILE) self.check_result('test_run_with_url_tf', rec_result) + logger.info('Run ASR test with url file (pytorch)...') + + rec_result = self.run_pipeline( + model_id=self.am_pytorch_model_id, audio_in=URL_FILE) + self.check_result('test_run_with_url_pytorch', rec_result) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_wav_dataset_pytorch(self): """run with datasets, and audio format is waveform @@ -217,7 +354,6 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, data.text # hypothesis text """ - logger.info('Run ASR test with waveform dataset (pytorch)...') logger.info('Downloading waveform testsets file ...') dataset_path = download_and_untar( @@ -225,40 +361,38 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, LITTLE_TESTSETS_URL, self.workspace) dataset_path = os.path.join(dataset_path, 'wav', 'test') + logger.info('Run ASR test with waveform dataset (tensorflow)...') + + rec_result = self.run_pipeline( + model_id=self.am_tf_model_id, audio_in=dataset_path) + self.check_result('test_run_with_wav_dataset_tf', rec_result) + + logger.info('Run ASR test with waveform dataset (pytorch)...') + rec_result = self.run_pipeline( model_id=self.am_pytorch_model_id, audio_in=dataset_path) self.check_result('test_run_with_wav_dataset_pytorch', rec_result) - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') - def test_run_with_wav_dataset_tf(self): - """run with datasets, and audio format is waveform - datasets directory: - - wav - test # testsets - xx.wav - ... - dev # devsets - yy.wav - ... - train # trainsets - zz.wav - ... - transcript - data.text # hypothesis text + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_all_models(self): + """run with all models """ - logger.info('Run ASR test with waveform dataset (tensorflow)...') - logger.info('Downloading waveform testsets file ...') - - dataset_path = download_and_untar( - os.path.join(self.workspace, LITTLE_TESTSETS_FILE), - LITTLE_TESTSETS_URL, self.workspace) - dataset_path = os.path.join(dataset_path, 'wav', 'test') - - rec_result = self.run_pipeline( - model_id=self.am_tf_model_id, audio_in=dataset_path) - self.check_result('test_run_with_wav_dataset_tf', rec_result) + logger.info('Run ASR test with all models') + + for item in self.all_models_info: + model_id = item['model_id'] + wav_path = item['wav_path'] + rec_result = self.run_pipeline( + model_id=model_id, audio_in=wav_path) + if rec_result.__contains__(OutputKeys.TEXT): + logger.info(ColorCodes.MAGENTA + str(item['model_id']) + ' ' + + ColorCodes.YELLOW + + str(rec_result[OutputKeys.TEXT]) + + ColorCodes.END) + else: + logger.info(ColorCodes.MAGENTA + str(rec_result) + + ColorCodes.END) @unittest.skip('demo compatibility test is only enabled on a needed-basis') def test_demo_compatibility(self): diff --git a/tests/pipelines/test_body_3d_keypoints.py b/tests/pipelines/test_body_3d_keypoints.py index bde04f8e..6e671d2e 100644 --- a/tests/pipelines/test_body_3d_keypoints.py +++ b/tests/pipelines/test_body_3d_keypoints.py @@ -20,18 +20,15 @@ class Body3DKeypointsTest(unittest.TestCase, DemoCompatibilityCheck): self.task = Tasks.body_3d_keypoints def pipeline_inference(self, pipeline: Pipeline, pipeline_input): - output = pipeline(pipeline_input) - poses = np.array(output[OutputKeys.POSES]) + output = pipeline(pipeline_input, output_video='./result.mp4') + poses = np.array(output[OutputKeys.KEYPOINTS]) print(f'result 3d points shape {poses.shape}') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_modelhub_with_video_file(self): body_3d_keypoints = pipeline( Tasks.body_3d_keypoints, model=self.model_id) - pipeline_input = { - 'input_video': self.test_video, - 'output_video_path': './result.mp4' - } + pipeline_input = self.test_video self.pipeline_inference( body_3d_keypoints, pipeline_input=pipeline_input) @@ -42,10 +39,7 @@ class Body3DKeypointsTest(unittest.TestCase, DemoCompatibilityCheck): if not cap.isOpened(): raise Exception('modelscope error: %s cannot be decoded by OpenCV.' % (self.test_video)) - pipeline_input = { - 'input_video': cap, - 'output_video_path': './result.mp4' - } + pipeline_input = self.test_video self.pipeline_inference( body_3d_keypoints, pipeline_input=pipeline_input) diff --git a/tests/pipelines/test_card_detection.py b/tests/pipelines/test_card_detection.py new file mode 100644 index 00000000..d913f494 --- /dev/null +++ b/tests/pipelines/test_card_detection.py @@ -0,0 +1,66 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 + +from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import draw_card_detection_result +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class CardDetectionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.card_detection + self.model_id = 'damo/cv_resnet_carddetection_scrfd34gkps' + + def show_result(self, img_path, detection_result): + img_list = draw_card_detection_result(img_path, detection_result) + for i, img in enumerate(img_list): + if i == 0: + cv2.imwrite('result.jpg', img_list[0]) + print( + f'Found {len(img_list)-1} cards, output written to {osp.abspath("result.jpg")}' + ) + else: + cv2.imwrite(f'card_{i}.jpg', img_list[i]) + save_path = osp.abspath(f'card_{i}.jpg') + print(f'detect card_{i}: {save_path}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_dataset(self): + input_location = ['data/test/images/card_detection.jpg'] + + dataset = MsDataset.load(input_location, target='image') + card_detection = pipeline(Tasks.card_detection, model=self.model_id) + # note that for dataset output, the inference-output is a Generator that can be iterated. + result = card_detection(dataset) + result = next(result) + self.show_result(input_location[0], result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + card_detection = pipeline(Tasks.card_detection, model=self.model_id) + img_path = 'data/test/images/card_detection.jpg' + + result = card_detection(img_path) + self.show_result(img_path, result) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + card_detection = pipeline(Tasks.card_detection) + img_path = 'data/test/images/card_detection.jpg' + result = card_detection(img_path) + self.show_result(img_path, result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_conversational_text_to_sql.py b/tests/pipelines/test_conversational_text_to_sql.py index 80c72337..17fffcaf 100644 --- a/tests/pipelines/test_conversational_text_to_sql.py +++ b/tests/pipelines/test_conversational_text_to_sql.py @@ -9,14 +9,15 @@ from modelscope.pipelines.nlp import ConversationalTextToSqlPipeline from modelscope.preprocessors import ConversationalTextToSqlPreprocessor from modelscope.utils.constant import Tasks from modelscope.utils.demo_utils import DemoCompatibilityCheck -from modelscope.utils.nlp.nlp_utils import text2sql_tracking_and_print_results +from modelscope.utils.nlp.space_T_en.utils import \ + text2sql_tracking_and_print_results from modelscope.utils.test_utils import test_level class ConversationalTextToSql(unittest.TestCase, DemoCompatibilityCheck): def setUp(self) -> None: - self.task = Tasks.conversational_text_to_sql + self.task = Tasks.table_question_answering self.model_id = 'damo/nlp_star_conversational-text-to-sql' model_id = 'damo/nlp_star_conversational-text-to-sql' @@ -66,11 +67,6 @@ class ConversationalTextToSql(unittest.TestCase, DemoCompatibilityCheck): pipelines = [pipeline(task=self.task, model=self.model_id)] text2sql_tracking_and_print_results(self.test_case, pipelines) - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_run_with_default_model(self): - pipelines = [pipeline(task=self.task)] - text2sql_tracking_and_print_results(self.test_case, pipelines) - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_demo_compatibility(self): self.compatibility_check() diff --git a/tests/pipelines/test_csanmt_translation.py b/tests/pipelines/test_csanmt_translation.py index f7ec81cd..83827813 100644 --- a/tests/pipelines/test_csanmt_translation.py +++ b/tests/pipelines/test_csanmt_translation.py @@ -26,6 +26,20 @@ class TranslationTest(unittest.TestCase, DemoCompatibilityCheck): pipeline_ins = pipeline(self.task, model=model_id) print(pipeline_ins(input=inputs)) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_en2fr(self): + model_id = 'damo/nlp_csanmt_translation_en2fr' + inputs = 'When I was in my 20s, I saw my very first psychotherapy client.' + pipeline_ins = pipeline(self.task, model=model_id) + print(pipeline_ins(input=inputs)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_fr2en(self): + model_id = 'damo/nlp_csanmt_translation_fr2en' + inputs = "Quand j'avais la vingtaine, j'ai vu mes tout premiers clients comme psychothérapeute." + pipeline_ins = pipeline(self.task, model=model_id) + print(pipeline_ins(input=inputs)) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_default_model(self): inputs = '声明补充说,沃伦的同事都深感震惊,并且希望他能够投案自首。' diff --git a/tests/pipelines/test_dialog_intent_prediction.py b/tests/pipelines/test_dialog_intent_prediction.py index 5894297f..2ee46388 100644 --- a/tests/pipelines/test_dialog_intent_prediction.py +++ b/tests/pipelines/test_dialog_intent_prediction.py @@ -25,7 +25,7 @@ class DialogIntentPredictionTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_by_direct_model_download(self): - cache_path = snapshot_download(self.model_id, revision='update') + cache_path = snapshot_download(self.model_id) preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path) model = SpaceForDialogIntent( model_dir=cache_path, @@ -46,7 +46,7 @@ class DialogIntentPredictionTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_from_modelhub(self): - model = Model.from_pretrained(self.model_id, revision='update') + model = Model.from_pretrained(self.model_id) preprocessor = DialogIntentPredictionPreprocessor( model_dir=model.model_dir) @@ -64,10 +64,7 @@ class DialogIntentPredictionTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_name(self): - pipelines = [ - pipeline( - task=self.task, model=self.model_id, model_revision='update') - ] + pipelines = [pipeline(task=self.task, model=self.model_id)] for my_pipeline, item in list(zip(pipelines, self.test_case)): print(my_pipeline(item)) diff --git a/tests/pipelines/test_dialog_modeling.py b/tests/pipelines/test_dialog_modeling.py index 19d6ed2f..6b6259ce 100644 --- a/tests/pipelines/test_dialog_modeling.py +++ b/tests/pipelines/test_dialog_modeling.py @@ -115,8 +115,7 @@ class DialogModelingTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_by_direct_model_download(self): - cache_path = snapshot_download( - self.model_id, revision='task_oriented_conversation') + cache_path = snapshot_download(self.model_id) preprocessor = DialogModelingPreprocessor(model_dir=cache_path) model = SpaceForDialogModeling( @@ -130,8 +129,7 @@ class DialogModelingTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_model_from_modelhub(self): - model = Model.from_pretrained( - self.model_id, revision='task_oriented_conversation') + model = Model.from_pretrained(self.model_id) preprocessor = DialogModelingPreprocessor(model_dir=model.model_dir) pipelines = [ @@ -142,20 +140,12 @@ class DialogModelingTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_name(self): - pipelines = [ - pipeline( - task=self.task, - model=self.model_id, - model_revision='task_oriented_conversation') - ] + pipelines = [pipeline(task=self.task, model=self.model_id)] self.generate_and_print_dialog_response(pipelines) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_default_model(self): - pipelines = [ - pipeline( - task=self.task, model_revision='task_oriented_conversation') - ] + pipelines = [pipeline(task=self.task)] self.generate_and_print_dialog_response(pipelines) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') diff --git a/tests/pipelines/test_dialog_state_tracking.py b/tests/pipelines/test_dialog_state_tracking.py index 81bdd9be..6cdd5ee7 100644 --- a/tests/pipelines/test_dialog_state_tracking.py +++ b/tests/pipelines/test_dialog_state_tracking.py @@ -3,13 +3,14 @@ import unittest from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model -from modelscope.models.nlp import SpaceForDialogStateTracking +from modelscope.models.nlp import SpaceForDST from modelscope.pipelines import pipeline from modelscope.pipelines.nlp import DialogStateTrackingPipeline from modelscope.preprocessors import DialogStateTrackingPreprocessor from modelscope.utils.constant import Tasks from modelscope.utils.demo_utils import DemoCompatibilityCheck -from modelscope.utils.nlp.nlp_utils import tracking_and_print_dialog_states +from modelscope.utils.nlp.space.utils_dst import \ + tracking_and_print_dialog_states from modelscope.utils.test_utils import test_level @@ -85,9 +86,9 @@ class DialogStateTrackingTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_by_direct_model_download(self): - cache_path = snapshot_download(self.model_id, revision='update') + cache_path = snapshot_download(self.model_id) - model = SpaceForDialogStateTracking(cache_path) + model = SpaceForDST.from_pretrained(cache_path) preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path) pipelines = [ DialogStateTrackingPipeline( @@ -101,7 +102,7 @@ class DialogStateTrackingTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_from_modelhub(self): - model = Model.from_pretrained(self.model_id, revision='update') + model = Model.from_pretrained(self.model_id) preprocessor = DialogStateTrackingPreprocessor( model_dir=model.model_dir) @@ -115,10 +116,7 @@ class DialogStateTrackingTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_name(self): - pipelines = [ - pipeline( - task=self.task, model=self.model_id, model_revision='update') - ] + pipelines = [pipeline(task=self.task, model=self.model_id)] tracking_and_print_dialog_states(self.test_case, pipelines) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') diff --git a/tests/pipelines/test_domain_classification.py b/tests/pipelines/test_domain_classification.py new file mode 100644 index 00000000..8e5bfa7f --- /dev/null +++ b/tests/pipelines/test_domain_classification.py @@ -0,0 +1,45 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class DomainClassificationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.text_classification + self.model_id = 'damo/nlp_domain_classification_chinese' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_zh_domain(self): + inputs = '通过这种方式产生的离子吸收大地水分之后,可以通过潮解作用,将活性电解离子有效释放到周围土壤中,使接地极成为一个离子发生装置,' \ + '从而改善周边土质使之达到接地要求。' + pipeline_ins = pipeline(self.task, model=self.model_id) + print(pipeline_ins(input=inputs)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_zh_style(self): + model_id = 'damo/nlp_style_classification_chinese' + inputs = '通过这种方式产生的离子吸收大地水分之后,可以通过潮解作用,将活性电解离子有效释放到周围土壤中,使接地极成为一个离子发生装置,' \ + '从而改善周边土质使之达到接地要求。' + pipeline_ins = pipeline(self.task, model=model_id) + print(pipeline_ins(input=inputs)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_en_style(self): + model_id = 'damo/nlp_style_classification_english' + inputs = 'High Power 11.1V 5200mAh Lipo Battery For RC Car Robot Airplanes ' \ + 'Helicopter RC Drone Parts 3s Lithium battery 11.1v Battery' + pipeline_ins = pipeline(self.task, model=model_id) + print(pipeline_ins(input=inputs)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_face_2d_keypoints.py b/tests/pipelines/test_face_2d_keypoints.py index 667ecddc..a5e347e8 100644 --- a/tests/pipelines/test_face_2d_keypoints.py +++ b/tests/pipelines/test_face_2d_keypoints.py @@ -18,7 +18,7 @@ class EasyCVFace2DKeypointsPipelineTest(unittest.TestCase): face_2d_keypoints_align = pipeline( task=Tasks.face_2d_keypoints, model=model_id) - output = face_2d_keypoints_align(img_path)[0] + output = face_2d_keypoints_align(img_path) output_keypoints = output[OutputKeys.KEYPOINTS] output_pose = output[OutputKeys.POSES] diff --git a/tests/pipelines/test_face_detection.py b/tests/pipelines/test_face_detection.py index f89e9a94..db513a80 100644 --- a/tests/pipelines/test_face_detection.py +++ b/tests/pipelines/test_face_detection.py @@ -25,7 +25,7 @@ class FaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_dataset(self): - input_location = ['data/test/images/face_detection.png'] + input_location = ['data/test/images/face_detection2.jpeg'] dataset = MsDataset.load(input_location, target='image') face_detection = pipeline(Tasks.face_detection, model=self.model_id) @@ -37,7 +37,7 @@ class FaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_modelhub(self): face_detection = pipeline(Tasks.face_detection, model=self.model_id) - img_path = 'data/test/images/face_detection.png' + img_path = 'data/test/images/face_detection2.jpeg' result = face_detection(img_path) self.show_result(img_path, result) @@ -45,7 +45,7 @@ class FaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_modelhub_default_model(self): face_detection = pipeline(Tasks.face_detection) - img_path = 'data/test/images/face_detection.png' + img_path = 'data/test/images/face_detection2.jpeg' result = face_detection(img_path) self.show_result(img_path, result) diff --git a/tests/pipelines/test_faq_question_answering.py b/tests/pipelines/test_faq_question_answering.py index 7eea0ddf..2f66f516 100644 --- a/tests/pipelines/test_faq_question_answering.py +++ b/tests/pipelines/test_faq_question_answering.py @@ -47,9 +47,9 @@ class FaqQuestionAnsweringTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_direct_file_download(self): cache_path = snapshot_download(self.model_id) - preprocessor = FaqQuestionAnsweringPreprocessor(cache_path) - model = SbertForFaqQuestionAnswering(cache_path) - model.load_checkpoint(cache_path) + preprocessor = FaqQuestionAnsweringPreprocessor.from_pretrained( + cache_path) + model = SbertForFaqQuestionAnswering.from_pretrained(cache_path) pipeline_ins = FaqQuestionAnsweringPipeline( model, preprocessor=preprocessor) result = pipeline_ins(self.param) diff --git a/tests/pipelines/test_fill_mask.py b/tests/pipelines/test_fill_mask.py index 0e5e242b..35202b88 100644 --- a/tests/pipelines/test_fill_mask.py +++ b/tests/pipelines/test_fill_mask.py @@ -5,8 +5,7 @@ from regex import R from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model -from modelscope.models.nlp import (BertForMaskedLM, StructBertForMaskedLM, - VecoForMaskedLM) +from modelscope.models.nlp import SbertForMaskedLM, VecoForMaskedLM from modelscope.pipelines import pipeline from modelscope.pipelines.nlp import FillMaskPipeline from modelscope.preprocessors import NLPPreprocessor @@ -55,7 +54,7 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck): model_dir = snapshot_download(self.model_id_sbert[language]) preprocessor = NLPPreprocessor( model_dir, first_sequence='sentence', second_sequence=None) - model = StructBertForMaskedLM.from_pretrained(model_dir) + model = SbertForMaskedLM.from_pretrained(model_dir) pipeline1 = FillMaskPipeline(model, preprocessor) pipeline2 = pipeline( Tasks.fill_mask, model=model, preprocessor=preprocessor) @@ -84,7 +83,7 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck): # bert language = 'zh' - model_dir = snapshot_download(self.model_id_bert, revision='beta') + model_dir = snapshot_download(self.model_id_bert) preprocessor = NLPPreprocessor( model_dir, first_sequence='sentence', second_sequence=None) model = Model.from_pretrained(model_dir) @@ -130,18 +129,6 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck): f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: ' f'{pipeline_ins(test_input)}\n') - # bert - language = 'zh' - model = Model.from_pretrained(self.model_id_bert, revision='beta') - preprocessor = NLPPreprocessor( - model.model_dir, first_sequence='sentence', second_sequence=None) - pipeline_ins = pipeline( - Tasks.fill_mask, model=model, preprocessor=preprocessor) - pipeline_ins.model, f'fill_mask_bert_{language}' - print( - f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: ' - f'{pipeline_ins(self.test_inputs[language])}\n') - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_name(self): # veco @@ -162,10 +149,7 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck): # Bert language = 'zh' - pipeline_ins = pipeline( - task=Tasks.fill_mask, - model=self.model_id_bert, - model_revision='beta') + pipeline_ins = pipeline(task=Tasks.fill_mask, model=self.model_id_bert) print( f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: ' f'{pipeline_ins(self.test_inputs[language])}\n') diff --git a/tests/pipelines/test_gpt3_text_generation.py b/tests/pipelines/test_gpt3_text_generation.py new file mode 100644 index 00000000..674e95bb --- /dev/null +++ b/tests/pipelines/test_gpt3_text_generation.py @@ -0,0 +1,58 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class TextGPT3GenerationTest(unittest.TestCase): + + def setUp(self) -> None: + # please make sure this local path exists. + self.model_id_1_3B = 'damo/nlp_gpt3_text-generation_1.3B' + self.model_id_2_7B = 'damo/nlp_gpt3_text-generation_2.7B' + self.model_id_13B = 'damo/nlp_gpt3_text-generation_13B' + self.model_dir_13B = snapshot_download(self.model_id_13B) + self.input = '好的' + + @unittest.skip('distributed gpt3 1.3B, skipped') + def test_gpt3_1_3B(self): + pipe = pipeline(Tasks.text_generation, model=self.model_id_1_3B) + print(pipe(self.input)) + + @unittest.skip('distributed gpt3 2.7B, skipped') + def test_gpt3_2_7B(self): + pipe = pipeline(Tasks.text_generation, model=self.model_id_2_7B) + print(pipe(self.input)) + + @unittest.skip('distributed gpt3 13B, skipped') + def test_gpt3_13B(self): + """ The model can be downloaded from the link on + TODO: add gpt3 checkpoint link + After downloading, you should have a gpt3 model structure like this: + nlp_gpt3_text-generation_13B + |_ config.json + |_ configuration.json + |_ tokenizer.json + |_ model <-- an empty directory + + Model binaries shall be downloaded separately to populate the model directory, so that + the model directory would contain the following binaries: + |_ model + |_ mp_rank_00_model_states.pt + |_ mp_rank_01_model_states.pt + |_ mp_rank_02_model_states.pt + |_ mp_rank_03_model_states.pt + |_ mp_rank_04_model_states.pt + |_ mp_rank_05_model_states.pt + |_ mp_rank_06_model_states.pt + |_ mp_rank_07_model_states.pt + """ + pipe = pipeline(Tasks.text_generation, model=self.model_dir_13B) + print(pipe(self.input)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_hand_2d_keypoints.py b/tests/pipelines/test_hand_2d_keypoints.py index 86cd2d06..43b569d0 100644 --- a/tests/pipelines/test_hand_2d_keypoints.py +++ b/tests/pipelines/test_hand_2d_keypoints.py @@ -15,10 +15,8 @@ class Hand2DKeypointsPipelineTest(unittest.TestCase): model_id = 'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody' hand_keypoint = pipeline(task=Tasks.hand_2d_keypoints, model=model_id) - outputs = hand_keypoint(img_path) - self.assertEqual(len(outputs), 1) + results = hand_keypoint(img_path) - results = outputs[0] self.assertIn(OutputKeys.KEYPOINTS, results.keys()) self.assertIn(OutputKeys.BOXES, results.keys()) self.assertEqual(results[OutputKeys.KEYPOINTS].shape[1], 21) @@ -30,10 +28,7 @@ class Hand2DKeypointsPipelineTest(unittest.TestCase): img_path = 'data/test/images/hand_keypoints.jpg' hand_keypoint = pipeline(task=Tasks.hand_2d_keypoints) - outputs = hand_keypoint(img_path) - self.assertEqual(len(outputs), 1) - - results = outputs[0] + results = hand_keypoint(img_path) self.assertIn(OutputKeys.KEYPOINTS, results.keys()) self.assertIn(OutputKeys.BOXES, results.keys()) self.assertEqual(results[OutputKeys.KEYPOINTS].shape[1], 21) diff --git a/tests/pipelines/test_human_wholebody_keypoint.py b/tests/pipelines/test_human_wholebody_keypoint.py new file mode 100644 index 00000000..7c5946cc --- /dev/null +++ b/tests/pipelines/test_human_wholebody_keypoint.py @@ -0,0 +1,40 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import cv2 + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class EasyCVFace2DKeypointsPipelineTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_human_wholebody_keypoint(self): + img_path = 'data/test/images/keypoints_detect/img_test_wholebody.jpg' + model_id = 'damo/cv_hrnetw48_human-wholebody-keypoint_image' + + human_wholebody_keypoint_pipeline = pipeline( + task=Tasks.human_wholebody_keypoint, model=model_id) + output = human_wholebody_keypoint_pipeline(img_path) + + output_keypoints = output[OutputKeys.KEYPOINTS] + output_pose = output[OutputKeys.BOXES] + + human_wholebody_keypoint_pipeline.predict_op.show_result( + img_path, + output_keypoints, + output_pose, + scale=1, + save_path='human_wholebody_keypoint_ret.jpg') + + for keypoint in output_keypoints: + self.assertEqual(keypoint.shape[0], 133) + for box in output_pose: + self.assertEqual(box.shape[0], 4) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_image_body_reshaping.py b/tests/pipelines/test_image_body_reshaping.py new file mode 100644 index 00000000..e1955e94 --- /dev/null +++ b/tests/pipelines/test_image_body_reshaping.py @@ -0,0 +1,58 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ImageBodyReshapingTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.image_body_reshaping + self.model_id = 'damo/cv_flow-based-body-reshaping_damo' + self.test_image = 'data/test/images/image_body_reshaping.jpg' + + def pipeline_inference(self, pipeline: Pipeline, input_location: str): + result = pipeline(input_location) + if result is not None: + cv2.imwrite('result_bodyreshaping.png', + result[OutputKeys.OUTPUT_IMG]) + print( + f'Output written to {osp.abspath("result_body_reshaping.png")}' + ) + else: + raise Exception('Testing failed: invalid output') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + model_dir = snapshot_download(self.model_id) + image_body_reshaping = pipeline( + Tasks.image_body_reshaping, model=model_dir) + self.pipeline_inference(image_body_reshaping, self.test_image) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + image_body_reshaping = pipeline( + Tasks.image_body_reshaping, model=self.model_id) + self.pipeline_inference(image_body_reshaping, self.test_image) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + image_body_reshaping = pipeline(Tasks.image_body_reshaping) + self.pipeline_inference(image_body_reshaping, self.test_image) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_image_color_enhance.py b/tests/pipelines/test_image_color_enhance.py index 9b72999e..7c3ae8c0 100644 --- a/tests/pipelines/test_image_color_enhance.py +++ b/tests/pipelines/test_image_color_enhance.py @@ -21,8 +21,7 @@ class ImageColorEnhanceTest(unittest.TestCase, DemoCompatibilityCheck): def pipeline_inference(self, pipeline: Pipeline, input_location: str): result = pipeline(input_location) if result is not None: - cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG][:, :, - [2, 1, 0]]) + cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) print(f'Output written to {osp.abspath("result.png")}') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') diff --git a/tests/pipelines/test_image_denoise.py b/tests/pipelines/test_image_denoise.py index bf8cfd0f..d95dd343 100644 --- a/tests/pipelines/test_image_denoise.py +++ b/tests/pipelines/test_image_denoise.py @@ -2,8 +2,6 @@ import unittest -from PIL import Image - from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model from modelscope.outputs import OutputKeys @@ -20,16 +18,16 @@ class ImageDenoiseTest(unittest.TestCase, DemoCompatibilityCheck): self.task = Tasks.image_denoising self.model_id = 'damo/cv_nafnet_image-denoise_sidd' - demo_image_path = 'data/test/images/noisy-demo-1.png' + demo_image_path = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/noisy-demo-0.png' @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_by_direct_model_download(self): cache_path = snapshot_download(self.model_id) pipeline = ImageDenoisePipeline(cache_path) + pipeline.group_key = self.task denoise_img = pipeline( - input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] - denoise_img = Image.fromarray(denoise_img) - w, h = denoise_img.size + input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] # BGR + h, w = denoise_img.shape[:2] print('pipeline: the shape of output_img is {}x{}'.format(h, w)) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -37,9 +35,8 @@ class ImageDenoiseTest(unittest.TestCase, DemoCompatibilityCheck): model = Model.from_pretrained(self.model_id) pipeline_ins = pipeline(task=Tasks.image_denoising, model=model) denoise_img = pipeline_ins( - input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] - denoise_img = Image.fromarray(denoise_img) - w, h = denoise_img.size + input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] # BGR + h, w = denoise_img.shape[:2] print('pipeline: the shape of output_img is {}x{}'.format(h, w)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @@ -47,18 +44,16 @@ class ImageDenoiseTest(unittest.TestCase, DemoCompatibilityCheck): pipeline_ins = pipeline( task=Tasks.image_denoising, model=self.model_id) denoise_img = pipeline_ins( - input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] - denoise_img = Image.fromarray(denoise_img) - w, h = denoise_img.size + input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] # BGR + h, w = denoise_img.shape[:2] print('pipeline: the shape of output_img is {}x{}'.format(h, w)) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_default_model(self): pipeline_ins = pipeline(task=Tasks.image_denoising) denoise_img = pipeline_ins( - input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] - denoise_img = Image.fromarray(denoise_img) - w, h = denoise_img.size + input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] # BGR + h, w = denoise_img.shape[:2] print('pipeline: the shape of output_img is {}x{}'.format(h, w)) @unittest.skip('demo compatibility test is only enabled on a needed-basis') diff --git a/tests/pipelines/test_image_inpainting.py b/tests/pipelines/test_image_inpainting.py new file mode 100644 index 00000000..a8b704b7 --- /dev/null +++ b/tests/pipelines/test_image_inpainting.py @@ -0,0 +1,75 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import cv2 +import torch +from PIL import Image + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class ImageInpaintingTest(unittest.TestCase): + + def setUp(self) -> None: + self.input_location = 'data/test/images/image_inpainting/image_inpainting.png' + self.input_mask_location = 'data/test/images/image_inpainting/image_inpainting_mask.png' + self.model_id = 'damo/cv_fft_inpainting_lama' + self.input = { + 'img': self.input_location, + 'mask': self.input_mask_location + } + + def save_result(self, result): + vis_img = result[OutputKeys.OUTPUT_IMG] + cv2.imwrite('result.png', vis_img) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_inpainting(self): + inpainting = pipeline(Tasks.image_inpainting, model=self.model_id) + result = inpainting(self.input) + if result: + self.save_result(result) + else: + raise ValueError('process error') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') + def test_inpainting_with_refinement(self): + # if input image is HR, set refine=True is more better + inpainting = pipeline( + Tasks.image_inpainting, model=self.model_id, refine=True) + result = inpainting(self.input) + if result: + self.save_result(result) + else: + raise ValueError('process error') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_inpainting_with_image(self): + inpainting = pipeline(Tasks.image_inpainting, model=self.model_id) + img = Image.open(self.input_location).convert('RGB') + mask = Image.open(self.input_mask_location).convert('RGB') + result = inpainting({'img': img, 'mask': mask}) + if result: + self.save_result(result) + else: + raise ValueError('process error') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_inpainting_with_default_task(self): + inpainting = pipeline(Tasks.image_inpainting) + result = inpainting(self.input) + if result: + self.save_result(result) + else: + raise ValueError('process error') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_image_style_transfer.py b/tests/pipelines/test_image_style_transfer.py index a02d5308..5f37f204 100644 --- a/tests/pipelines/test_image_style_transfer.py +++ b/tests/pipelines/test_image_style_transfer.py @@ -25,8 +25,9 @@ class ImageStyleTransferTest(unittest.TestCase, DemoCompatibilityCheck): Tasks.image_style_transfer, model=snapshot_path) result = image_style_transfer( - 'data/test/images/style_transfer_content.jpg', - style='data/test/images/style_transfer_style.jpg') + dict( + content='data/test/images/style_transfer_content.jpg', + style='data/test/images/style_transfer_style.jpg')) cv2.imwrite('result_styletransfer1.png', result[OutputKeys.OUTPUT_IMG]) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -35,8 +36,9 @@ class ImageStyleTransferTest(unittest.TestCase, DemoCompatibilityCheck): Tasks.image_style_transfer, model=self.model_id) result = image_style_transfer( - 'data/test/images/style_transfer_content.jpg', - style='data/test/images/style_transfer_style.jpg') + dict( + content='data/test/images/style_transfer_content.jpg', + style='data/test/images/style_transfer_style.jpg')) cv2.imwrite('result_styletransfer2.png', result[OutputKeys.OUTPUT_IMG]) print('style_transfer.test_run_modelhub done') @@ -45,8 +47,9 @@ class ImageStyleTransferTest(unittest.TestCase, DemoCompatibilityCheck): image_style_transfer = pipeline(Tasks.image_style_transfer) result = image_style_transfer( - 'data/test/images/style_transfer_content.jpg', - style='data/test/images/style_transfer_style.jpg') + dict( + content='data/test/images/style_transfer_content.jpg', + style='data/test/images/style_transfer_style.jpg')) cv2.imwrite('result_styletransfer3.png', result[OutputKeys.OUTPUT_IMG]) print('style_transfer.test_run_modelhub_default_model done') diff --git a/tests/pipelines/test_key_word_spotting.py b/tests/pipelines/test_key_word_spotting.py index 91f9f566..f31d212b 100644 --- a/tests/pipelines/test_key_word_spotting.py +++ b/tests/pipelines/test_key_word_spotting.py @@ -245,7 +245,7 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_wav_by_customized_keywords(self): - keywords = [{'keyword': '播放音乐'}] + keywords = '播放音乐' kws_result = self.run_pipeline( model_id=self.model_id, diff --git a/tests/pipelines/test_key_word_spotting_farfield.py b/tests/pipelines/test_key_word_spotting_farfield.py index f8c167de..69d6a953 100644 --- a/tests/pipelines/test_key_word_spotting_farfield.py +++ b/tests/pipelines/test_key_word_spotting_farfield.py @@ -9,9 +9,8 @@ from modelscope.utils.test_utils import test_level TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav' TEST_SPEECH_FILE_MONO = 'data/test/audios/1ch_nihaomiya.wav' -TEST_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \ - 'speech_dfsmn_kws_char_farfield_16k_nihaomiya/repo' \ - '?Revision=master&FilePath=examples/3ch_nihaomiya.wav' +TEST_SPEECH_URL = 'https://modelscope.oss-cn-beijing.aliyuncs.com/' \ + 'test/audios/3ch_nihaomiya.wav' class KWSFarfieldTest(unittest.TestCase): @@ -22,18 +21,14 @@ class KWSFarfieldTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_normal(self): kws = pipeline(Tasks.keyword_spotting, model=self.model_id) - inputs = {'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE)} - result = kws(inputs) + result = kws(os.path.join(os.getcwd(), TEST_SPEECH_FILE)) self.assertEqual(len(result['kws_list']), 5) print(result['kws_list'][-1]) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_mono(self): kws = pipeline(Tasks.keyword_spotting, model=self.model_id) - inputs = { - 'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE_MONO) - } - result = kws(inputs) + result = kws(os.path.join(os.getcwd(), TEST_SPEECH_FILE_MONO)) self.assertEqual(len(result['kws_list']), 5) print(result['kws_list'][-1]) @@ -44,17 +39,6 @@ class KWSFarfieldTest(unittest.TestCase): self.assertEqual(len(result['kws_list']), 5) print(result['kws_list'][-1]) - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') - def test_output(self): - kws = pipeline(Tasks.keyword_spotting, model=self.model_id) - inputs = { - 'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE), - 'output_file': 'output.wav' - } - result = kws(inputs) - self.assertEqual(len(result['kws_list']), 5) - print(result['kws_list'][-1]) - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_input_bytes(self): with open(os.path.join(os.getcwd(), TEST_SPEECH_FILE), 'rb') as f: diff --git a/tests/pipelines/test_mplug_tasks.py b/tests/pipelines/test_mplug_tasks.py index a3ace62d..21439ce2 100644 --- a/tests/pipelines/test_mplug_tasks.py +++ b/tests/pipelines/test_mplug_tasks.py @@ -13,10 +13,6 @@ from modelscope.utils.test_utils import test_level class MplugTasksTest(unittest.TestCase, DemoCompatibilityCheck): - def setUp(self) -> None: - self.task = 'visual-question-answering' - self.model_id = 'damo/mplug_visual-question-answering_coco_large_en' - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_image_captioning_with_model(self): model = Model.from_pretrained( @@ -26,7 +22,7 @@ class MplugTasksTest(unittest.TestCase, DemoCompatibilityCheck): model=model, ) image = Image.open('data/test/images/image_mplug_vqa.jpg') - result = pipeline_caption({'image': image}) + result = pipeline_caption(image) print(result[OutputKeys.CAPTION]) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -35,7 +31,7 @@ class MplugTasksTest(unittest.TestCase, DemoCompatibilityCheck): Tasks.image_captioning, model='damo/mplug_image-captioning_coco_base_en') image = Image.open('data/test/images/image_mplug_vqa.jpg') - result = pipeline_caption({'image': image}) + result = pipeline_caption(image) print(result[OutputKeys.CAPTION]) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @@ -80,6 +76,25 @@ class MplugTasksTest(unittest.TestCase, DemoCompatibilityCheck): result = pipeline_retrieval(input) print(result) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_image_captioning_zh_base_with_name(self): + pipeline_caption = pipeline( + Tasks.image_captioning, + model='damo/mplug_image-captioning_coco_base_zh') + image = Image.open('data/test/images/image_mplug_vqa.jpg') + result = pipeline_caption(image) + print(result[OutputKeys.CAPTION]) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_visual_question_answering_zh_base_with_name(self): + model = 'damo/mplug_visual-question-answering_coco_base_zh' + pipeline_vqa = pipeline(Tasks.visual_question_answering, model=model) + image = Image.open('data/test/images/image_mplug_vqa.jpg') + text = '这个女人在做什么?' + input = {'image': image, 'text': text} + result = pipeline_vqa(input) + print(result) + @unittest.skip('demo compatibility test is only enabled on a needed-basis') def test_demo_compatibility(self): self.compatibility_check() diff --git a/tests/pipelines/test_multi_modal_embedding.py b/tests/pipelines/test_multi_modal_embedding.py index 23954c27..ee9cdb1f 100644 --- a/tests/pipelines/test_multi_modal_embedding.py +++ b/tests/pipelines/test_multi_modal_embedding.py @@ -19,14 +19,11 @@ class MultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): self.model_id = 'damo/multi-modal_clip-vit-base-patch16_zh' test_input = {'text': '皮卡丘'} - model_version = 'dev' @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run(self): pipeline_multi_modal_embedding = pipeline( - Tasks.multi_modal_embedding, - model=self.model_id, - model_revision=self.model_version) + Tasks.multi_modal_embedding, model=self.model_id) text_embedding = pipeline_multi_modal_embedding( self.test_input)[OutputKeys.TEXT_EMBEDDING] print('l1-norm: {}'.format( @@ -36,8 +33,7 @@ class MultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_model_from_modelhub(self): - model = Model.from_pretrained( - self.model_id, revision=self.model_version) + model = Model.from_pretrained(self.model_id) pipeline_multi_modal_embedding = pipeline( task=Tasks.multi_modal_embedding, model=model) text_embedding = pipeline_multi_modal_embedding( @@ -50,8 +46,7 @@ class MultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_default_model(self): pipeline_multi_modal_embedding = pipeline( - task=Tasks.multi_modal_embedding, - model_revision=self.model_version) + task=Tasks.multi_modal_embedding) text_embedding = pipeline_multi_modal_embedding( self.test_input)[OutputKeys.TEXT_EMBEDDING] print('l1-norm: {}'.format( diff --git a/tests/pipelines/test_multilingual_named_entity_recognition.py b/tests/pipelines/test_multilingual_named_entity_recognition.py new file mode 100644 index 00000000..6f72c83c --- /dev/null +++ b/tests/pipelines/test_multilingual_named_entity_recognition.py @@ -0,0 +1,102 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import (LSTMCRFForNamedEntityRecognition, + TransformerCRFForNamedEntityRecognition) +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import (NamedEntityRecognitionThaiPipeline, + NamedEntityRecognitionVietPipeline) +from modelscope.preprocessors import NERPreprocessorThai, NERPreprocessorViet +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class MultilingualNamedEntityRecognitionTest(unittest.TestCase, + DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.named_entity_recognition + self.model_id = 'damo/nlp_xlmr_named-entity-recognition_thai-ecommerce-title' + + thai_tcrf_model_id = 'damo/nlp_xlmr_named-entity-recognition_thai-ecommerce-title' + thai_sentence = 'เครื่องชั่งดิจิตอลแบบตั้งพื้น150kg.' + + viet_tcrf_model_id = 'damo/nlp_xlmr_named-entity-recognition_viet-ecommerce-title' + viet_sentence = 'Nón vành dễ thương cho bé gái' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_tcrf_by_direct_model_download_thai(self): + cache_path = snapshot_download(self.thai_tcrf_model_id) + tokenizer = NERPreprocessorThai(cache_path) + model = TransformerCRFForNamedEntityRecognition( + cache_path, tokenizer=tokenizer) + pipeline1 = NamedEntityRecognitionThaiPipeline( + model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.named_entity_recognition, + model=model, + preprocessor=tokenizer) + print(f'thai_sentence: {self.thai_sentence}\n' + f'pipeline1:{pipeline1(input=self.thai_sentence)}') + print() + print(f'pipeline2: {pipeline2(input=self.thai_sentence)}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_tcrf_with_model_from_modelhub_thai(self): + model = Model.from_pretrained(self.thai_tcrf_model_id) + tokenizer = NERPreprocessorThai(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, + model=model, + preprocessor=tokenizer) + print(pipeline_ins(input=self.thai_sentence)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_tcrf_with_model_name_thai(self): + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, model=self.thai_tcrf_model_id) + print(pipeline_ins(input=self.thai_sentence)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_tcrf_by_direct_model_download_viet(self): + cache_path = snapshot_download(self.viet_tcrf_model_id) + tokenizer = NERPreprocessorViet(cache_path) + model = TransformerCRFForNamedEntityRecognition( + cache_path, tokenizer=tokenizer) + pipeline1 = NamedEntityRecognitionVietPipeline( + model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.named_entity_recognition, + model=model, + preprocessor=tokenizer) + print(f'viet_sentence: {self.viet_sentence}\n' + f'pipeline1:{pipeline1(input=self.viet_sentence)}') + print() + print(f'pipeline2: {pipeline2(input=self.viet_sentence)}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_tcrf_with_model_from_modelhub_viet(self): + model = Model.from_pretrained(self.viet_tcrf_model_id) + tokenizer = NERPreprocessorViet(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, + model=model, + preprocessor=tokenizer) + print(pipeline_ins(input=self.viet_sentence)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_tcrf_with_model_name_viet(self): + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, model=self.viet_tcrf_model_id) + print(pipeline_ins(input=self.viet_sentence)) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_multilingual_word_segmentation.py b/tests/pipelines/test_multilingual_word_segmentation.py new file mode 100644 index 00000000..25b4b241 --- /dev/null +++ b/tests/pipelines/test_multilingual_word_segmentation.py @@ -0,0 +1,57 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import TransformerCRFForWordSegmentation +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import WordSegmentationThaiPipeline +from modelscope.preprocessors import WordSegmentationPreprocessorThai +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.regress_test_utils import MsRegressTool +from modelscope.utils.test_utils import test_level + + +class WordSegmentationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.word_segmentation + self.model_id = 'damo/nlp_xlmr_word-segmentation_thai' + + sentence = 'รถคันเก่าก็ยังเก็บเอาไว้ยังไม่ได้ขาย' + regress_tool = MsRegressTool(baseline=False) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + tokenizer = WordSegmentationPreprocessorThai(cache_path) + model = TransformerCRFForWordSegmentation.from_pretrained(cache_path) + pipeline1 = WordSegmentationThaiPipeline(model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.word_segmentation, model=model, preprocessor=tokenizer) + print(f'sentence: {self.sentence}\n' + f'pipeline1:{pipeline1(input=self.sentence)}') + print(f'pipeline2: {pipeline2(input=self.sentence)}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + tokenizer = WordSegmentationPreprocessorThai(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.word_segmentation, model=model, preprocessor=tokenizer) + print(pipeline_ins(input=self.sentence)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.word_segmentation, model=self.model_id) + print(pipeline_ins(input=self.sentence)) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_nli.py b/tests/pipelines/test_nli.py index db4b9912..5f2dcb25 100644 --- a/tests/pipelines/test_nli.py +++ b/tests/pipelines/test_nli.py @@ -5,7 +5,7 @@ from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model from modelscope.models.nlp import SbertForSequenceClassification from modelscope.pipelines import pipeline -from modelscope.pipelines.nlp import SequenceClassificationPipeline +from modelscope.pipelines.nlp import TextClassificationPipeline from modelscope.preprocessors import SequenceClassificationPreprocessor from modelscope.utils.constant import Tasks from modelscope.utils.demo_utils import DemoCompatibilityCheck @@ -27,9 +27,8 @@ class NLITest(unittest.TestCase, DemoCompatibilityCheck): def test_run_with_direct_file_download(self): cache_path = snapshot_download(self.model_id) tokenizer = SequenceClassificationPreprocessor(cache_path) - model = SbertForSequenceClassification.from_pretrained(cache_path) - pipeline1 = SequenceClassificationPipeline( - model, preprocessor=tokenizer) + model = Model.from_pretrained(cache_path) + pipeline1 = TextClassificationPipeline(model, preprocessor=tokenizer) pipeline2 = pipeline(Tasks.nli, model=model, preprocessor=tokenizer) print(f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' f'pipeline1:{pipeline1(input=(self.sentence1, self.sentence2))}') diff --git a/tests/pipelines/test_object_detection.py b/tests/pipelines/test_object_detection.py index 2a74eb41..64766c77 100644 --- a/tests/pipelines/test_object_detection.py +++ b/tests/pipelines/test_object_detection.py @@ -19,20 +19,14 @@ class ObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): model_id = 'damo/cv_vit_object-detection_coco' object_detect = pipeline(Tasks.image_object_detection, model=model_id) result = object_detect(input_location) - if result: - print(result) - else: - raise ValueError('process error') + print(result) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_object_detection_with_default_task(self): input_location = 'data/test/images/image_detection.jpg' object_detect = pipeline(Tasks.image_object_detection) result = object_detect(input_location) - if result: - print(result) - else: - raise ValueError('process error') + print(result) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_human_detection(self): @@ -40,25 +34,31 @@ class ObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): model_id = 'damo/cv_resnet18_human-detection' human_detect = pipeline(Tasks.human_detection, model=model_id) result = human_detect(input_location) - if result: - print(result) - else: - raise ValueError('process error') + print(result) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_human_detection_with_default_task(self): input_location = 'data/test/images/image_detection.jpg' human_detect = pipeline(Tasks.human_detection) result = human_detect(input_location) - if result: - print(result) - else: - raise ValueError('process error') + print(result) @unittest.skip('demo compatibility test is only enabled on a needed-basis') def test_demo_compatibility(self): self.compatibility_check() + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_image_object_detection_auto_pipeline(self): + model_id = 'damo/cv_yolox_image-object-detection-auto' + test_image = 'data/test/images/auto_demo.jpg' + + image_object_detection_auto = pipeline( + Tasks.image_object_detection, model=model_id) + + result = image_object_detection_auto(test_image) + image_object_detection_auto.show_result(test_image, result, + 'auto_demo_ret.jpg') + if __name__ == '__main__': unittest.main() diff --git a/tests/pipelines/test_ofa_tasks.py b/tests/pipelines/test_ofa_tasks.py index e6638dfa..57dcb0c3 100644 --- a/tests/pipelines/test_ofa_tasks.py +++ b/tests/pipelines/test_ofa_tasks.py @@ -34,7 +34,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): model=model, ) image = 'data/test/images/image_captioning.png' - result = img_captioning({'image': image}) + result = img_captioning(image) print(result[OutputKeys.CAPTION]) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -42,18 +42,24 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): img_captioning = pipeline( Tasks.image_captioning, model='damo/ofa_image-caption_coco_large_en') - result = img_captioning( - {'image': 'data/test/images/image_captioning.png'}) + result = img_captioning('data/test/images/image_captioning.png') print(result[OutputKeys.CAPTION]) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_ocr_recognize_with_name(self): + ocr_recognize = pipeline( + Tasks.ocr_recognition, + model='damo/ofa_ocr-recognition_scene_base_zh') + result = ocr_recognize('data/test/images/image_ocr_recognition.jpg') + print(result[OutputKeys.TEXT]) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_image_classification_with_model(self): model = Model.from_pretrained( 'damo/ofa_image-classification_imagenet_large_en') ofa_pipe = pipeline(Tasks.image_classification, model=model) image = 'data/test/images/image_classification.png' - input = {'image': image} - result = ofa_pipe(input) + result = ofa_pipe(image) print(result) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -62,15 +68,14 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): Tasks.image_classification, model='damo/ofa_image-classification_imagenet_large_en') image = 'data/test/images/image_classification.png' - input = {'image': image} - result = ofa_pipe(input) + result = ofa_pipe(image) print(result) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_summarization_with_model(self): model = Model.from_pretrained( 'damo/ofa_summarization_gigaword_large_en') - ofa_pipe = pipeline(Tasks.summarization, model=model) + ofa_pipe = pipeline(Tasks.text_summarization, model=model) text = 'five-time world champion michelle kwan withdrew' + \ 'from the #### us figure skating championships on wednesday ,' + \ ' but will petition us skating officials for the chance to ' + \ @@ -82,7 +87,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_summarization_with_name(self): ofa_pipe = pipeline( - Tasks.summarization, + Tasks.text_summarization, model='damo/ofa_summarization_gigaword_large_en') text = 'five-time world champion michelle kwan withdrew' + \ 'from the #### us figure skating championships on wednesday ,' + \ @@ -99,8 +104,8 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): ofa_pipe = pipeline(Tasks.text_classification, model=model) text = 'One of our number will carry out your instructions minutely.' text2 = 'A member of my team will execute your orders with immense precision.' - input = {'text': text, 'text2': text2} - result = ofa_pipe(input) + result = ofa_pipe((text, text2)) + result = ofa_pipe({'text': text, 'text2': text2}) print(result) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -110,8 +115,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): model='damo/ofa_text-classification_mnli_large_en') text = 'One of our number will carry out your instructions minutely.' text2 = 'A member of my team will execute your orders with immense precision.' - input = {'text': text, 'text2': text2} - result = ofa_pipe(input) + result = ofa_pipe((text, text2)) print(result) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') diff --git a/tests/pipelines/test_part_of_speech.py b/tests/pipelines/test_part_of_speech.py index 25f4491c..038a90f0 100644 --- a/tests/pipelines/test_part_of_speech.py +++ b/tests/pipelines/test_part_of_speech.py @@ -13,7 +13,7 @@ from modelscope.utils.test_utils import test_level class PartOfSpeechTest(unittest.TestCase): - model_id = 'damo/nlp_structbert_part-of-speech_chinese-base' + model_id = 'damo/nlp_structbert_part-of-speech_chinese-lite' sentence = '今天天气不错,适合出去游玩' @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @@ -23,7 +23,7 @@ class PartOfSpeechTest(unittest.TestCase): model = TokenClassificationModel.from_pretrained(cache_path) pipeline1 = TokenClassificationPipeline(model, preprocessor=tokenizer) pipeline2 = pipeline( - Tasks.token_classification, model=model, preprocessor=tokenizer) + Tasks.part_of_speech, model=model, preprocessor=tokenizer) print(f'sentence: {self.sentence}\n' f'pipeline1:{pipeline1(input=self.sentence)}') print() @@ -34,20 +34,17 @@ class PartOfSpeechTest(unittest.TestCase): model = Model.from_pretrained(self.model_id) tokenizer = TokenClassificationPreprocessor(model.model_dir) pipeline_ins = pipeline( - task=Tasks.token_classification, - model=model, - preprocessor=tokenizer) + task=Tasks.part_of_speech, model=model, preprocessor=tokenizer) print(pipeline_ins(input=self.sentence)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_model_name(self): - pipeline_ins = pipeline( - task=Tasks.token_classification, model=self.model_id) + pipeline_ins = pipeline(task=Tasks.part_of_speech, model=self.model_id) print(pipeline_ins(input=self.sentence)) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_default_model(self): - pipeline_ins = pipeline(task=Tasks.token_classification) + pipeline_ins = pipeline(task=Tasks.part_of_speech) print(pipeline_ins(input=self.sentence)) diff --git a/tests/pipelines/test_passage_ranking.py b/tests/pipelines/test_passage_ranking.py deleted file mode 100644 index 5faa365e..00000000 --- a/tests/pipelines/test_passage_ranking.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -import shutil -import unittest - -from modelscope.hub.snapshot_download import snapshot_download -from modelscope.models import Model -from modelscope.models.nlp import PassageRanking -from modelscope.pipelines import pipeline -from modelscope.pipelines.nlp import PassageRankingPipeline -from modelscope.preprocessors import PassageRankingPreprocessor -from modelscope.utils.constant import Tasks -from modelscope.utils.test_utils import test_level - - -class PassageRankingTest(unittest.TestCase): - model_id = 'damo/nlp_corom_passage-ranking_english-base' - inputs = { - 'source_sentence': ["how long it take to get a master's degree"], - 'sentences_to_compare': [ - "On average, students take about 18 to 24 months to complete a master's degree.", - 'On the other hand, some students prefer to go at a slower pace and choose to take ' - 'several years to complete their studies.', - 'It can take anywhere from two semesters' - ] - } - - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_run_by_direct_model_download(self): - cache_path = snapshot_download(self.model_id) - tokenizer = PassageRankingPreprocessor(cache_path) - model = PassageRanking.from_pretrained(cache_path) - pipeline1 = PassageRankingPipeline(model, preprocessor=tokenizer) - pipeline2 = pipeline( - Tasks.passage_ranking, model=model, preprocessor=tokenizer) - print(f'sentence: {self.inputs}\n' - f'pipeline1:{pipeline1(input=self.inputs)}') - print() - print(f'pipeline2: {pipeline2(input=self.inputs)}') - - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_run_with_model_from_modelhub(self): - model = Model.from_pretrained(self.model_id) - tokenizer = PassageRankingPreprocessor(model.model_dir) - pipeline_ins = pipeline( - task=Tasks.passage_ranking, model=model, preprocessor=tokenizer) - print(pipeline_ins(input=self.inputs)) - - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') - def test_run_with_model_name(self): - pipeline_ins = pipeline( - task=Tasks.passage_ranking, model=self.model_id) - print(pipeline_ins(input=self.inputs)) - - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_run_with_default_model(self): - pipeline_ins = pipeline(task=Tasks.passage_ranking) - print(pipeline_ins(input=self.inputs)) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/pipelines/test_realtime_video_object_detection.py b/tests/pipelines/test_realtime_video_object_detection.py new file mode 100644 index 00000000..d65313a3 --- /dev/null +++ b/tests/pipelines/test_realtime_video_object_detection.py @@ -0,0 +1,46 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import cv2 +import numpy as np + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import show_video_object_detection_result +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class RealtimeVideoObjectDetectionTest(unittest.TestCase, + DemoCompatibilityCheck): + + def setUp(self) -> None: + self.model_id = 'damo/cv_cspnet_video-object-detection_streamyolo' + self.test_video = 'data/test/videos/test_realtime_vod.mp4' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + realtime_video_object_detection = pipeline( + Tasks.video_object_detection, model=self.model_id) + result = realtime_video_object_detection(self.test_video) + if result: + logger.info('Video output to test_vod_results.avi') + show_video_object_detection_result(self.test_video, + result[OutputKeys.BOXES], + result[OutputKeys.LABELS], + 'test_vod_results.avi') + else: + raise ValueError('process error') + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_referring_video_object_segmentation.py b/tests/pipelines/test_referring_video_object_segmentation.py new file mode 100644 index 00000000..3e81d9c3 --- /dev/null +++ b/tests/pipelines/test_referring_video_object_segmentation.py @@ -0,0 +1,56 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ReferringVideoObjectSegmentationTest(unittest.TestCase, + DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.referring_video_object_segmentation + self.model_id = 'damo/cv_swin-t_referring_video-object-segmentation' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_referring_video_object_segmentation(self): + input_location = 'data/test/videos/referring_video_object_segmentation_test_video.mp4' + text_queries = [ + 'guy in black performing tricks on a bike', + 'a black bike used to perform tricks' + ] + start_pt, end_pt = 4, 14 + input_tuple = (input_location, text_queries, start_pt, end_pt) + pp = pipeline( + Tasks.referring_video_object_segmentation, model=self.model_id) + result = pp(input_tuple) + if result: + print(result) + else: + raise ValueError('process error') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_referring_video_object_segmentation_with_default_task(self): + input_location = 'data/test/videos/referring_video_object_segmentation_test_video.mp4' + text_queries = [ + 'guy in black performing tricks on a bike', + 'a black bike used to perform tricks' + ] + start_pt, end_pt = 4, 14 + input_tuple = (input_location, text_queries, start_pt, end_pt) + pp = pipeline(Tasks.referring_video_object_segmentation) + result = pp(input_tuple) + if result: + print(result) + else: + raise ValueError('process error') + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_relation_extraction.py b/tests/pipelines/test_relation_extraction.py index 57d98f66..561eaf21 100644 --- a/tests/pipelines/test_relation_extraction.py +++ b/tests/pipelines/test_relation_extraction.py @@ -15,7 +15,7 @@ from modelscope.utils.test_utils import test_level class RelationExtractionTest(unittest.TestCase, DemoCompatibilityCheck): def setUp(self) -> None: - self.task = Tasks.information_extraction + self.task = Tasks.relation_extraction self.model_id = 'damo/nlp_bert_relation-extraction_chinese-base' sentence = '高捷,祖籍江苏,本科毕业于东南大学' @@ -28,7 +28,7 @@ class RelationExtractionTest(unittest.TestCase, DemoCompatibilityCheck): pipeline1 = InformationExtractionPipeline( model, preprocessor=tokenizer) pipeline2 = pipeline( - Tasks.information_extraction, model=model, preprocessor=tokenizer) + Tasks.relation_extraction, model=model, preprocessor=tokenizer) print(f'sentence: {self.sentence}\n' f'pipeline1:{pipeline1(input=self.sentence)}') print() @@ -39,7 +39,7 @@ class RelationExtractionTest(unittest.TestCase, DemoCompatibilityCheck): model = Model.from_pretrained(self.model_id) tokenizer = RelationExtractionPreprocessor(model.model_dir) pipeline_ins = pipeline( - task=Tasks.information_extraction, + task=Tasks.relation_extraction, model=model, preprocessor=tokenizer) print(pipeline_ins(input=self.sentence)) @@ -47,12 +47,12 @@ class RelationExtractionTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_model_name(self): pipeline_ins = pipeline( - task=Tasks.information_extraction, model=self.model_id) + task=Tasks.relation_extraction, model=self.model_id) print(pipeline_ins(input=self.sentence)) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_default_model(self): - pipeline_ins = pipeline(task=Tasks.information_extraction) + pipeline_ins = pipeline(task=Tasks.relation_extraction) print(pipeline_ins(input=self.sentence)) @unittest.skip('demo compatibility test is only enabled on a needed-basis') diff --git a/tests/pipelines/test_sentence_embedding.py b/tests/pipelines/test_sentence_embedding.py index 739dd7ab..e96724a8 100644 --- a/tests/pipelines/test_sentence_embedding.py +++ b/tests/pipelines/test_sentence_embedding.py @@ -4,7 +4,7 @@ import unittest from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model -from modelscope.models.nlp import SentenceEmbedding +from modelscope.models.nlp import BertForSentenceEmbedding from modelscope.pipelines import pipeline from modelscope.pipelines.nlp import SentenceEmbeddingPipeline from modelscope.preprocessors import SentenceEmbeddingPreprocessor @@ -40,7 +40,7 @@ class SentenceEmbeddingTest(unittest.TestCase): def test_run_by_direct_model_download(self): cache_path = snapshot_download(self.model_id) tokenizer = SentenceEmbeddingPreprocessor(cache_path) - model = SentenceEmbedding.from_pretrained(cache_path) + model = BertForSentenceEmbedding.from_pretrained(cache_path) pipeline1 = SentenceEmbeddingPipeline(model, preprocessor=tokenizer) pipeline2 = pipeline( Tasks.sentence_embedding, model=model, preprocessor=tokenizer) diff --git a/tests/pipelines/test_sentence_similarity.py b/tests/pipelines/test_sentence_similarity.py index 288d38c7..76db0a8f 100644 --- a/tests/pipelines/test_sentence_similarity.py +++ b/tests/pipelines/test_sentence_similarity.py @@ -5,7 +5,7 @@ from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model from modelscope.models.nlp import SbertForSequenceClassification from modelscope.pipelines import pipeline -from modelscope.pipelines.nlp import SequenceClassificationPipeline +from modelscope.pipelines.nlp import TextClassificationPipeline from modelscope.preprocessors import SequenceClassificationPreprocessor from modelscope.utils.constant import Tasks from modelscope.utils.demo_utils import DemoCompatibilityCheck @@ -28,8 +28,7 @@ class SentenceSimilarityTest(unittest.TestCase, DemoCompatibilityCheck): cache_path = snapshot_download(self.model_id) tokenizer = SequenceClassificationPreprocessor(cache_path) model = SbertForSequenceClassification.from_pretrained(cache_path) - pipeline1 = SequenceClassificationPipeline( - model, preprocessor=tokenizer) + pipeline1 = TextClassificationPipeline(model, preprocessor=tokenizer) pipeline2 = pipeline( Tasks.sentence_similarity, model=model, preprocessor=tokenizer) print('test1') diff --git a/tests/pipelines/test_sentiment_classification.py b/tests/pipelines/test_sentiment_classification.py index d0b1b40f..5c8d4e93 100644 --- a/tests/pipelines/test_sentiment_classification.py +++ b/tests/pipelines/test_sentiment_classification.py @@ -6,7 +6,7 @@ from modelscope.models import Model from modelscope.models.nlp.task_models.sequence_classification import \ SequenceClassificationModel from modelscope.pipelines import pipeline -from modelscope.pipelines.nlp import SequenceClassificationPipeline +from modelscope.pipelines.nlp import TextClassificationPipeline from modelscope.preprocessors import SequenceClassificationPreprocessor from modelscope.utils.constant import Tasks from modelscope.utils.demo_utils import DemoCompatibilityCheck @@ -24,12 +24,11 @@ class SentimentClassificationTaskModelTest(unittest.TestCase, @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_direct_file_download(self): - cache_path = snapshot_download(self.model_id, revision='beta') + cache_path = snapshot_download(self.model_id) tokenizer = SequenceClassificationPreprocessor(cache_path) model = SequenceClassificationModel.from_pretrained( - self.model_id, num_labels=2, revision='beta') - pipeline1 = SequenceClassificationPipeline( - model, preprocessor=tokenizer) + self.model_id, num_labels=2) + pipeline1 = TextClassificationPipeline(model, preprocessor=tokenizer) pipeline2 = pipeline( Tasks.text_classification, model=model, preprocessor=tokenizer) print(f'sentence1: {self.sentence1}\n' @@ -39,7 +38,7 @@ class SentimentClassificationTaskModelTest(unittest.TestCase, @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_model_from_modelhub(self): - model = Model.from_pretrained(self.model_id, revision='beta') + model = Model.from_pretrained(self.model_id) tokenizer = SequenceClassificationPreprocessor(model.model_dir) pipeline_ins = pipeline( task=Tasks.text_classification, @@ -52,17 +51,14 @@ class SentimentClassificationTaskModelTest(unittest.TestCase, @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_name(self): pipeline_ins = pipeline( - task=Tasks.text_classification, - model=self.model_id, - model_revision='beta') + task=Tasks.text_classification, model=self.model_id) print(pipeline_ins(input=self.sentence1)) self.assertTrue( isinstance(pipeline_ins.model, SequenceClassificationModel)) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_default_model(self): - pipeline_ins = pipeline( - task=Tasks.text_classification, model_revision='beta') + pipeline_ins = pipeline(task=Tasks.text_classification) print(pipeline_ins(input=self.sentence1)) self.assertTrue( isinstance(pipeline_ins.model, SequenceClassificationModel)) diff --git a/tests/pipelines/test_speech_signal_process.py b/tests/pipelines/test_speech_signal_process.py index e5f97c02..2916d31a 100644 --- a/tests/pipelines/test_speech_signal_process.py +++ b/tests/pipelines/test_speech_signal_process.py @@ -11,17 +11,14 @@ from modelscope.utils.test_utils import test_level NEAREND_MIC_FILE = 'data/test/audios/nearend_mic.wav' FAREND_SPEECH_FILE = 'data/test/audios/farend_speech.wav' -NEAREND_MIC_URL = 'https://modelscope.cn/api/v1/models/damo/' \ - 'speech_dfsmn_aec_psm_16k/repo?Revision=master' \ - '&FilePath=examples/nearend_mic.wav' -FAREND_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \ - 'speech_dfsmn_aec_psm_16k/repo?Revision=master' \ - '&FilePath=examples/farend_speech.wav' +NEAREND_MIC_URL = 'https://modelscope.oss-cn-beijing.aliyuncs.com/' \ + 'test/audios/nearend_mic.wav' +FAREND_SPEECH_URL = 'https://modelscope.oss-cn-beijing.aliyuncs.com/' \ + 'test/audios/farend_speech.wav' NOISE_SPEECH_FILE = 'data/test/audios/speech_with_noise.wav' -NOISE_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \ - 'speech_frcrn_ans_cirm_16k/repo?Revision=master' \ - '&FilePath=examples/speech_with_noise.wav' +NOISE_SPEECH_URL = 'https://modelscope.oss-cn-beijing.aliyuncs.com/' \ + 'test/audios/speech_with_noise.wav' class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck): diff --git a/tests/pipelines/test_table_question_answering.py b/tests/pipelines/test_table_question_answering.py index 7ea28725..825d8f23 100644 --- a/tests/pipelines/test_table_question_answering.py +++ b/tests/pipelines/test_table_question_answering.py @@ -1,57 +1,173 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os import unittest +from threading import Thread +from typing import List +import json from transformers import BertTokenizer from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model +from modelscope.outputs import OutputKeys from modelscope.pipelines import pipeline from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline from modelscope.preprocessors import TableQuestionAnsweringPreprocessor -from modelscope.preprocessors.star3.fields.database import Database +from modelscope.preprocessors.nlp.space_T_cn.fields.database import Database from modelscope.utils.constant import ModelFile, Tasks -from modelscope.utils.nlp.nlp_utils import tableqa_tracking_and_print_results from modelscope.utils.test_utils import test_level +def tableqa_tracking_and_print_results_with_history( + pipelines: List[TableQuestionAnsweringPipeline]): + test_case = { + 'utterance': [ + '有哪些风险类型?', + '风险类型有多少种?', + '珠江流域的小(2)型水库的库容总量是多少?', + '那平均值是多少?', + '那水库的名称呢?', + '换成中型的呢?', + '枣庄营业厅的电话', + '那地址呢?', + '枣庄营业厅的电话和地址', + ] + } + for p in pipelines: + historical_queries = None + for question in test_case['utterance']: + output_dict = p({ + 'question': question, + 'history_sql': historical_queries + })[OutputKeys.OUTPUT] + print('question', question) + print('sql text:', output_dict[OutputKeys.SQL_STRING]) + print('sql query:', output_dict[OutputKeys.SQL_QUERY]) + print('query result:', output_dict[OutputKeys.QUERT_RESULT]) + print('json dumps', json.dumps(output_dict, ensure_ascii=False)) + print() + historical_queries = output_dict[OutputKeys.HISTORY] + + +def tableqa_tracking_and_print_results_without_history( + pipelines: List[TableQuestionAnsweringPipeline]): + test_case = { + 'utterance': [ + '有哪些风险类型?', + '风险类型有多少种?', + '珠江流域的小(2)型水库的库容总量是多少?', + '枣庄营业厅的电话', + '枣庄营业厅的电话和地址', + ] + } + for p in pipelines: + for question in test_case['utterance']: + output_dict = p({'question': question})[OutputKeys.OUTPUT] + print('question', question) + print('sql text:', output_dict[OutputKeys.SQL_STRING]) + print('sql query:', output_dict[OutputKeys.SQL_QUERY]) + print('query result:', output_dict[OutputKeys.QUERT_RESULT]) + print('json dumps', json.dumps(output_dict, ensure_ascii=False)) + print() + + +def tableqa_tracking_and_print_results_with_tableid( + pipelines: List[TableQuestionAnsweringPipeline]): + test_case = { + 'utterance': [ + ['有哪些风险类型?', 'fund'], + ['风险类型有多少种?', 'reservoir'], + ['珠江流域的小(2)型水库的库容总量是多少?', 'reservoir'], + ['那平均值是多少?', 'reservoir'], + ['那水库的名称呢?', 'reservoir'], + ['换成中型的呢?', 'reservoir'], + ['枣庄营业厅的电话', 'business'], + ['那地址呢?', 'business'], + ['枣庄营业厅的电话和地址', 'business'], + ], + } + for p in pipelines: + historical_queries = None + for question, table_id in test_case['utterance']: + output_dict = p({ + 'question': question, + 'table_id': table_id, + 'history_sql': historical_queries + })[OutputKeys.OUTPUT] + print('question', question) + print('sql text:', output_dict[OutputKeys.SQL_STRING]) + print('sql query:', output_dict[OutputKeys.SQL_QUERY]) + print('query result:', output_dict[OutputKeys.QUERT_RESULT]) + print('json dumps', json.dumps(output_dict, ensure_ascii=False)) + print() + historical_queries = output_dict[OutputKeys.HISTORY] + + class TableQuestionAnswering(unittest.TestCase): def setUp(self) -> None: self.task = Tasks.table_question_answering self.model_id = 'damo/nlp_convai_text2sql_pretrain_cn' - model_id = 'damo/nlp_convai_text2sql_pretrain_cn' - test_case = { - 'utterance': - ['长江流域的小(2)型水库的库容总量是多少?', '那平均值是多少?', '那水库的名称呢?', '换成中型的呢?'] - } - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_by_direct_model_download(self): cache_path = snapshot_download(self.model_id) preprocessor = TableQuestionAnsweringPreprocessor(model_dir=cache_path) pipelines = [ - TableQuestionAnsweringPipeline( - model=cache_path, preprocessor=preprocessor) + pipeline( + Tasks.table_question_answering, + model=cache_path, + preprocessor=preprocessor) ] - tableqa_tracking_and_print_results(self.test_case, pipelines) + tableqa_tracking_and_print_results_with_history(pipelines) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download_with_multithreads(self): + cache_path = snapshot_download(self.model_id) + pl = pipeline(Tasks.table_question_answering, model=cache_path) + + def print_func(pl, i): + result = pl({ + 'question': '上个月收益从低到高排前七的基金的名称和风险等级是什么', + 'table_id': 'fund', + 'history_sql': None + }) + print(i, result[OutputKeys.OUTPUT][OutputKeys.SQL_QUERY], + result[OutputKeys.OUTPUT][OutputKeys.QUERT_RESULT], + json.dumps(result)) + + procs = [] + for i in range(5): + proc = Thread(target=print_func, args=(pl, i)) + procs.append(proc) + proc.start() + for proc in procs: + proc.join() @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_model_from_modelhub(self): model = Model.from_pretrained(self.model_id) + self.tokenizer = BertTokenizer( + os.path.join(model.model_dir, ModelFile.VOCAB_FILE)) + db = Database( + tokenizer=self.tokenizer, + table_file_path=[ + os.path.join(model.model_dir, 'databases', fname) + for fname in os.listdir( + os.path.join(model.model_dir, 'databases')) + ], + syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'), + is_use_sqlite=False) preprocessor = TableQuestionAnsweringPreprocessor( - model_dir=model.model_dir) + model_dir=model.model_dir, db=db) pipelines = [ - TableQuestionAnsweringPipeline( - model=model, preprocessor=preprocessor) + pipeline( + Tasks.table_question_answering, + model=model, + preprocessor=preprocessor, + db=db) ] - tableqa_tracking_and_print_results(self.test_case, pipelines) - - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_run_with_model_from_task(self): - pipelines = [pipeline(Tasks.table_question_answering, self.model_id)] - tableqa_tracking_and_print_results(self.test_case, pipelines) + tableqa_tracking_and_print_results_with_tableid(pipelines) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_model_from_modelhub_with_other_classes(self): @@ -60,15 +176,23 @@ class TableQuestionAnswering(unittest.TestCase): os.path.join(model.model_dir, ModelFile.VOCAB_FILE)) db = Database( tokenizer=self.tokenizer, - table_file_path=os.path.join(model.model_dir, 'table.json'), - syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt')) + table_file_path=[ + os.path.join(model.model_dir, 'databases', fname) + for fname in os.listdir( + os.path.join(model.model_dir, 'databases')) + ], + syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'), + is_use_sqlite=True) preprocessor = TableQuestionAnsweringPreprocessor( model_dir=model.model_dir, db=db) pipelines = [ - TableQuestionAnsweringPipeline( - model=model, preprocessor=preprocessor, db=db) + pipeline( + Tasks.table_question_answering, + model=model, + preprocessor=preprocessor, + db=db) ] - tableqa_tracking_and_print_results(self.test_case, pipelines) + tableqa_tracking_and_print_results_without_history(pipelines) if __name__ == '__main__': diff --git a/tests/pipelines/test_text2text_generation.py b/tests/pipelines/test_text2text_generation.py index a39562f5..d90263c4 100644 --- a/tests/pipelines/test_text2text_generation.py +++ b/tests/pipelines/test_text2text_generation.py @@ -15,42 +15,44 @@ from modelscope.utils.test_utils import test_level class Text2TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): def setUp(self) -> None: - self.model_id = 'damo/t5-cn-base-test' - self.input = '中国的首都位于。' + self.model_id_generate = 'damo/t5-cn-base-test' + self.input_generate = '中国的首都位于。' + self.model_id_translate = 'damo/t5-translate-base-test' + self.input_translate = 'My name is Wolfgang and I live in Berlin' - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_T5(self): - cache_path = snapshot_download(self.model_id) - model = T5ForConditionalGeneration(cache_path) + cache_path = snapshot_download(self.model_id_generate) + model = T5ForConditionalGeneration.from_pretrained(cache_path) preprocessor = Text2TextGenerationPreprocessor(cache_path) pipeline1 = Text2TextGenerationPipeline(model, preprocessor) pipeline2 = pipeline( Tasks.text2text_generation, model=model, preprocessor=preprocessor) print( - f'pipeline1: {pipeline1(self.input)}\npipeline2: {pipeline2(self.input)}' + f'pipeline1: {pipeline1(self.input_generate)}\npipeline2: {pipeline2(self.input_generate)}' ) - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_pipeline_with_model_instance(self): - model = Model.from_pretrained(self.model_id) + model = Model.from_pretrained(self.model_id_translate) preprocessor = Text2TextGenerationPreprocessor(model.model_dir) pipeline_ins = pipeline( task=Tasks.text2text_generation, model=model, preprocessor=preprocessor) - print(pipeline_ins(self.input)) + print(pipeline_ins(self.input_translate)) - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_pipeline_with_model_id(self): pipeline_ins = pipeline( - task=Tasks.text2text_generation, model=self.model_id) - print(pipeline_ins(self.input)) + task=Tasks.text2text_generation, model=self.model_id_translate) + print(pipeline_ins(self.input_translate)) @unittest.skip( 'only for test cases, there is no default official model yet') def test_run_pipeline_without_model_id(self): pipeline_ins = pipeline(task=Tasks.text2text_generation) - print(pipeline_ins(self.input)) + print(pipeline_ins(self.input_generate)) @unittest.skip('demo compatibility test is only enabled on a needed-basis') def test_demo_compatibility(self): diff --git a/tests/pipelines/test_text_classification.py b/tests/pipelines/test_text_classification.py index 39dbac99..5b38e116 100644 --- a/tests/pipelines/test_text_classification.py +++ b/tests/pipelines/test_text_classification.py @@ -4,7 +4,7 @@ import unittest from modelscope.models import Model from modelscope.msdatasets import MsDataset from modelscope.pipelines import pipeline -from modelscope.pipelines.nlp import SequenceClassificationPipeline +from modelscope.pipelines.nlp import TextClassificationPipeline from modelscope.preprocessors import SequenceClassificationPreprocessor from modelscope.utils.constant import Tasks from modelscope.utils.demo_utils import DemoCompatibilityCheck @@ -18,7 +18,7 @@ class SequenceClassificationTest(unittest.TestCase, DemoCompatibilityCheck): self.model_id = 'damo/bert-base-sst2' self.task = Tasks.text_classification - def predict(self, pipeline_ins: SequenceClassificationPipeline): + def predict(self, pipeline_ins: TextClassificationPipeline): from easynlp.appzoo import load_dataset set = load_dataset('glue', 'sst2') diff --git a/tests/pipelines/test_text_generation.py b/tests/pipelines/test_text_generation.py index 66f9c9da..ffb30090 100644 --- a/tests/pipelines/test_text_generation.py +++ b/tests/pipelines/test_text_generation.py @@ -15,12 +15,17 @@ from modelscope.utils.test_utils import test_level class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): def setUp(self) -> None: - self.palm_model_id_zh = 'damo/nlp_palm2.0_text-generation_chinese-base' + self.palm_model_id_zh_base = 'damo/nlp_palm2.0_text-generation_chinese-base' + self.palm_model_id_zh_large = 'damo/nlp_palm2.0_text-generation_chinese-large' + self.palm_model_id_zh_commodity = 'damo/nlp_palm2.0_text-generation_commodity_chinese-base' + self.palm_model_id_zh_weather = 'damo/nlp_palm2.0_text-generation_weather_chinese-base' self.palm_model_id_en = 'damo/nlp_palm2.0_text-generation_english-base' self.palm_input_zh = """ 本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方: 1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代 """ + self.palm_input_commodity = '垃圾桶,双层,可拆卸,加高,加高双层,把手,垃圾桶,内附,万向轮' + self.palm_input_weather = "今日天气类型='浮尘'&空气质量等级='重度污染'&紫外线强度指数='中等'" self.palm_input_en = """ The Director of Public Prosecutions who let off Lord Janner over alleged child sex abuse started her career at a legal chambers when the disgraced Labour peer was a top QC there . Alison Saunders , @@ -51,8 +56,8 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): print(pipeline_ins(input)) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_palm_zh_with_model_name(self): - self.run_pipeline_with_model_id(self.palm_model_id_zh, + def test_palm_zh_base_with_model_name(self): + self.run_pipeline_with_model_id(self.palm_model_id_zh_base, self.palm_input_zh) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -71,10 +76,40 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): self.gpt3_input) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_palm_zh_with_model_instance(self): - self.run_pipeline_with_model_instance(self.palm_model_id_zh, + def test_palm_zh_large_with_model_name(self): + self.run_pipeline_with_model_id(self.palm_model_id_zh_large, + self.palm_input_zh) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_commodity_with_model_name(self): + self.run_pipeline_with_model_id(self.palm_model_id_zh_commodity, + self.palm_input_commodity) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_weather_with_model_name(self): + self.run_pipeline_with_model_id(self.palm_model_id_zh_weather, + self.palm_input_weather) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_base_with_model_instance(self): + self.run_pipeline_with_model_instance(self.palm_model_id_zh_base, self.palm_input_zh) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_large_with_model_instance(self): + self.run_pipeline_with_model_instance(self.palm_model_id_zh_large, + self.palm_input_zh) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_commodity_with_model_instance(self): + self.run_pipeline_with_model_instance(self.palm_model_id_zh_commodity, + self.palm_input_commodity) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_weather_with_model_instance(self): + self.run_pipeline_with_model_instance(self.palm_model_id_zh_weather, + self.palm_input_weather) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_palm_en_with_model_instance(self): self.run_pipeline_with_model_instance(self.palm_model_id_en, @@ -92,8 +127,9 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_palm(self): - for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh), - (self.palm_model_id_en, self.palm_input_en)): + for model_id, input in ((self.palm_model_id_zh_base, + self.palm_input_zh), (self.palm_model_id_en, + self.palm_input_en)): cache_path = snapshot_download(model_id) model = PalmForTextGeneration.from_pretrained(cache_path) preprocessor = TextGenerationPreprocessor( @@ -129,6 +165,25 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): pipeline_ins = pipeline(task=Tasks.text_generation) print(pipeline_ins(self.palm_input_zh)) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_bloom(self): + pipe = pipeline( + task=Tasks.text_generation, model='langboat/bloom-1b4-zh') + print(pipe('中国的首都是')) + + @unittest.skip("Langboat's checkpoint has not been uploaded to modelhub") + def test_gpt_neo(self): + pipe = pipeline( + task=Tasks.text_generation, model='langboat/mengzi-gpt-neo-base') + print( + pipe( + '我是', + do_sample=True, + top_k=5, + top_p=1, + max_length=20, + repetition_penalty=0.5)) + @unittest.skip('demo compatibility test is only enabled on a needed-basis') def test_demo_compatibility(self): self.compatibility_check() diff --git a/tests/pipelines/test_text_ranking.py b/tests/pipelines/test_text_ranking.py new file mode 100644 index 00000000..0b43e8b4 --- /dev/null +++ b/tests/pipelines/test_text_ranking.py @@ -0,0 +1,67 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import shutil +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import BertForTextRanking +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import TextRankingPipeline +from modelscope.preprocessors import TextRankingPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class TextRankingTest(unittest.TestCase): + models = [ + 'damo/nlp_corom_passage-ranking_english-base', + 'damo/nlp_rom_passage-ranking_chinese-base' + ] + + inputs = { + 'source_sentence': ["how long it take to get a master's degree"], + 'sentences_to_compare': [ + "On average, students take about 18 to 24 months to complete a master's degree.", + 'On the other hand, some students prefer to go at a slower pace and choose to take ' + 'several years to complete their studies.', + 'It can take anywhere from two semesters' + ] + } + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + for model_id in self.models: + cache_path = snapshot_download(model_id) + tokenizer = TextRankingPreprocessor(cache_path) + model = BertForTextRanking.from_pretrained(cache_path) + pipeline1 = TextRankingPipeline(model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.text_ranking, model=model, preprocessor=tokenizer) + print(f'sentence: {self.inputs}\n' + f'pipeline1:{pipeline1(input=self.inputs)}') + print() + print(f'pipeline2: {pipeline2(input=self.inputs)}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + for model_id in self.models: + model = Model.from_pretrained(model_id) + tokenizer = TextRankingPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.text_ranking, model=model, preprocessor=tokenizer) + print(pipeline_ins(input=self.inputs)) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_name(self): + for model_id in self.models: + pipeline_ins = pipeline(task=Tasks.text_ranking, model=model_id) + print(pipeline_ins(input=self.inputs)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.text_ranking) + print(pipeline_ins(input=self.inputs)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_text_to_speech.py b/tests/pipelines/test_text_to_speech.py index f659e59b..50807e23 100644 --- a/tests/pipelines/test_text_to_speech.py +++ b/tests/pipelines/test_text_to_speech.py @@ -9,7 +9,6 @@ import unittest import torch from scipy.io.wavfile import write -from modelscope.models import Model from modelscope.outputs import OutputKeys from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks @@ -27,21 +26,48 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase, def setUp(self) -> None: self.task = Tasks.text_to_speech - self.model_id = 'damo/speech_sambert-hifigan_tts_zhitian_emo_zh-cn_16k' + self.zhcn_text = '今天北京天气怎么样' + self.en_text = 'How is the weather in Beijing?' + self.zhcn_voices = [ + 'zhitian_emo', 'zhizhe_emo', 'zhiyan_emo', 'zhibei_emo', 'zhcn' + ] + self.zhcn_models = [ + 'damo/speech_sambert-hifigan_tts_zhitian_emo_zh-cn_16k', + 'damo/speech_sambert-hifigan_tts_zhizhe_emo_zh-cn_16k', + 'damo/speech_sambert-hifigan_tts_zhiyan_emo_zh-cn_16k', + 'damo/speech_sambert-hifigan_tts_zhibei_emo_zh-cn_16k', + 'damo/speech_sambert-hifigan_tts_zh-cn_16k' + ] + self.en_voices = ['luca', 'luna', 'andy', 'annie', 'engb', 'enus'] + self.en_models = [ + 'damo/speech_sambert-hifigan_tts_luca_en-gb_16k', + 'damo/speech_sambert-hifigan_tts_luna_en-gb_16k', + 'damo/speech_sambert-hifigan_tts_andy_en-us_16k', + 'damo/speech_sambert-hifigan_tts_annie_en-us_16k', + 'damo/speech_sambert-hifigan_tts_en-gb_16k', + 'damo/speech_sambert-hifigan_tts_en-us_16k' + ] @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_pipeline(self): - text = '今天北京天气怎么样?' - voice = 'zhitian_emo' - - model = Model.from_pretrained( - model_name_or_path=self.model_id, revision='pytorch_am') - sambert_hifigan_tts = pipeline(task=self.task, model=model) - self.assertTrue(sambert_hifigan_tts is not None) - output = sambert_hifigan_tts(input=text, voice=voice) - self.assertIsNotNone(output[OutputKeys.OUTPUT_PCM]) - pcm = output[OutputKeys.OUTPUT_PCM] - write('output.wav', 16000, pcm) + for i in range(len(self.zhcn_voices)): + logger.info('test %s' % self.zhcn_voices[i]) + sambert_hifigan_tts = pipeline( + task=self.task, model=self.zhcn_models[i]) + self.assertTrue(sambert_hifigan_tts is not None) + output = sambert_hifigan_tts(input=self.zhcn_text) + self.assertIsNotNone(output[OutputKeys.OUTPUT_PCM]) + pcm = output[OutputKeys.OUTPUT_PCM] + write('output_%s.wav' % self.zhcn_voices[i], 16000, pcm) + for i in range(len(self.en_voices)): + logger.info('test %s' % self.en_voices[i]) + sambert_hifigan_tts = pipeline( + task=self.task, model=self.en_models[i]) + self.assertTrue(sambert_hifigan_tts is not None) + output = sambert_hifigan_tts(input=self.en_text) + self.assertIsNotNone(output[OutputKeys.OUTPUT_PCM]) + pcm = output[OutputKeys.OUTPUT_PCM] + write('output_%s.wav' % self.en_voices[i], 16000, pcm) @unittest.skip('demo compatibility test is only enabled on a needed-basis') def test_demo_compatibility(self): diff --git a/tests/pipelines/test_tinynas_detection.py b/tests/pipelines/test_tinynas_detection.py index 63db9145..43e1842d 100644 --- a/tests/pipelines/test_tinynas_detection.py +++ b/tests/pipelines/test_tinynas_detection.py @@ -4,22 +4,45 @@ import unittest from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck from modelscope.utils.test_utils import test_level -class TinynasObjectDetectionTest(unittest.TestCase): +class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.image_object_detection + self.model_id = 'damo/cv_tinynas_object-detection_damoyolo' @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_run(self): + def test_run_airdet(self): tinynas_object_detection = pipeline( Tasks.image_object_detection, model='damo/cv_tinynas_detection') result = tinynas_object_detection( 'data/test/images/image_detection.jpg') print(result) + @unittest.skip('will be enabled after damoyolo officially released') + def test_run_damoyolo(self): + tinynas_object_detection = pipeline( + Tasks.image_object_detection, + model='damo/cv_tinynas_object-detection_damoyolo') + result = tinynas_object_detection( + 'data/test/images/image_detection.jpg') + print(result) + @unittest.skip('demo compatibility test is only enabled on a needed-basis') def test_demo_compatibility(self): - self.test_demo() + self.compatibility_check() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_image_object_detection_auto_pipeline(self): + test_image = 'data/test/images/image_detection.jpg' + tinynas_object_detection = pipeline( + Tasks.image_object_detection, model='damo/cv_tinynas_detection') + result = tinynas_object_detection(test_image) + tinynas_object_detection.show_result(test_image, result, + 'demo_ret.jpg') if __name__ == '__main__': diff --git a/tests/pipelines/test_translation_quality_estimation.py b/tests/pipelines/test_translation_quality_estimation.py new file mode 100644 index 00000000..315fa72b --- /dev/null +++ b/tests/pipelines/test_translation_quality_estimation.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class TranslationQualityEstimationTest(unittest.TestCase, + DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.sentence_similarity + self.model_id = 'damo/nlp_translation_quality_estimation_multilingual' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_en2zh(self): + inputs = { + 'source_text': 'Love is a losing game', + 'target_text': '宝贝,人和人一场游戏' + } + pipeline_ins = pipeline(self.task, model=self.model_id) + print(pipeline_ins(input=inputs)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_unifold.py b/tests/pipelines/test_unifold.py new file mode 100644 index 00000000..df35dc5e --- /dev/null +++ b/tests/pipelines/test_unifold.py @@ -0,0 +1,34 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class UnifoldProteinStructureTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.protein_structure + self.model_id = 'DPTech/uni-fold-monomer' + self.model_id_multimer = 'DPTech/uni-fold-multimer' + + self.protein = 'MGLPKKALKESQLQFLTAGTAVSDSSHQTYKVSFIENGVIKNAFYKKLDPKNHYPELLAKISVAVSLFKRIFQGRRSAEERLVFDD' + self.protein_multimer = 'GAMGLPEEPSSPQESTLKALSLYEAHLSSYIMYLQTFLVKTKQKVNNKNYPEFTLFDTSKLKKDQTLKSIKT' + \ + 'NIAALKNHIDKIKPIAMQIYKKYSKNIP' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + model_dir = snapshot_download(self.model_id) + mono_pipeline_ins = pipeline(task=self.task, model=model_dir) + _ = mono_pipeline_ins(self.protein) + + model_dir1 = snapshot_download(self.model_id_multimer) + multi_pipeline_ins = pipeline(task=self.task, model=model_dir1) + _ = multi_pipeline_ins(self.protein_multimer) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_video_multi_modal_embedding.py b/tests/pipelines/test_video_multi_modal_embedding.py index f4aa4d24..afe5940d 100644 --- a/tests/pipelines/test_video_multi_modal_embedding.py +++ b/tests/pipelines/test_video_multi_modal_embedding.py @@ -17,8 +17,8 @@ class VideoMultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): self.task = Tasks.video_multi_modal_embedding self.model_id = 'damo/multi_modal_clip_vtretrival_msrvtt_53' - video_path = 'data/test/videos/multi_modal_test_video_9770.mp4' - caption = ('a person is connecting something to system', None, None) + video_path = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/videos/multi_modal_test_video_9770.mp4' + caption = 'a person is connecting something to system' _input = {'video': video_path, 'text': caption} @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') diff --git a/tests/pipelines/test_video_summarization.py b/tests/pipelines/test_video_summarization.py index 6dcc31e9..1f965c53 100644 --- a/tests/pipelines/test_video_summarization.py +++ b/tests/pipelines/test_video_summarization.py @@ -3,7 +3,6 @@ import unittest from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks -from modelscope.utils.cv.image_utils import show_video_summarization_result from modelscope.utils.demo_utils import DemoCompatibilityCheck from modelscope.utils.test_utils import test_level @@ -22,8 +21,6 @@ class VideoSummarizationTest(unittest.TestCase, DemoCompatibilityCheck): result = summarization_pipeline(video_path) print(f'video summarization output: \n{result}.') - show_video_summarization_result(video_path, result, - './summarization_result.avi') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_modelhub_default_model(self): diff --git a/tests/pipelines/test_zero_shot_classification.py b/tests/pipelines/test_zero_shot_classification.py index da1854c9..6a98132a 100644 --- a/tests/pipelines/test_zero_shot_classification.py +++ b/tests/pipelines/test_zero_shot_classification.py @@ -21,6 +21,7 @@ class ZeroShotClassificationTest(unittest.TestCase, DemoCompatibilityCheck): sentence = '全新突破 解放军运20版空中加油机曝光' labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事'] + labels_str = '文化, 体育, 娱乐, 财经, 家居, 汽车, 教育, 科技, 军事' template = '这篇文章的标题是{}' regress_tool = MsRegressTool(baseline=False) @@ -40,6 +41,10 @@ class ZeroShotClassificationTest(unittest.TestCase, DemoCompatibilityCheck): f'sentence: {self.sentence}\n' f'pipeline1:{pipeline1(input=self.sentence,candidate_labels=self.labels)}' ) + print( + f'sentence: {self.sentence}\n' + f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels_str,hypothesis_template=self.template)}' + ) print( f'sentence: {self.sentence}\n' f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels,hypothesis_template=self.template)}' diff --git a/tests/run_config.yaml b/tests/run_config.yaml index 4bbdb92f..d51e2606 100644 --- a/tests/run_config.yaml +++ b/tests/run_config.yaml @@ -11,6 +11,7 @@ isolated: # test cases that may require excessive anmount of GPU memory, which - test_segformer.py - test_segmentation_pipeline.py - test_movie_scene_segmentation.py + - test_image_inpainting.py envs: default: # default env, case not in other env will in default, pytorch. diff --git a/tests/trainers/audio/test_ans_trainer.py b/tests/trainers/audio/test_ans_trainer.py index c0860529..d897e6a9 100644 --- a/tests/trainers/audio/test_ans_trainer.py +++ b/tests/trainers/audio/test_ans_trainer.py @@ -17,7 +17,6 @@ SEGMENT_LENGTH_TEST = 640 class TestANSTrainer(unittest.TestCase): - REVISION = 'beta' def setUp(self): self.tmp_dir = tempfile.TemporaryDirectory().name @@ -25,7 +24,7 @@ class TestANSTrainer(unittest.TestCase): os.makedirs(self.tmp_dir) self.model_id = 'damo/speech_frcrn_ans_cirm_16k' - cfg = read_config(self.model_id, revision=self.REVISION) + cfg = read_config(self.model_id) cfg.train.max_epochs = 2 cfg.train.dataloader.batch_size_per_gpu = 1 self.cfg_file = os.path.join(self.tmp_dir, 'train_config.json') @@ -48,7 +47,6 @@ class TestANSTrainer(unittest.TestCase): def test_trainer(self): kwargs = dict( model=self.model_id, - model_revision=self.REVISION, train_dataset=self.dataset, eval_dataset=self.dataset, max_epochs=2, diff --git a/tests/trainers/audio/test_kws_farfield_trainer.py b/tests/trainers/audio/test_kws_farfield_trainer.py new file mode 100644 index 00000000..70b68a11 --- /dev/null +++ b/tests/trainers/audio/test_kws_farfield_trainer.py @@ -0,0 +1,83 @@ +import os +import shutil +import tempfile +import unittest + +from modelscope.metainfo import Trainers +from modelscope.trainers import build_trainer +from modelscope.utils.test_utils import test_level + +POS_FILE = 'data/test/audios/wake_word_with_label_xyxy.wav' +NEG_FILE = 'data/test/audios/speech_with_noise.wav' +NOISE_FILE = 'data/test/audios/speech_with_noise.wav' +INTERF_FILE = 'data/test/audios/speech_with_noise.wav' +REF_FILE = 'data/test/audios/farend_speech.wav' +NOISE_2CH_FILE = 'data/test/audios/noise_2ch.wav' + + +class TestKwsFarfieldTrainer(unittest.TestCase): + + def setUp(self): + self.tmp_dir = tempfile.TemporaryDirectory().name + print(f'tmp dir: {self.tmp_dir}') + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya' + + train_pos_list = self.create_list('pos.list', POS_FILE) + train_neg_list = self.create_list('neg.list', NEG_FILE) + train_noise1_list = self.create_list('noise.list', NOISE_FILE) + train_noise2_list = self.create_list('noise_2ch.list', NOISE_2CH_FILE) + train_interf_list = self.create_list('interf.list', INTERF_FILE) + train_ref_list = self.create_list('ref.list', REF_FILE) + + base_dict = dict( + train_pos_list=train_pos_list, + train_neg_list=train_neg_list, + train_noise1_list=train_noise1_list) + fintune_dict = dict( + train_pos_list=train_pos_list, + train_neg_list=train_neg_list, + train_noise1_list=train_noise1_list, + train_noise2_type='1', + train_noise1_ratio='0.2', + train_noise2_list=train_noise2_list, + train_interf_list=train_interf_list, + train_ref_list=train_ref_list) + self.custom_conf = dict( + basetrain_easy=base_dict, + basetrain_normal=base_dict, + basetrain_hard=base_dict, + finetune_easy=fintune_dict, + finetune_normal=fintune_dict, + finetune_hard=fintune_dict) + + def create_list(self, list_name, audio_file): + pos_list_file = os.path.join(self.tmp_dir, list_name) + with open(pos_list_file, 'w') as f: + for i in range(10): + f.write(f'{os.path.join(os.getcwd(), audio_file)}\n') + train_pos_list = f'{pos_list_file}, 1.0' + return train_pos_list + + def tearDown(self) -> None: + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_normal(self): + kwargs = dict( + model=self.model_id, + work_dir=self.tmp_dir, + workers=2, + max_epochs=2, + train_iters_per_epoch=2, + val_iters_per_epoch=1, + custom_conf=self.custom_conf) + + trainer = build_trainer( + Trainers.speech_dfsmn_kws_char_farfield, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files, + f'work_dir:{self.tmp_dir}') diff --git a/tests/trainers/easycv/test_easycv_trainer_face_2d_keypoints.py b/tests/trainers/easycv/test_easycv_trainer_face_2d_keypoints.py new file mode 100644 index 00000000..4dffa998 --- /dev/null +++ b/tests/trainers/easycv/test_easycv_trainer_face_2d_keypoints.py @@ -0,0 +1,71 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import glob +import os +import shutil +import tempfile +import unittest + +import torch + +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.constant import DownloadMode, LogKeys, Tasks +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + + +@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') +class EasyCVTrainerTestFace2DKeypoints(unittest.TestCase): + model_id = 'damo/cv_mobilenet_face-2d-keypoints_alignment' + + def setUp(self): + self.logger = get_logger() + self.logger.info(('Testing %s.%s' % + (type(self).__name__, self._testMethodName))) + + def _train(self, tmp_dir): + cfg_options = {'train.max_epochs': 2} + + trainer_name = Trainers.easycv + + train_dataset = MsDataset.load( + dataset_name='face_2d_keypoints_dataset', + namespace='modelscope', + split='train', + download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) + eval_dataset = MsDataset.load( + dataset_name='face_2d_keypoints_dataset', + namespace='modelscope', + split='train', + download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) + + kwargs = dict( + model=self.model_id, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + work_dir=tmp_dir, + cfg_options=cfg_options) + + trainer = build_trainer(trainer_name, kwargs) + trainer.train() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_single_gpu(self): + temp_file_dir = tempfile.TemporaryDirectory() + tmp_dir = temp_file_dir.name + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + + self._train(tmp_dir) + + results_files = os.listdir(tmp_dir) + json_files = glob.glob(os.path.join(tmp_dir, '*.log.json')) + self.assertEqual(len(json_files), 1) + self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) + + temp_file_dir.cleanup() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/easycv/test_easycv_trainer_hand_2d_keypoints.py b/tests/trainers/easycv/test_easycv_trainer_hand_2d_keypoints.py new file mode 100644 index 00000000..270ecbc4 --- /dev/null +++ b/tests/trainers/easycv/test_easycv_trainer_hand_2d_keypoints.py @@ -0,0 +1,72 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import glob +import os +import shutil +import tempfile +import unittest + +import torch + +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.constant import DownloadMode, LogKeys, Tasks +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + + +@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') +class EasyCVTrainerTestHand2dKeypoints(unittest.TestCase): + model_id = 'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody' + + def setUp(self): + self.logger = get_logger() + self.logger.info(('Testing %s.%s' % + (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir, ignore_errors=True) + + def _train(self): + cfg_options = {'train.max_epochs': 20} + + trainer_name = Trainers.easycv + + train_dataset = MsDataset.load( + dataset_name='cv_hand_2d_keypoints_coco_wholebody', + namespace='chenhyer', + split='subtrain', + download_mode=DownloadMode.FORCE_REDOWNLOAD) + eval_dataset = MsDataset.load( + dataset_name='cv_hand_2d_keypoints_coco_wholebody', + namespace='chenhyer', + split='subtrain', + download_mode=DownloadMode.FORCE_REDOWNLOAD) + + kwargs = dict( + model=self.model_id, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + work_dir=self.tmp_dir, + cfg_options=cfg_options) + + trainer = build_trainer(trainer_name, kwargs) + trainer.train() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_single_gpu(self): + self._train() + + results_files = os.listdir(self.tmp_dir) + json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) + self.assertEqual(len(json_files), 1) + self.assertIn(f'{LogKeys.EPOCH}_10.pth', results_files) + self.assertIn(f'{LogKeys.EPOCH}_20.pth', results_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/hooks/compression/__init__.py b/tests/trainers/hooks/compression/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/trainers/hooks/compression/test_sparsity_hook.py b/tests/trainers/hooks/compression/test_sparsity_hook.py new file mode 100644 index 00000000..4af4dcdb --- /dev/null +++ b/tests/trainers/hooks/compression/test_sparsity_hook.py @@ -0,0 +1,113 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +import json +import numpy as np +import torch +from torch import nn +from torch.optim import SGD +from torch.optim.lr_scheduler import MultiStepLR + +from modelscope.metainfo import Trainers +from modelscope.models.base import Model +from modelscope.trainers import build_trainer +from modelscope.utils.constant import ModelFile, TrainerStages +from modelscope.utils.test_utils import create_dummy_test_dataset + +dummy_dataset = create_dummy_test_dataset( + np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 10) + + +class DummyModel(nn.Module, Model): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 10) + self.bn = nn.BatchNorm1d(10) + + def forward(self, feat, labels): + x = self.linear(feat) + + x = self.bn(x) + loss = torch.sum(x) + return dict(logits=x, loss=loss) + + +class SparsityHookTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir) + + def test_sparsity_hook(self): + json_cfg = { + 'task': 'image_classification', + 'train': { + 'work_dir': + self.tmp_dir, + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1 + }, + 'hooks': [{ + 'type': 'SparsityHook', + 'pruning_method': 'pst', + 'config': { + 'weight_rank': 1, + 'mask_rank': 1, + 'final_sparsity': 0.9, + 'frequency': 1, + }, + }], + }, + } + + config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) + with open(config_path, 'w') as f: + json.dump(json_cfg, f) + + model = DummyModel() + optimizer = SGD(model.parameters(), lr=0.01) + lr_scheduler = MultiStepLR(optimizer, milestones=[2, 4]) + trainer_name = Trainers.default + kwargs = dict( + cfg_file=config_path, + model=model, + train_dataset=dummy_dataset, + optimizers=(optimizer, lr_scheduler), + max_epochs=5, + device='cpu', + ) + + trainer = build_trainer(trainer_name, kwargs) + train_dataloader = trainer._build_dataloader_with_dataset( + trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) + trainer.register_optimizers_hook() + trainer.register_hook_from_cfg(trainer.cfg.train.hooks) + trainer.train_dataloader = train_dataloader + trainer.data_loader = train_dataloader + trainer.invoke_hook(TrainerStages.before_run) + for i in range(trainer._epoch, trainer._max_epochs): + trainer.invoke_hook(TrainerStages.before_train_epoch) + for _, data_batch in enumerate(train_dataloader): + trainer.invoke_hook(TrainerStages.before_train_iter) + trainer.train_step(trainer.model, data_batch) + trainer.invoke_hook(TrainerStages.after_train_iter) + trainer.invoke_hook(TrainerStages.after_train_epoch) + trainer.invoke_hook(TrainerStages.after_run) + + self.assertEqual( + torch.mean(1.0 * (trainer.model.linear.weight == 0)), 0.9) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_card_detection_scrfd_trainer.py b/tests/trainers/test_card_detection_scrfd_trainer.py new file mode 100644 index 00000000..af87000b --- /dev/null +++ b/tests/trainers/test_card_detection_scrfd_trainer.py @@ -0,0 +1,151 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import glob +import os +import shutil +import tempfile +import unittest + +import torch + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile +from modelscope.utils.test_utils import DistributedTestCase, test_level + + +def _setup(): + model_id = 'damo/cv_resnet_carddetection_scrfd34gkps' + # mini dataset only for unit test, remove '_mini' for full dataset. + ms_ds_syncards = MsDataset.load( + 'SyntheticCards_mini', namespace='shaoxuan') + + data_path = ms_ds_syncards.config_kwargs['split_config'] + train_dir = data_path['train'] + val_dir = data_path['validation'] + train_root = train_dir + '/' + os.listdir(train_dir)[0] + '/' + val_root = val_dir + '/' + os.listdir(val_dir)[0] + '/' + max_epochs = 1 # run epochs in unit test + + cache_path = snapshot_download(model_id) + + tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + return train_root, val_root, max_epochs, cache_path, tmp_dir + + +def train_func(**kwargs): + trainer = build_trainer( + name=Trainers.card_detection_scrfd, default_args=kwargs) + trainer.train() + + +class TestCardDetectionScrfdTrainerSingleGPU(unittest.TestCase): + + def setUp(self): + print(('SingleGPU Testing %s.%s' % + (type(self).__name__, self._testMethodName))) + self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup( + ) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + def _cfg_modify_fn(self, cfg): + cfg.checkpoint_config.interval = 1 + cfg.log_config.interval = 10 + cfg.evaluation.interval = 1 + cfg.data.workers_per_gpu = 3 + cfg.data.samples_per_gpu = 4 # batch size + return cfg + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_from_scratch(self): + kwargs = dict( + cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), + work_dir=self.tmp_dir, + train_root=self.train_root, + val_root=self.val_root, + total_epochs=self.max_epochs, + cfg_modify_fn=self._cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.card_detection_scrfd, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_finetune(self): + pretrain_epoch = 640 + self.max_epochs += pretrain_epoch + kwargs = dict( + cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), + work_dir=self.tmp_dir, + train_root=self.train_root, + val_root=self.val_root, + total_epochs=self.max_epochs, + resume_from=os.path.join(self.cache_path, + ModelFile.TORCH_MODEL_BIN_FILE), + cfg_modify_fn=self._cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.card_detection_scrfd, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(pretrain_epoch, self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +@unittest.skipIf(not torch.cuda.is_available() + or torch.cuda.device_count() <= 1, 'distributed unittest') +class TestCardDetectionScrfdTrainerMultiGpus(DistributedTestCase): + + def setUp(self): + print(('MultiGPUs Testing %s.%s' % + (type(self).__name__, self._testMethodName))) + self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup( + ) + cfg_file_path = os.path.join(self.cache_path, 'mmcv_scrfd.py') + cfg = Config.from_file(cfg_file_path) + cfg.checkpoint_config.interval = 1 + cfg.log_config.interval = 10 + cfg.evaluation.interval = 1 + cfg.data.workers_per_gpu = 3 + cfg.data.samples_per_gpu = 4 + cfg.dump(cfg_file_path) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_multi_gpus_finetune(self): + pretrain_epoch = 640 + self.max_epochs += pretrain_epoch + kwargs = dict( + cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), + work_dir=self.tmp_dir, + train_root=self.train_root, + val_root=self.val_root, + total_epochs=self.max_epochs, + resume_from=os.path.join(self.cache_path, + ModelFile.TORCH_MODEL_BIN_FILE), + launcher='pytorch') + self.start(train_func, num_gpus=2, **kwargs) + results_files = os.listdir(self.tmp_dir) + json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) + self.assertEqual(len(json_files), 1) + for i in range(pretrain_epoch, self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_face_detection_scrfd_trainer.py b/tests/trainers/test_face_detection_scrfd_trainer.py new file mode 100644 index 00000000..97b0eca7 --- /dev/null +++ b/tests/trainers/test_face_detection_scrfd_trainer.py @@ -0,0 +1,150 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import glob +import os +import shutil +import tempfile +import unittest + +import torch + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile +from modelscope.utils.test_utils import DistributedTestCase, test_level + + +def _setup(): + model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' + # mini dataset only for unit test, remove '_mini' for full dataset. + ms_ds_widerface = MsDataset.load('WIDER_FACE_mini', namespace='shaoxuan') + + data_path = ms_ds_widerface.config_kwargs['split_config'] + train_dir = data_path['train'] + val_dir = data_path['validation'] + train_root = train_dir + '/' + os.listdir(train_dir)[0] + '/' + val_root = val_dir + '/' + os.listdir(val_dir)[0] + '/' + max_epochs = 1 # run epochs in unit test + + cache_path = snapshot_download(model_id) + + tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + return train_root, val_root, max_epochs, cache_path, tmp_dir + + +def train_func(**kwargs): + trainer = build_trainer( + name=Trainers.face_detection_scrfd, default_args=kwargs) + trainer.train() + + +class TestFaceDetectionScrfdTrainerSingleGPU(unittest.TestCase): + + def setUp(self): + print(('SingleGPU Testing %s.%s' % + (type(self).__name__, self._testMethodName))) + self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup( + ) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + def _cfg_modify_fn(self, cfg): + cfg.checkpoint_config.interval = 1 + cfg.log_config.interval = 10 + cfg.evaluation.interval = 1 + cfg.data.workers_per_gpu = 3 + cfg.data.samples_per_gpu = 4 # batch size + return cfg + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_from_scratch(self): + kwargs = dict( + cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), + work_dir=self.tmp_dir, + train_root=self.train_root, + val_root=self.val_root, + total_epochs=self.max_epochs, + cfg_modify_fn=self._cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.face_detection_scrfd, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_finetune(self): + pretrain_epoch = 640 + self.max_epochs += pretrain_epoch + kwargs = dict( + cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), + work_dir=self.tmp_dir, + train_root=self.train_root, + val_root=self.val_root, + total_epochs=self.max_epochs, + resume_from=os.path.join(self.cache_path, + ModelFile.TORCH_MODEL_BIN_FILE), + cfg_modify_fn=self._cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.face_detection_scrfd, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(pretrain_epoch, self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +@unittest.skipIf(not torch.cuda.is_available() + or torch.cuda.device_count() <= 1, 'distributed unittest') +class TestFaceDetectionScrfdTrainerMultiGpus(DistributedTestCase): + + def setUp(self): + print(('MultiGPUs Testing %s.%s' % + (type(self).__name__, self._testMethodName))) + self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup( + ) + cfg_file_path = os.path.join(self.cache_path, 'mmcv_scrfd.py') + cfg = Config.from_file(cfg_file_path) + cfg.checkpoint_config.interval = 1 + cfg.log_config.interval = 10 + cfg.evaluation.interval = 1 + cfg.data.workers_per_gpu = 3 + cfg.data.samples_per_gpu = 4 + cfg.dump(cfg_file_path) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_multi_gpus_finetune(self): + pretrain_epoch = 640 + self.max_epochs += pretrain_epoch + kwargs = dict( + cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), + work_dir=self.tmp_dir, + train_root=self.train_root, + val_root=self.val_root, + total_epochs=self.max_epochs, + resume_from=os.path.join(self.cache_path, + ModelFile.TORCH_MODEL_BIN_FILE), + launcher='pytorch') + self.start(train_func, num_gpus=2, **kwargs) + results_files = os.listdir(self.tmp_dir) + json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) + self.assertEqual(len(json_files), 1) + for i in range(pretrain_epoch, self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_finetune_sequence_classification.py b/tests/trainers/test_finetune_sequence_classification.py index f2adfa22..ae780793 100644 --- a/tests/trainers/test_finetune_sequence_classification.py +++ b/tests/trainers/test_finetune_sequence_classification.py @@ -8,7 +8,7 @@ from modelscope.metainfo import Preprocessors, Trainers from modelscope.models import Model from modelscope.msdatasets import MsDataset from modelscope.pipelines import pipeline -from modelscope.trainers import build_trainer +from modelscope.trainers import NlpTrainerArguments, build_trainer from modelscope.trainers.hooks import Hook from modelscope.trainers.nlp_trainer import (EpochBasedTrainer, NlpEpochBasedTrainer) @@ -16,7 +16,8 @@ from modelscope.trainers.optimizer.child_tuning_adamw_optimizer import \ calculate_fisher from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.data_utils import to_device -from modelscope.utils.regress_test_utils import MsRegressTool +from modelscope.utils.regress_test_utils import (MsRegressTool, + compare_arguments_nested) from modelscope.utils.test_utils import test_level @@ -37,10 +38,84 @@ class TestFinetuneSequenceClassification(unittest.TestCase): shutil.rmtree(self.tmp_dir) super().tearDown() - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_cfg_class(self): + dataset = MsDataset.load('clue', subset_name='tnews') + train_dataset = dataset['train'] + validation_dataset = dataset['validation'] + cfg_modify_fn = NlpTrainerArguments( + task=Tasks.text_classification, + preprocessor_type=Preprocessors.sen_cls_tokenizer, + train_first_sequence='sentence', + train_label='label', + labels=[ + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', + '12', '13', '14' + ], + max_epochs=5, + optimizer_args={ + 'lr': 3e-5, + }, + lr_scheduler_args={ + 'total_iters': int(len(train_dataset) / 32) * 5, + }, + checkpoint_saving_type='BestCkptSaverHook', + metric_key='accuracy', + train_batch_size_per_gpu=32, + checkpoint_interval=1, + train_workers_per_gpu=0, + checkpoint_by_epoch=False, + evaluation_interval=1, + evaluation_by_epoch=False, + eval_workers_per_gpu=0, + metrics=['seq-cls-metric'], + ) + + kwargs = dict( + model='damo/nlp_structbert_backbone_base_std', + train_dataset=train_dataset, + eval_dataset=validation_dataset, + work_dir=self.tmp_dir, + seed=42, + cfg_modify_fn=cfg_modify_fn) + + os.environ['LOCAL_RANK'] = '0' + trainer: EpochBasedTrainer = build_trainer( + name=Trainers.nlp_base_trainer, default_args=kwargs) + trainer.train() + + @unittest.skip( + 'Skip testing trainer repeatable, because it\'s unstable in daily UT') def test_trainer_repeatable(self): import torch # noqa + def compare_fn(value1, value2, key, type): + # Ignore the differences between optimizers of two torch versions + if type != 'optimizer': + return None + + match = (value1['type'] == value2['type']) + shared_defaults = set(value1['defaults'].keys()).intersection( + set(value2['defaults'].keys())) + match = all([ + compare_arguments_nested(f'Optimizer defaults {key} not match', + value1['defaults'][key], + value2['defaults'][key]) + for key in shared_defaults + ]) and match + match = (len(value1['state_dict']['param_groups']) == len( + value2['state_dict']['param_groups'])) and match + for group1, group2 in zip(value1['state_dict']['param_groups'], + value2['state_dict']['param_groups']): + shared_keys = set(group1.keys()).intersection( + set(group2.keys())) + match = all([ + compare_arguments_nested( + f'Optimizer param_groups {key} not match', group1[key], + group2[key]) for key in shared_keys + ]) and match + return match + def cfg_modify_fn(cfg): cfg.task = 'nli' cfg['preprocessor'] = {'type': 'nli-tokenizer'} @@ -98,7 +173,8 @@ class TestFinetuneSequenceClassification(unittest.TestCase): name=Trainers.nlp_base_trainer, default_args=kwargs) with self.regress_tool.monitor_ms_train( - trainer, 'sbert-base-tnews', level='strict'): + trainer, 'sbert-base-tnews', level='strict', + compare_fn=compare_fn): trainer.train() def finetune(self, @@ -300,7 +376,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): 2, 'dataloader': { 'batch_size_per_gpu': 16, - 'workers_per_gpu': 1 + 'workers_per_gpu': 0 }, 'optimizer': { 'type': 'AdamW', @@ -321,7 +397,6 @@ class TestFinetuneSequenceClassification(unittest.TestCase): 'hooks': [{ 'type': 'CheckpointHook', 'interval': 1, - 'save_dir': '/root' }, { 'type': 'TextLoggerHook', 'interval': 1 @@ -336,7 +411,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): cfg['evaluation'] = { 'dataloader': { 'batch_size_per_gpu': 128, - 'workers_per_gpu': 1, + 'workers_per_gpu': 0, 'shuffle': False } } diff --git a/tests/trainers/test_finetune_passage_ranking.py b/tests/trainers/test_finetune_text_ranking.py similarity index 56% rename from tests/trainers/test_finetune_passage_ranking.py rename to tests/trainers/test_finetune_text_ranking.py index f833f981..6e97310d 100644 --- a/tests/trainers/test_finetune_passage_ranking.py +++ b/tests/trainers/test_finetune_text_ranking.py @@ -14,6 +14,7 @@ from modelscope.msdatasets import MsDataset from modelscope.pipelines import pipeline from modelscope.trainers import build_trainer from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.test_utils import test_level class TestFinetuneSequenceClassification(unittest.TestCase): @@ -41,7 +42,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): model_id, train_dataset, eval_dataset, - name=Trainers.nlp_passage_ranking_trainer, + name=Trainers.nlp_text_ranking_trainer, cfg_modify_fn=None, **kwargs): kwargs = dict( @@ -58,11 +59,13 @@ class TestFinetuneSequenceClassification(unittest.TestCase): results_files = os.listdir(self.tmp_dir) self.assertIn(f'{trainer.timestamp}.log.json', results_files) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_finetune_msmarco(self): def cfg_modify_fn(cfg): - cfg.task = 'passage-ranking' - cfg['preprocessor'] = {'type': 'passage-ranking'} + neg_sample = 4 + cfg.task = 'text-ranking' + cfg['preprocessor'] = {'type': 'text-ranking'} cfg.train.optimizer.lr = 2e-5 cfg['dataset'] = { 'train': { @@ -70,19 +73,19 @@ class TestFinetuneSequenceClassification(unittest.TestCase): 'query_sequence': 'query', 'pos_sequence': 'positive_passages', 'neg_sequence': 'negative_passages', - 'passage_text_fileds': ['title', 'text'], - 'qid_field': 'query_id' + 'text_fileds': ['title', 'text'], + 'qid_field': 'query_id', + 'neg_sample': neg_sample }, 'val': { 'type': 'bert', 'query_sequence': 'query', 'pos_sequence': 'positive_passages', 'neg_sequence': 'negative_passages', - 'passage_text_fileds': ['title', 'text'], + 'text_fileds': ['title', 'text'], 'qid_field': 'query_id' }, } - cfg['train']['neg_samples'] = 4 cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30 cfg.train.max_epochs = 1 cfg.train.train_batch_size = 4 @@ -94,6 +97,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): 'by_epoch': False } } + cfg.model['neg_sample'] = 4 cfg.train.hooks = [{ 'type': 'CheckpointHook', 'interval': 1 @@ -105,27 +109,90 @@ class TestFinetuneSequenceClassification(unittest.TestCase): }, { 'type': 'EvaluationHook', 'by_epoch': False, - 'interval': 3000 + 'interval': 15 }] return cfg # load dataset ds = MsDataset.load('passage-ranking-demo', 'zyznull') train_ds = ds['train'].to_hf_dataset() - dev_ds = ds['train'].to_hf_dataset() + dev_ds = ds['dev'].to_hf_dataset() + model_id = 'damo/nlp_corom_passage-ranking_english-base' self.finetune( - model_id='damo/nlp_corom_passage-ranking_english-base', + model_id=model_id, train_dataset=train_ds, eval_dataset=dev_ds, cfg_modify_fn=cfg_modify_fn) output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) - self.pipeline_passage_ranking(output_dir) + self.pipeline_text_ranking(output_dir) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_finetune_dureader(self): + + def cfg_modify_fn(cfg): + cfg.task = 'text-ranking' + cfg['preprocessor'] = {'type': 'text-ranking'} + cfg.train.optimizer.lr = 2e-5 + cfg['dataset'] = { + 'train': { + 'type': 'bert', + 'query_sequence': 'query', + 'pos_sequence': 'positive_passages', + 'neg_sequence': 'negative_passages', + 'text_fileds': ['text'], + 'qid_field': 'query_id' + }, + 'val': { + 'type': 'bert', + 'query_sequence': 'query', + 'pos_sequence': 'positive_passages', + 'neg_sequence': 'negative_passages', + 'text_fileds': ['text'], + 'qid_field': 'query_id' + }, + } + cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30 + cfg.train.max_epochs = 1 + cfg.train.train_batch_size = 4 + cfg.train.lr_scheduler = { + 'type': 'LinearLR', + 'start_factor': 1.0, + 'end_factor': 0.0, + 'options': { + 'by_epoch': False + } + } + cfg.train.hooks = [{ + 'type': 'CheckpointHook', + 'interval': 1 + }, { + 'type': 'TextLoggerHook', + 'interval': 1 + }, { + 'type': 'IterTimerHook' + }, { + 'type': 'EvaluationHook', + 'by_epoch': False, + 'interval': 5000 + }] + return cfg + + # load dataset + ds = MsDataset.load('dureader-retrieval-ranking', 'zyznull') + train_ds = ds['train'].to_hf_dataset().shard(1000, index=0) + dev_ds = ds['dev'].to_hf_dataset() + model_id = 'damo/nlp_rom_passage-ranking_chinese-base' + self.finetune( + model_id=model_id, + train_dataset=train_ds, + eval_dataset=dev_ds, + cfg_modify_fn=cfg_modify_fn) - def pipeline_passage_ranking(self, model_dir): + def pipeline_text_ranking(self, model_dir): model = Model.from_pretrained(model_dir) - pipeline_ins = pipeline(task=Tasks.passage_ranking, model=model) + pipeline_ins = pipeline(task=Tasks.text_ranking, model=model) print(pipeline_ins(input=self.inputs)) diff --git a/tests/trainers/test_image_denoise_trainer.py b/tests/trainers/test_image_denoise_trainer.py index 261ee4ed..b742dcae 100644 --- a/tests/trainers/test_image_denoise_trainer.py +++ b/tests/trainers/test_image_denoise_trainer.py @@ -6,10 +6,12 @@ import unittest from modelscope.hub.snapshot_download import snapshot_download from modelscope.models.cv.image_denoise import NAFNetForImageDenoise -from modelscope.msdatasets.image_denoise_data import PairedImageDataset +from modelscope.msdatasets import MsDataset +from modelscope.msdatasets.task_datasets.sidd_image_denoising import \ + SiddImageDenoisingDataset from modelscope.trainers import build_trainer from modelscope.utils.config import Config -from modelscope.utils.constant import ModelFile +from modelscope.utils.constant import DownloadMode, ModelFile from modelscope.utils.logger import get_logger from modelscope.utils.test_utils import test_level @@ -28,10 +30,22 @@ class ImageDenoiseTrainerTest(unittest.TestCase): self.cache_path = snapshot_download(self.model_id) self.config = Config.from_file( os.path.join(self.cache_path, ModelFile.CONFIGURATION)) - self.dataset_train = PairedImageDataset( - self.config.dataset, self.cache_path, is_train=True) - self.dataset_val = PairedImageDataset( - self.config.dataset, self.cache_path, is_train=False) + dataset_train = MsDataset.load( + 'SIDD', + namespace='huizheng', + subset_name='default', + split='test', + download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds + dataset_val = MsDataset.load( + 'SIDD', + namespace='huizheng', + subset_name='default', + split='test', + download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds + self.dataset_train = SiddImageDenoisingDataset( + dataset_train, self.config.dataset, is_train=True) + self.dataset_val = SiddImageDenoisingDataset( + dataset_val, self.config.dataset, is_train=False) def tearDown(self): shutil.rmtree(self.tmp_dir, ignore_errors=True) diff --git a/tests/trainers/test_image_inpainting_trainer.py b/tests/trainers/test_image_inpainting_trainer.py new file mode 100644 index 00000000..807fe64f --- /dev/null +++ b/tests/trainers/test_image_inpainting_trainer.py @@ -0,0 +1,84 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.models.cv.image_inpainting import FFTInpainting +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config, ConfigDict +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class ImageInpaintingTrainerTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + self.model_id = 'damo/cv_fft_inpainting_lama' + self.cache_path = snapshot_download(self.model_id) + cfg = Config.from_file( + os.path.join(self.cache_path, ModelFile.CONFIGURATION)) + + train_data_cfg = ConfigDict( + name='PlacesToydataset', + split='train', + mask_gen_kwargs=cfg.dataset.mask_gen_kwargs, + out_size=cfg.dataset.train_out_size, + test_mode=False) + + test_data_cfg = ConfigDict( + name='PlacesToydataset', + split='test', + mask_gen_kwargs=cfg.dataset.mask_gen_kwargs, + out_size=cfg.dataset.val_out_size, + test_mode=True) + + self.train_dataset = MsDataset.load( + dataset_name=train_data_cfg.name, + split=train_data_cfg.split, + mask_gen_kwargs=train_data_cfg.mask_gen_kwargs, + out_size=train_data_cfg.out_size, + test_mode=train_data_cfg.test_mode) + assert next( + iter(self.train_dataset.config_kwargs['split_config'].values())) + + self.test_dataset = MsDataset.load( + dataset_name=test_data_cfg.name, + split=test_data_cfg.split, + mask_gen_kwargs=test_data_cfg.mask_gen_kwargs, + out_size=test_data_cfg.out_size, + test_mode=test_data_cfg.test_mode) + assert next( + iter(self.test_dataset.config_kwargs['split_config'].values())) + + def tearDown(self): + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer(self): + kwargs = dict( + model=self.model_id, + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset) + + trainer = build_trainer( + name=Trainers.image_inpainting, default_args=kwargs) + trainer.train() + results_files = os.listdir(trainer.work_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_image_portrait_enhancement_trainer.py b/tests/trainers/test_image_portrait_enhancement_trainer.py index 049adf7e..123e0098 100644 --- a/tests/trainers/test_image_portrait_enhancement_trainer.py +++ b/tests/trainers/test_image_portrait_enhancement_trainer.py @@ -14,52 +14,14 @@ from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import Trainers from modelscope.models.cv.image_portrait_enhancement import \ ImagePortraitEnhancement +from modelscope.msdatasets import MsDataset +from modelscope.msdatasets.task_datasets.image_portrait_enhancement import \ + ImagePortraitEnhancementDataset from modelscope.trainers import build_trainer -from modelscope.utils.constant import ModelFile +from modelscope.utils.constant import DownloadMode, ModelFile from modelscope.utils.test_utils import test_level -class PairedImageDataset(data.Dataset): - - def __init__(self, root, size=512): - super(PairedImageDataset, self).__init__() - self.size = size - gt_dir = osp.join(root, 'gt') - lq_dir = osp.join(root, 'lq') - self.gt_filelist = os.listdir(gt_dir) - self.gt_filelist = sorted(self.gt_filelist, key=lambda x: int(x[:-4])) - self.gt_filelist = [osp.join(gt_dir, f) for f in self.gt_filelist] - self.lq_filelist = os.listdir(lq_dir) - self.lq_filelist = sorted(self.lq_filelist, key=lambda x: int(x[:-4])) - self.lq_filelist = [osp.join(lq_dir, f) for f in self.lq_filelist] - - def _img_to_tensor(self, img): - img = torch.from_numpy(img[:, :, [2, 1, 0]]).permute(2, 0, 1).type( - torch.float32) / 255. - return (img - 0.5) / 0.5 - - def __getitem__(self, index): - lq = cv2.imread(self.lq_filelist[index]) - gt = cv2.imread(self.gt_filelist[index]) - lq = cv2.resize( - lq, (self.size, self.size), interpolation=cv2.INTER_CUBIC) - gt = cv2.resize( - gt, (self.size, self.size), interpolation=cv2.INTER_CUBIC) - - return \ - {'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)} - - def __len__(self): - return len(self.gt_filelist) - - def to_torch_dataset(self, - columns: Union[str, List[str]] = None, - preprocessors: Union[Callable, List[Callable]] = None, - **format_kwargs): - # self.preprocessor = preprocessors - return self - - class TestImagePortraitEnhancementTrainer(unittest.TestCase): def setUp(self): @@ -70,8 +32,23 @@ class TestImagePortraitEnhancementTrainer(unittest.TestCase): self.model_id = 'damo/cv_gpen_image-portrait-enhancement' - self.dataset = PairedImageDataset( - './data/test/images/face_enhancement/') + dataset_train = MsDataset.load( + 'image-portrait-enhancement-dataset', + namespace='modelscope', + subset_name='default', + split='test', + download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds + dataset_val = MsDataset.load( + 'image-portrait-enhancement-dataset', + namespace='modelscope', + subset_name='default', + split='test', + download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds + + self.dataset_train = ImagePortraitEnhancementDataset( + dataset_train, is_train=True) + self.dataset_val = ImagePortraitEnhancementDataset( + dataset_val, is_train=False) def tearDown(self): shutil.rmtree(self.tmp_dir, ignore_errors=True) @@ -81,8 +58,8 @@ class TestImagePortraitEnhancementTrainer(unittest.TestCase): def test_trainer(self): kwargs = dict( model=self.model_id, - train_dataset=self.dataset, - eval_dataset=self.dataset, + train_dataset=self.dataset_train, + eval_dataset=self.dataset_val, device='gpu', work_dir=self.tmp_dir) @@ -101,8 +78,8 @@ class TestImagePortraitEnhancementTrainer(unittest.TestCase): kwargs = dict( cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), model=model, - train_dataset=self.dataset, - eval_dataset=self.dataset, + train_dataset=self.dataset_train, + eval_dataset=self.dataset_val, device='gpu', max_epochs=2, work_dir=self.tmp_dir) diff --git a/tests/trainers/test_ofa_trainer.py b/tests/trainers/test_ofa_trainer.py new file mode 100644 index 00000000..06003625 --- /dev/null +++ b/tests/trainers/test_ofa_trainer.py @@ -0,0 +1,105 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import unittest + +import json + +from modelscope.metainfo import Metrics, Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.constant import ModelFile +from modelscope.utils.test_utils import test_level + + +class TestOfaTrainer(unittest.TestCase): + + def setUp(self) -> None: + self.finetune_cfg = \ + {'framework': 'pytorch', + 'task': 'image-captioning', + 'model': {'type': 'ofa', + 'beam_search': {'beam_size': 5, + 'max_len_b': 16, + 'min_len': 1, + 'no_repeat_ngram_size': 0}, + 'seed': 7, + 'max_src_length': 256, + 'language': 'en', + 'gen_type': 'generation', + 'patch_image_size': 480, + 'max_image_size': 480, + 'imagenet_default_mean_and_std': False}, + 'pipeline': {'type': 'image-captioning'}, + 'dataset': {'column_map': {'text': 'caption'}}, + 'train': {'work_dir': 'work/ckpts/caption', + # 'launcher': 'pytorch', + 'max_epochs': 1, + 'use_fp16': True, + 'dataloader': {'batch_size_per_gpu': 1, 'workers_per_gpu': 0}, + 'lr_scheduler': {'name': 'polynomial_decay', + 'warmup_proportion': 0.01, + 'lr_end': 1e-07}, + 'lr_scheduler_hook': {'type': 'LrSchedulerHook', 'by_epoch': False}, + 'optimizer': {'type': 'AdamW', 'lr': 5e-05, 'weight_decay': 0.01}, + 'optimizer_hook': {'type': 'TorchAMPOptimizerHook', + 'cumulative_iters': 1, + 'grad_clip': {'max_norm': 1.0, 'norm_type': 2}, + 'loss_keys': 'loss'}, + 'criterion': {'name': 'AdjustLabelSmoothedCrossEntropyCriterion', + 'constraint_range': None, + 'drop_worst_after': 0, + 'drop_worst_ratio': 0.0, + 'ignore_eos': False, + 'ignore_prefix_size': 0, + 'label_smoothing': 0.1, + 'reg_alpha': 1.0, + 'report_accuracy': False, + 'sample_patch_num': 196, + 'sentence_avg': False, + 'use_rdrop': False}, + 'hooks': [{'type': 'BestCkptSaverHook', + 'metric_key': 'bleu-4', + 'interval': 100}, + {'type': 'TextLoggerHook', 'interval': 1}, + {'type': 'IterTimerHook'}, + {'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}]}, + 'evaluation': {'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0}, + 'metrics': [{'type': 'bleu', + 'eval_tokenized_bleu': False, + 'ref_name': 'labels', + 'hyp_name': 'caption'}]}, + 'preprocessor': []} + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_std(self): + WORKSPACE = './workspace/ckpts/caption' + os.makedirs(WORKSPACE, exist_ok=True) + config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) + with open(config_file, 'w') as writer: + json.dump(self.finetune_cfg, writer) + + pretrained_model = 'damo/ofa_image-caption_coco_distilled_en' + args = dict( + model=pretrained_model, + work_dir=WORKSPACE, + train_dataset=MsDataset.load( + 'coco_2014_caption', + namespace='modelscope', + split='train[:20]'), + eval_dataset=MsDataset.load( + 'coco_2014_caption', + namespace='modelscope', + split='validation[:10]'), + metrics=[Metrics.BLEU], + cfg_file=config_file) + trainer = build_trainer(name=Trainers.ofa, default_args=args) + trainer.train() + + self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, + os.listdir(os.path.join(WORKSPACE, 'output'))) + shutil.rmtree(WORKSPACE) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_trainer_with_nlp.py b/tests/trainers/test_trainer_with_nlp.py index 6030ada9..8aaa42a3 100644 --- a/tests/trainers/test_trainer_with_nlp.py +++ b/tests/trainers/test_trainer_with_nlp.py @@ -7,8 +7,7 @@ import unittest from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import Metrics from modelscope.models.base import Model -from modelscope.models.nlp.sequence_classification import \ - SbertForSequenceClassification +from modelscope.models.nlp import SbertForSequenceClassification from modelscope.msdatasets import MsDataset from modelscope.pipelines import pipeline from modelscope.trainers import EpochBasedTrainer, build_trainer @@ -29,7 +28,8 @@ class TestTrainerWithNlp(unittest.TestCase): os.makedirs(self.tmp_dir) self.dataset = MsDataset.load( - 'afqmc_small', namespace='userxiaoming', split='train') + 'clue', subset_name='afqmc', + split='train').to_hf_dataset().select(range(2)) def tearDown(self): shutil.rmtree(self.tmp_dir) @@ -37,13 +37,12 @@ class TestTrainerWithNlp(unittest.TestCase): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_trainer(self): - model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' + model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' kwargs = dict( model=model_id, train_dataset=self.dataset, eval_dataset=self.dataset, - work_dir=self.tmp_dir, - model_revision='beta') + work_dir=self.tmp_dir) trainer = build_trainer(default_args=kwargs) trainer.train() @@ -73,15 +72,14 @@ class TestTrainerWithNlp(unittest.TestCase): output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) pipeline_sentence_similarity(output_dir) - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 3, 'skip test in current test level') def test_trainer_with_backbone_head(self): model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' kwargs = dict( model=model_id, train_dataset=self.dataset, eval_dataset=self.dataset, - work_dir=self.tmp_dir, - model_revision='beta') + work_dir=self.tmp_dir) trainer = build_trainer(default_args=kwargs) trainer.train() @@ -97,8 +95,10 @@ class TestTrainerWithNlp(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_trainer_with_user_defined_config(self): model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' - cfg = read_config(model_id, revision='beta') + cfg = read_config(model_id) cfg.train.max_epochs = 20 + cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} + cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} cfg.train.work_dir = self.tmp_dir cfg_file = os.path.join(self.tmp_dir, 'config.json') cfg.dump(cfg_file) @@ -106,8 +106,7 @@ class TestTrainerWithNlp(unittest.TestCase): model=model_id, train_dataset=self.dataset, eval_dataset=self.dataset, - cfg_file=cfg_file, - model_revision='beta') + cfg_file=cfg_file) trainer = build_trainer(default_args=kwargs) trainer.train() @@ -120,22 +119,24 @@ class TestTrainerWithNlp(unittest.TestCase): checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) self.assertTrue(Metrics.accuracy in eval_results) - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_trainer_with_configured_datasets(self): model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' cfg: Config = read_config(model_id) cfg.train.max_epochs = 20 + cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} + cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} cfg.train.work_dir = self.tmp_dir cfg.dataset = { 'train': { - 'name': 'afqmc_small', + 'name': 'clue', + 'subset_name': 'afqmc', 'split': 'train', - 'namespace': 'userxiaoming' }, 'val': { - 'name': 'afqmc_small', + 'name': 'clue', + 'subset_name': 'afqmc', 'split': 'train', - 'namespace': 'userxiaoming' }, } cfg_file = os.path.join(self.tmp_dir, 'config.json') @@ -159,11 +160,30 @@ class TestTrainerWithNlp(unittest.TestCase): model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' cfg: Config = read_config(model_id) cfg.train.max_epochs = 3 + cfg.preprocessor.first_sequence = 'sentence1' + cfg.preprocessor.second_sequence = 'sentence2' + cfg.preprocessor.label = 'label' + cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} + cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} + cfg.train.dataloader.batch_size_per_gpu = 2 + cfg.train.hooks = [{ + 'type': 'CheckpointHook', + 'interval': 3, + 'by_epoch': False, + }, { + 'type': 'TextLoggerHook', + 'interval': 1 + }, { + 'type': 'IterTimerHook' + }, { + 'type': 'EvaluationHook', + 'interval': 1 + }] cfg.train.work_dir = self.tmp_dir cfg_file = os.path.join(self.tmp_dir, 'config.json') cfg.dump(cfg_file) dataset = MsDataset.load('clue', subset_name='afqmc', split='train') - dataset = dataset.to_hf_dataset().select(range(128)) + dataset = dataset.to_hf_dataset().select(range(4)) kwargs = dict( model=model_id, train_dataset=dataset, @@ -180,7 +200,7 @@ class TestTrainerWithNlp(unittest.TestCase): PRIORITY = Priority.VERY_LOW def after_iter(self, trainer): - if trainer.iter == 12: + if trainer.iter == 3: raise MsRegressTool.EarlyStopError('Test finished.') if 'EarlyStopHook' not in [ @@ -197,12 +217,11 @@ class TestTrainerWithNlp(unittest.TestCase): results_files = os.listdir(self.tmp_dir) self.assertIn(f'{trainer.timestamp}.log.json', results_files) - trainer = build_trainer(default_args=kwargs) regress_tool = MsRegressTool(baseline=False) with regress_tool.monitor_ms_train( trainer, 'trainer_continue_train', level='strict'): - trainer.train(os.path.join(self.tmp_dir, 'iter_12.pth')) + trainer.train(os.path.join(self.tmp_dir, 'iter_3.pth')) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_trainer_with_model_and_args(self): @@ -211,7 +230,7 @@ class TestTrainerWithNlp(unittest.TestCase): os.makedirs(tmp_dir) model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' - cache_path = snapshot_download(model_id, revision='beta') + cache_path = snapshot_download(model_id) model = SbertForSequenceClassification.from_pretrained(cache_path) kwargs = dict( cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), diff --git a/tests/utils/test_ast.py b/tests/utils/test_ast.py index 9a8ab828..c0624679 100644 --- a/tests/utils/test_ast.py +++ b/tests/utils/test_ast.py @@ -41,7 +41,7 @@ class AstScaningTest(unittest.TestCase): self.assertIsInstance(from_imports, dict) self.assertIsInstance(decorators, list) self.assertListEqual(list(set(imports.keys()) - set(['torch'])), []) - self.assertEqual(len(from_imports.keys()), 7) + self.assertEqual(len(from_imports.keys()), 9) self.assertTrue(from_imports['modelscope.metainfo'] is not None) self.assertEqual(from_imports['modelscope.metainfo'], ['Pipelines']) self.assertEqual(decorators, diff --git a/tests/utils/test_compatibility.py b/tests/utils/test_compatibility.py new file mode 100644 index 00000000..f5222261 --- /dev/null +++ b/tests/utils/test_compatibility.py @@ -0,0 +1,19 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + + +class CompatibilityTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + def tearDown(self): + super().tearDown() + + def test_xtcocotools(self): + from xtcocotools.coco import COCO + + +if __name__ == '__main__': + unittest.main()