diff --git a/.Jenkinsfile b/.Jenkinsfile
index da1157e9..f2740e5b 100644
--- a/.Jenkinsfile
+++ b/.Jenkinsfile
@@ -1,36 +1,96 @@
pipeline {
- agent {
- docker {
- image 'ubuntu_tester'
- args '-u root:root -v ${JENKINS_HOME}/html/docs:/docs -v ${JENKINS_HOME}/html/_ci:/ci'
- }
+ agent any
+ options {
+ timeout(time:30, unit: 'MINUTES')
}
environment {
- TRAVIS = 1
PJ_NAME = 'fastNLP'
- POST_URL = 'https://open.feishu.cn/open-apis/bot/v2/hook/14719364-818d-4f88-9057-7c9f0eaaf6ae'
+ POST_URL = 'https://open.feishu.cn/open-apis/bot/v2/hook/2f7122e3-3459-43d2-a9e4-ddd77bfc4282'
}
stages {
- stage('Package Installation') {
- steps {
- sh 'python setup.py install'
- }
- }
stage('Parallel Stages') {
parallel {
- stage('Document Building') {
+ stage('Test Other'){
+ agent {
+ docker {
+ image 'fnlp:other'
+ args '-u root:root -v ${JENKINS_HOME}/html/docs:/docs -v ${JENKINS_HOME}/html/_ci:/ci'
+ }
+ }
steps {
- sh 'cd docs && make prod'
- sh 'rm -rf /docs/${PJ_NAME}'
- sh 'mv docs/build/html /docs/${PJ_NAME}'
+ sh 'pytest ./tests --durations=0 --html=other.html --self-contained-html -m "not (torch or paddle or paddledist or jittor or oneflow or deepspeed or oneflowdist or torchpaddle or torchjittor or torchoneflow)"'
+ }
+ post {
+ always {
+ sh 'html_path=/ci/${PJ_NAME}/report-${BUILD_NUMBER}-${GIT_BRANCH#*/}-${GIT_COMMIT} && mkdir -p ${html_path} && mv other.html ${html_path}'
+ }
}
}
- stage('Package Testing') {
+ stage('Test Torch-1.11') {
+ agent {
+ docker {
+ image 'fnlp:torch-1.11'
+ args '-u root:root -v ${JENKINS_HOME}/html/docs:/docs -v ${JENKINS_HOME}/html/_ci:/ci --gpus all --shm-size 1G'
+ }
+ }
steps {
- sh 'pip install fitlog'
- sh 'pytest ./tests --html=test_results.html --self-contained-html'
+ sh 'pytest ./tests/ --durations=0 --html=torch-1.11.html --self-contained-html -m torch'
+ }
+ post {
+ always {
+ sh 'html_path=/ci/${PJ_NAME}/report-${BUILD_NUMBER}-${GIT_BRANCH#*/}-${GIT_COMMIT} && mkdir -p ${html_path} && mv torch-1.11.html ${html_path}'
+ }
+ }
+ }
+ stage('Test Torch-1.6') {
+ agent {
+ docker {
+ image 'fnlp:torch-1.6'
+ args '-u root:root -v ${JENKINS_HOME}/html/docs:/docs -v ${JENKINS_HOME}/html/_ci:/ci --gpus all --shm-size 1G'
+ }
+ }
+ steps {
+ sh 'pytest ./tests/ --durations=0 --html=torch-1.6.html --self-contained-html -m torch'
+ }
+ post {
+ always {
+ sh 'html_path=/ci/${PJ_NAME}/report-${BUILD_NUMBER}-${GIT_BRANCH#*/}-${GIT_COMMIT} && mkdir -p ${html_path} && mv torch-1.6.html ${html_path}'
+ }
+ }
+ }
+ stage('Test Paddle') {
+ agent {
+ docker {
+ image 'fnlp:paddle'
+ args '-u root:root -v ${JENKINS_HOME}/html/docs:/docs -v ${JENKINS_HOME}/html/_ci:/ci --gpus all --shm-size 1G'
+ }
+ }
+ steps {
+ sh 'pytest ./tests --durations=0 --html=paddle.html --self-contained-html -m paddle --co'
+ sh 'FASTNLP_BACKEND=paddle pytest ./tests --durations=0 --html=paddle_with_backend.html --self-contained-html -m paddle --co'
+ sh 'FASTNLP_BACKEND=paddle pytest ./tests/core/drivers/paddle_driver/test_dist_utils.py --durations=0 --html=paddle_dist_utils.html --self-contained-html --co'
+ sh 'FASTNLP_BACKEND=paddle pytest ./tests/core/drivers/paddle_driver/test_fleet.py --durations=0 --html=paddle_fleet.html --self-contained-html --co'
+ sh 'FASTNLP_BACKEND=paddle pytest ./tests/core/controllers/test_trainer_paddle.py --durations=0 --html=paddle_trainer.html --self-contained-html --co'
+ }
+ post {
+ always {
+ sh 'html_path=/ci/${PJ_NAME}/report-${BUILD_NUMBER}-${GIT_BRANCH#*/}-${GIT_COMMIT} && mkdir -p ${html_path} && mv paddle*.html ${html_path}'
+ }
}
}
+ // stage('Test Jittor') {
+ // agent {
+ // docker {
+ // image 'fnlp:jittor'
+ // args '-u root:root -v ${JENKINS_HOME}/html/docs:/docs -v ${JENKINS_HOME}/html/_ci:/ci --gpus all --shm-size 1G'
+ // }
+ // }
+ // steps {
+ // // sh 'pip install fitlog'
+ // // sh 'pytest ./tests --html=test_results.html --self-contained-html'
+ // sh 'pytest ./tests --durations=0 --html=jittor.html --self-contained-html -m jittor --co'
+ // }
+ // }
}
}
}
@@ -40,8 +100,7 @@ pipeline {
}
success {
sh 'post 0'
- sh 'post github'
+ // sh 'post github'
}
}
-
}
\ No newline at end of file
diff --git a/.coverage b/.coverage
deleted file mode 100644
index a6d89bc8..00000000
--- a/.coverage
+++ /dev/null
@@ -1 +0,0 @@
-!coverage.py: This is a private format, don't read it directly!{"lines":{"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/__init__.py":[12,14,15,18,19,20,22,23,24,26,27,29,30,31,32,33,34,35,37,38,39,41,42,43,45,46,47,48,50,51,52,53,55,56,57,58,59,60,62,64,66,68,69,70,71,72],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/embeddings/__init__.py":[6,9,10,11,12,13,14,15,16,17,18,21,22,23,24,25,26,27],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/embeddings/embedding.py":[128,129,130,131,4,133,7,8,11,12,13,140,15,141,142,18,146,148,143,144,145,155,157,39,41,169,43,45,174,47,48,177,178,49,51,181,182,52,55,185,186,179,60,61,63,193,68,199,72,201,73,75,76,205,82,85,86,87,89,90,91,93,104,111,119,120,121,122,123,124,125,126,127],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/embeddings/utils.py":[4,5,6,7,9,42,43,12,44,45,46,16,24,57,26,27,28,25,31],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/core/__init__.py":[13,15,17,19,20,21,22,24,26,27,28,30,32,33,35,36,37,38,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,56,57,58,59,60,61,63,64,65,67,68,69,70,72,73,74,75,78,79,80,83,84,85,86,87,88,89,90,91,92,93,94],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/core/_logger.py":[1,130,131,4,132,134,7,8,9,10,11,137,13,140,15,16,143,19,20,24,25,26,155,27,29,30,31,32,33,45,46,47,49,50,51,52,53,56,78,79,80,83,84,88,92,94,95,99,100,101,102,103,106,107,108,110,114,119,125,127],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/core/batch.py":[4,6,7,8,11,13,14,15,16,18,19,20,21,24,29,32,33,34,35,36,37,38,40,42,43,44,45,47,50,57,58,59,60,61,62,63,64,65,67,68,69,70,73,74,75,76,80,81,83,84,85,87,92,99,100,101,102,103,105,106,108,109,112,113,114,115,116,117,119,120,122,124,125,126,127,129,130,131,132,133,135,136,138,139,141,146,171,174,175,176,177,178,181,182,183,184,185,186,187,189,190,193,194,202,204,207,211,215,223,224,225,226,227,228,229,230,233],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/core/sampler.py":[3,5,6,7,8,134,135,11,140,13,137,16,149,150,151,24,153,26,155,156,158,160,34,162,163,164,166,165,40,167,42,170,43,46,52,54,55,58,186,187,188,190,191,192,193,68,70,71,72,73,75,83,84,86,87,89,90,91,92,93,94,96,97,98,100,102,103,104,105,106,107,108,109,110,112,113,114,115,117,120],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/core/dataset.py":[515,516,518,532,543,550,552,554,560,562,570,578,585,586,587,589,590,592,606,607,608,609,610,611,617,619,631,632,633,634,635,640,641,643,660,676,688,694,696,702,704,722,723,725,726,727,728,729,734,737,738,859,740,742,751,752,753,754,755,756,757,758,862,760,761,762,763,764,765,766,767,768,770,771,772,774,791,792,793,794,795,796,285,287,290,291,803,293,294,806,296,297,298,299,300,301,302,303,811,305,809,824,314,316,317,318,319,320,321,834,835,836,837,838,322,323,324,325,326,327,328,334,329,332,337,849,338,339,340,342,335,344,857,858,347,348,861,345,350,346,865,864,863,866,860,351,868,354,353,356,871,867,875,869,870,360,363,364,877,365,367,369,883,884,886,376,377,378,379,380,381,382,383,384,385,386,387,388,894,895,896,897,402,409,410,412,413,415,420,421,422,423,425,426,427,431,432,434,872,441,443,445,447,451,452,453,454,459,873,474,807,486,487,488,490,491,493,499,500,502,503,505,506,507,509],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/core/field.py":[4,7,8,9,12,13,14,15,16,18,19,21,22,535,25,26,27,28,29,30,33,34,35,36,37,38,41,43,44,557,46,559,560,47,562,48,52,53,54,563,56,57,58,59,60,564,62,63,64,568,570,67,68,572,70,71,72,65,74,578,76,580,78,590,80,591,585,83,592,85,593,87,594,89,595,596,597,598,599,95,96,97,98,99,100,613,101,614,102,615,609,616,617,618,104,106,108,622,624,113,114,115,116,629,117,118,120,119,122,130,131,132,133,134,135,136,137,138,651,139,653,140,141,142,146,147,148,149,150,663,152,659,661,157,158,159,160,162,165,677,167,169,681,682,683,685,686,175,687,177,178,688,180,181,182,183,184,690,691,187,692,693,190,694,192,697,200,201,202,205,206,207,209,211,212,214,220,221,222,226,45,236,242,244,252,254,255,256,257,259,261,278,565,566,567,298,569,571,318,573,574,575,339,576,577,579,359,581,582,379,584,586,626,398,419,695,428,429,430,431,432,433,434,435,436,437,438,439,441,443,444,445,446,447,448,450,451,452,453,454,455,456,458,459,460,465,482,484,485,487,490,491,610],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/core/utils.py":[3,5,517,6,7,518,10,11,12,13,14,521,16,17,18,19,20,524,22,23,527,528,26,27,535,29,536,540,541,542,543,35,544,545,547,546,40,548,551,552,553,554,46,550,49,563,564,53,565,567,568,569,59,60,574,62,63,64,67,522,592,599,609,615,530,118,119,120,121,122,124,125,126,127,641,129,130,132,131,134,647,648,649,650,651,652,135,139,140,656,142,144,147,659,145,146,662,663,664,151,666,667,152,669,670,153,672,673,674,163,676,165,678,679,168,681,682,643,685,644,645,192,709,217,218,219,220,222,736,738,227,739,740,226,229,232,233,230,148,231,745,234,235,149,236,237,238,239,240,244,245,241,242,243,246,247,248,249,250,251,252,253,254,255,256,259,260,263,154,271,273,274,156,277,157,280,158,642,288,289,159,291,292,293,294,295,296,297,298,161,301,316,333,334,335,337,338,339,340,341,342,343,344,345,346,347,348,349,350,351,352,353,354,355,356,357,358,359,360,361,364,388,389,390,391,392,393,396,397,405,411,413,416,417,421,430,433,436,437,438,439,440,290,445,449,451,452,454,456,457,458,460,463,465,466,469,470,471,475,476,477,478,479,480,485,496,497,498,499,500,501,502,503,505,506,507,508,509,510,511],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/core/instance.py":[58,5,37,39,7,11,46,47,48,52,53,55,56,24,26,59,28,30],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/core/const.py":[4,7,11,29,30,31,32,33,34,35,36,37,39,42,43,45,51,56,61,64,65,67,70,71,73,76,77,79],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/core/callback.py":[1024,513,1026,1030,1031,1036,529,531,1043,1044,1060,1071,562,51,53,1077,55,56,57,58,59,60,61,62,63,64,65,66,576,68,69,1092,583,72,73,74,76,78,80,81,84,85,87,88,89,91,92,603,606,97,612,106,108,621,109,623,110,113,111,118,120,123,125,128,130,133,135,648,138,140,143,145,660,148,150,153,155,159,161,674,164,166,681,169,171,683,685,686,175,687,177,688,179,692,693,183,696,701,189,191,703,705,706,708,197,710,199,721,722,210,212,723,726,724,728,729,730,220,733,222,741,229,231,743,745,746,748,237,749,239,750,751,752,753,756,245,754,247,758,761,759,252,765,254,766,767,768,770,771,260,773,262,774,775,776,778,779,780,781,782,783,777,785,786,275,787,788,789,791,790,273,794,283,795,796,797,287,799,289,800,801,802,805,293,295,303,818,820,821,310,311,312,313,822,315,316,823,318,830,824,321,322,826,836,828,827,831,832,833,841,329,331,332,333,334,839,336,337,842,851,339,340,341,852,855,348,349,350,863,351,353,864,357,870,871,361,875,365,369,881,373,377,889,890,381,385,389,902,393,907,397,912,401,405,410,411,922,929,420,428,945,946,437,961,964,455,968,457,459,461,462,463,468,469,471,472,473,987,479,482,504,489,491,1003,492,493,494,496,1009,497,1014,1016,506,1020],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/core/tester.py":[34,35,37,38,40,41,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,62,66,92,94,95,97,100,102,103,104,105,106,107,109,110,111,118,119,121,127,131,132,134,138,139,141,148,149,150,151,152,153,154,155,156,158,159,160,162,164,165,166,167,170,171,173,174,176,177,178,181,182,183,184,185,187,188,189,190,191,192,194,195,196,197,199,206,207,209,211,213,214,215,217,223,224,225,226,227,228],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/core/metrics.py":[4,6,7,8,9,12,13,15,16,18,19,20,21,22,23,24,25,26,29,117,119,120,121,122,124,133,137,138,141,151,156,158,165,166,179,180,181,182,183,185,186,187,188,192,193,194,195,200,208,209,213,215,230,231,235,236,239,240,241,242,246,247,249,250,253,254,255,256,257,258,259,262,263,264,265,267,270,271,272,274,277,278,279,280,281,282,284,285,286,287,288,290,292,295,305,307,309,311,313,314,316,329,330,332,336,340,341,343,345,346,347,348,350,354,355,356,357,359,360,362,369,370,371,372,373,376,386,388,389,390,391,392,393,394,395,396,398,399,400,401,402,406,437,468,477,479,480,481,482,483,484,485,486,488,489,491,492,493,496,504,505,506,507,508,509,510,511,512,514,515,516,520,561,564,566,568,570,573,574,575,576,577,578,579,580,581,582,586,587,588,589,590,592,593,595,597,598,599,601,609,612,616,620,622,623,624,625,633,634,635,636,637,638,640,641,643,644,646,647,648,649,651,652,653,655,657,658,659,660,661,662,663,664,665,666,667,668,669,670,671,672,673,674,675,676,677,678,679,681,686,687,688,689,690,691,692,694,695,696,697,699,700,702,704,712,713,714,716,719,726,727,728,729,730,732,733,734,736,738,742,743,747,750,759,760,761,762,763,766,776,777,778,779,780,781,784,799,802,804,806,808,810,811,813,814,816,818,819,820,821,823,825,827,836,837,838,839,841,842,845,846,850,851,852,853,855,856,857,858,859,862,863,864,865,867,868,870,873,875,876,878,879,880,881,883,884,885,887,888,891,893,895,897,900,901,903,905,907,909,910,911,912,914,916,918,919,920,921,923,929,932,933,935,936,937,939,940,942,944,945,946,947,949],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/core/vocabulary.py":[4,7,8,11,12,13,15,16,17,18,21,26,35,40,42,43,44,46,49,54,56,57,58,59,61,62,64,67,90,92,93,94,95,96,97,98,99,100,102,104,105,116,117,118,120,121,133,134,135,137,145,146,147,148,149,150,151,153,154,166,168,169,181,182,184,190,191,192,193,194,195,197,198,199,200,201,202,203,204,205,206,207,209,214,215,217,219,221,229,231,242,244,251,252,253,254,258,259,273,279,280,282,283,285,287,289,291,292,295,296,297,301,302,303,304,305,311,313,317,337,338,342,343,344,345,346,348,349,350,352,354,355,356,358,359,360,361,368,369,370,371,377,379,385,387,398,400,401,406,407,408,410,411,416,417,418,420,428,430,443,447,448,450,451,453,457,458,460,463,465,466],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/core/_parallel_utils.py":[1,97,3,5,7,8,9,10,11,76,104,14,105,107],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/core/losses.py":[4,6,8,9,11,12,13,14,17,18,20,21,23,24,25,26,27,28,29,30,31,34,37,39,40,41,43,52,55,62,63,76,77,78,79,80,82,83,84,85,89,90,91,92,102,110,112,113,114,115,119,120,122,123,125,126,127,128,129,130,131,134,135,136,137,139,141,142,143,145,148,149,150,151,152,153,155,156,157,158,160,162,163,165,168,188,190,192,193,194,195,198,201,222,224,225,226,227,228,229,230,232,233,234,235,236,239,240,241,242,243,245,246,249,259,261,262,263,264,265,267,268,271,280,282,283,284,285,286,288,289,292,303,305,306,307,308,309,310,312,313,316,323,325,326,327,329,331,332,333,334,335,336,337,338,339,340,341,343,345,347,353,356,357,358,359,360,361,366,374,377,386,387,395,410,432],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/core/optimizer.py":[4,6,7,8,9,135,138,12,13,14,15,18,151,24,26,27,156,29,30,32,35,41,43,47,48,51,54,61,68,70,71,72,73,75,76,78,80,83,90,92,93,95,96,98,99,101,103,106],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/core/trainer.py":[517,518,519,521,522,523,524,525,526,527,528,529,530,531,532,533,534,536,537,538,539,540,541,545,547,548,551,552,553,554,555,556,557,558,559,560,561,562,564,565,567,570,571,573,593,594,598,599,600,601,602,603,604,606,607,608,609,619,620,621,622,623,624,625,626,627,628,629,630,634,635,637,639,640,641,643,644,645,646,647,648,649,650,651,652,653,654,656,657,658,659,660,662,663,666,667,668,669,672,673,674,676,677,679,680,681,682,683,685,686,687,688,689,690,691,693,694,695,696,697,698,700,701,705,707,708,711,712,713,715,716,717,720,721,722,723,724,725,727,728,857,730,737,740,742,746,747,352,749,750,751,752,755,757,764,765,766,768,775,777,800,802,812,813,816,818,823,824,825,826,827,829,319,831,321,832,835,324,325,326,833,328,329,330,843,332,333,841,847,336,848,338,851,340,341,342,343,344,339,853,854,855,349,350,351,856,345,346,858,347,348,864,865,868,869,353,354,355,356,358,872,873,875,876,877,878,879,880,881,882,883,884,885,886,887,888,889,890,891,892,893,895,896,898,899,900,901,902,903,904,905,907,908,909,910,911,913,914,915,916,917,918,919,920,924,925,927,928,418,932,936,425,426,427,937,941,942,939,940,431,943,433,944,945,947,437,438,948,949,441,954,950,444,951,958,449,450,961,962,964,454,965,456,968,458,970,971,974,466,482,484,485,489,490,491,498,499,502,503,506,507,510,511],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/embeddings/static_embedding.py":[4,7,9,11,12,13,14,16,17,18,19,20,21,22,25,66,69,70,71,72,75,76,77,78,79,83,84,91,92,119,121,122,123,124,127,128,130,133,134,135,136,140,141,142,143,144,146,147,148,150,151,153,154,155,156,158,164,165,166,167,168,169,171,179,181,182,186,188,202,204,205,207,209,210,226,227,229,230,231,232,233,237,238,239,240,241,242,243,244,245,246,247,248,249,250,252,254,257,258,259,260,261,262,269,270,271,272,275,277,279,283,284,285,286,287,288,290,292,299,300,301,302,303,304],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/io/__init__.py":[13,15,17,19,21,22,23,24,25,26,28,29,30,31,32,33,34,35,37,38,40,42,43,44,45,46,48,50,51,52,53,54,55,57,58,59,60,61,63,65,66,67,68,69,70,71,72,73,74,75,76,78,79,83,84,85,87,88],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/io/embed_loader.py":[4,6,7,10,11,12,14,16,17,20,22,23,24,25,34,39,41,44,45,46,63,64,66,67,68,69,70,71,72,73,75,76,77,78,80,81,82,83,84,86,88,90,91,92,93,100,101,102,103,104,105,106,107,108,109,111,112,114,116,117,118,133,134,135,136,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,157,166,168,169,170,171,173,174,175,176,177,178,180,181,182,183,184,185,187,188,190],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/io/data_bundle.py":[4,6,9,10,13,142,27,29,30,31,159,33,45,55,184,64,74,203,83,92,117],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/io/model_io.py":[32,3,5,6,9,42,12,17,19,53,22,55,62],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/io/loader/__init__.py":[44,47,49,50,51,52,53,54,56,57,58,59,60,61,62,63,65,66,68,70,71,72,73,74,76,77,78,79,80,81,82,83],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/io/loader/classification.py":[1,259,4,5,6,7,8,9,261,264,12,13,14,15,16,17,19,20,21,279,24,291,164,45,47,304,50,178,180,306,309,183,72,73,201,339,244,119,120],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/io/loader/loader.py":[65,66,1,4,33,70,7,67,9,10,11,12,68,78,15,19,21,22,24,63],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/io/file_utils.py":[4,7,8,9,10,11,14,15,16,17,18,19,21,22,23,25,28,29,30,32,33,35,36,38,40,41,43,44,45,46,50,51,52,53,54,58,60,61,62,63,64,65,66,67,68,69,71,73,74,76,77,78,79,83,84,85,86,87,88,89,90,91,92,93,94,96,97,98,99,102,103,104,107,108,109,110,114,159,186,202,228,252,273,293,306,418,427,434,443],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/io/utils.py":[33,34,35,4,36,7,10,11,12,14,17,81],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/io/loader/conll.py":[1,4,5,6,7,8,9,10,11,12,15,16,17,18,19,146,21,22,23,24,25,150,278,28,279,282,286,287,408,421,175,177,183,62,446,64,448,451,325,204,78,208,92,349,222,273,224,351,354,227,404,117,405,119,125],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/io/file_reader.py":[33,34,3,35,5,7,9,41,42,12,43,78,47,44,24,25,26,30],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/io/loader/csv.py":[32,1,34,33,4,35,36,7,8,9,10,37,13,24,26,27,28,29,30],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/io/loader/cws.py":[1,4,38,7,8,9,10,11,39,13,14,15,47,18,56],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/io/loader/json.py":[1,4,38,7,8,9,10,13,25,27],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/io/loader/matching.py":[1,129,4,5,6,7,8,11,12,13,15,16,17,18,19,20,273,277,23,159,35,37,40,170,298,300,303,184,186,189,318,66,216,98,228,109,241,243,246,120,122],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/io/pipe/__init__.py":[9,11,13,15,16,17,18,19,21,22,23,24,25,26,28,29,30,31,32,33,34,35,36,37,38,39,42,43,44,46,47,48],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/io/pipe/classification.py":[1,4,5,6,7,8,134,264,11,392,13,15,16,17,18,19,20,21,22,408,24,410,28,414,32,34,37,172,52,182,315,320,449,195,197,70,201,333,335,339,89,218,247,228,104,106,119,249,382],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/io/pipe/pipe.py":[1,4,7,10,13,14,23],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/io/pipe/utils.py":[1,66,153,4,5,6,39,9,137,11,12,15,87,121,91],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/io/pipe/conll.py":[1,4,5,6,7,8,9,12,13,14,15,16,17,18,19,20,141,272,23,286,288,34,36,293,43,306,308,182,313,192,328,330,79,208,210,215,225,98,227,100,233,113,114],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/io/pipe/matching.py":[128,1,129,259,4,5,6,7,8,9,10,11,12,13,14,15,135,140,18,19,20,21,22,146,147,25,152,260,134,169,42,171,44,177,50,265,266,191,64,141,271,272,247,248,122,123,253,254],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/io/pipe/cws.py":[1,4,7,8,136,10,11,12,13,14,17,155,157,34,168,50,65,202,84,110,254],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/modules/__init__.py":[18,22,23,25,27,29,31,33,34,35,37,38,39,40,42,44,45,46,47,49,52,53,54,55,56],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/modules/decoder/__init__.py":[4,6,7,8,9,12,13,14,15],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/modules/decoder/crf.py":[1,4,5,8,9,11,12,15,29,31,32,33,34,35,36,37,38,40,41,42,43,44,46,47,48,50,51,52,53,54,55,56,57,58,59,60,63,73,74,75,76,93,94,95,96,97,98,102,121,122,123,124,125,126,127,128,157,170,173,175,177,178,181,182,183,184,186,187,192,194,196,204,205,206,207,209,211,212,213,214,215,216,218,219,221,223,231,232,233,236,237,238,240,242,243,244,245,246,247,248,250,252,261,262,263,264,265,267,269,282,283,284,287,288,289,290,291,295,296,297,298,299,300,301,302,303,304,306,310,311,312,314,316,317,318,319,320,321,322,323,328,329],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/modules/utils.py":[4,134,7,8,11,12,14,15,16,19,35,37,39,41,43,45,47,49,52,54,56,57,60,61,62,63,64,65,67,68,69,70,72,73,74,75,77,80,83,120],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/modules/decoder/mlp.py":[1,4,7,8,10,13,44,46,47,48,49,50,51,52,53,55,57,60,61,62,64,65,71,72,73,75,76,79,86,88,93,94,95,96,98,99],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/modules/decoder/utils.py":[1,4,6,9],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/modules/encoder/__init__.py":[4,9,10,12,14,16,18,20,21,22,24,25,26,27,29,32,33,34,35,36,37,38,39,40],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/modules/encoder/attention.py":[128,1,4,132,7,9,10,11,13,16,20,22,23,24,25,26,27,28,30,38,39,40,41,42,43,46,175,55,184,57,186,58,59,60,61,62,64,65,66,67,69,198,70,71,73,74,75,76,77,78,80,212,88,89,90,92,93,94,97,98,99,100,101,102,105,106,107,110,126],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/modules/encoder/bert.py":[512,4,517,7,10,11,12,13,14,15,17,18,20,21,22,24,25,28,30,44,571,70,586,587,75,76,77,591,78,79,80,81,82,83,84,85,600,86,87,92,100,107,621,110,115,119,632,125,126,129,133,136,654,149,150,153,154,667,155,156,158,159,160,161,162,165,167,169,170,171,172,173,689,177,178,180,181,182,183,184,187,188,189,703,191,192,193,194,197,198,199,200,715,204,205,206,208,209,210,212,214,727,215,216,217,219,220,221,222,224,225,226,229,230,743,744,232,747,235,239,241,242,243,244,245,248,249,250,251,252,253,255,256,257,258,259,262,263,776,264,265,266,268,269,270,271,274,275,786,276,277,278,279,283,796,284,285,286,289,290,291,292,293,294,296,809,297,298,299,300,303,304,816,305,306,307,308,310,311,312,313,314,317,318,319,320,833,321,323,324,325,326,327,328,329,330,331,334,335,848,336,337,338,340,852,854,343,344,345,346,349,877,374,376,377,378,385,386,387,388,389,390,391,393,396,909,399,400,401,402,403,404,406,407,409,410,417,424,425,427,428,429,430,431,432,433,434,435,437,500,509,510],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/modules/encoder/char_encoder.py":[1,4,5,7,8,10,14,25,27,28,29,30,32,34,36,41,43,45,47,48,49,50,52,54,55,57,58,61,68,70,77,78,80,81,82,83,84,85,87,92,93,94,95,96,98,99],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/modules/encoder/conv_maxpool.py":[1,4,6,7,8,11,23,25,26,28,29,32,33,36,37,38,43,52,59,60,69,77,79,80,81,82,84,85,86],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/modules/encoder/lstm.py":[4,7,10,11,12,15,30,33,34,35,36,37,38,40,41,42,44,45,46,47,49,51,61,62,65,66,67,68,69,72,73,74,75,76,77,82],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/modules/encoder/pooling.py":[1,129,4,5,6,7,135,9,10,137,13,141,25,27,38,62,67,69,73,85,86,88,92,102,107,109,114],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/modules/encoder/star_transformer.py":[3,6,9,10,11,12,15,32,34,35,36,38,40,41,42,43,44,45,46,48,49,53,63,65,67,68,69,71,72,76,77,78,79,80,81,82,83,85,87,89,91,94,95,96,99,100,101,102,104,107,109,111,112,114,116,117,118,119,120,121,122,123,124,125,126,127,129,130,132,134,137,138,140,141,142,143,144,146,149,151,153,154,156,158,159,160,161,162,163,164,165,166],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/modules/encoder/transformer.py":[1,4,6,8,9,12,26,28,29,30,31,32,33,34,35,36,37,39,46,47,48,49,50,51,52,54,55,56,58,65,66,69,70,71,72,73],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/modules/dropout.py":[1,4,7,10,14,16,17,18,19,20,24],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/modules/encoder/variational_rnn.py":[3,6,7,8,11,12,13,15,16,25,28,31,33,34,35,36,37,38,40,52,53,54,55,56,58,59,60,61,62,63,64,66,67,69,70,73,74,75,76,77,79,80,81,82,83,84,85,86,87,88,89,96,97,98,99,102,120,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,146,147,148,149,150,151,152,153,155,163,164,165,166,167,168,169,170,172,173,175,176,177,178,179,181,182,183,184,185,186,187,188,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,210,212,213,215,216,218,219,221,224,239,241,242,243,245,246,249,264,266,270,274,289,291,295],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/embeddings/elmo_embedding.py":[4,7,136,10,11,12,13,14,15,141,17,18,19,20,21,23,155,163,171,173,305,58,61,92,99,111,119],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/modules/encoder/_elmo.py":[514,3,515,5,7,263,9,10,11,12,264,14,528,17,409,410,309,56,65,453,327,328,85,98,493,239,240,251,510],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/embeddings/contextual_embedding.py":[99,4,7,104,10,12,76,14,15,16,17,18,19,20,23,24,27],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/embeddings/bert_embedding.py":[4,7,8,135,11,12,14,15,16,17,271,19,20,21,22,23,24,149,273,27,157,168,171,186,67,198,71,203,207,211,215,95,98,227,361,115,250],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/embeddings/char_embedding.py":[4,7,8,11,12,13,14,16,17,18,19,20,21,22,25,57,61,62,64,65,67,68,70,71,72,85,87,88,89,91,92,93,94,95,98,99,101,104,106,108,109,110,111,113,120,121,122,123,124,125,127,128,129,130,131,132,133,134,135,136,137,138,142,143,145,161,168,169,170,172,173,174,175,177,180,211,216,217,219,221,222,224,225,226,239,241,242,243,245,246,247,248,249,252,253,255,258,260,261,263,264,265,267,274,275,276,277,278,279,281,282,283,284,285,286,289,290,291,292,297,299,301,318],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/embeddings/stack_embedding.py":[4,7,10,12,13,15,18,37,39,40,41,42,43,44,45,46,48,49,50,51,52,53,55,64,71,75,87,92,99,100,101,102,103,104],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/models/__init__.py":[32,33,34,9,11,13,14,16,18,19,20,21,23,24,27,28,30,31],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/models/base_model.py":[32,1,33,3,5,7,10,12,14,15,17,20,24,25,26,27,29,30],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/models/bert.py":[4,6,8,10,11,13,14,15,16,17,20,57,58,59,60,61,65,67,68,69,71,77,78,80,81,82,83,84,86,91,93,98,135,136,137,138,139,142,144,145,146,148,154,155,156,157,158,159,160,161,162,164,169,171,176,215,216,217,218,219,222,224,225,226,228,234,235,236,237,239,251,253,258,300,301,302,303,306,308,311,313,319,320,321,322,323,324,326,343,345],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/models/biaffine_parser.py":[3,5,517,6,520,9,10,11,12,522,14,523,16,17,18,19,20,21,22,23,24,25,530,536,28,539,542,534,544,33,34,35,36,37,38,39,40,41,545,546,547,548,46,47,48,49,50,51,52,53,45,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,73,74,75,76,524,77,78,79,80,525,81,84,82,87,526,527,92,93,94,95,528,96,97,99,101,102,103,104,105,107,108,109,110,111,112,531,114,115,116,117,118,119,120,121,122,532,124,125,126,533,128,131,136,138,139,141,142,151,152,153,154,155,156,157,158,160,161,170,171,172,173,174,175,176,177,178,179,182,188,190,191,192,193,194,195,198,200,549,207,208,209,210,211,42,214,43,222,44,224,225,226,227,229,236,237,238,241,262,275,276,277,278,279,280,281,282,283,284,285,286,287,288,289,290,291,292,293,294,295,296,297,305,306,307,308,310,311,312,313,314,315,316,317,318,322,323,324,325,326,327,328,329,330,331,333,334,335,336,337,338,339,341,342,344,362,366,368,369,371,372,373,376,377,378,379,380,381,382,383,385,386,387,391,392,393,394,397,400,402,403,405,406,416,417,418,419,420,421,422,424,437,438,439,440,441,442,443,444,445,446,447,449,450,451,452,453,454,456,469,470,471,472,473,474,477,489,493,494,495,496,497,498,499,502],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/models/cnn_text_classification.py":[4,7,10,11,13,14,15,16,19,32,38,39,42,43,44,45,46,47,48,50,57,58,59,60,62,63,64,65,67,74,75,76],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/models/sequence_labeling.py":[3,5,6,10,11,12,14,15,16,17,18,19,20,21,22,25,39,41,61,75,78,82,93,95,96,98,99,100,101,102,104,112,113,114,116,118,120,122,124,132,134,136,138,140,141,143,151,152,153,154,155,156,158,159,160,161,162,163,165,170,171,174,189,191,193,195,196,197,198,199,200,201,202,203,204,206,207,213,218,219,221,229,230,231,232,233,234,236,237,238,239,240,241,243,252,253,254,257,259,263,264,267,269,270,271,272,273,274,275,277,279,287,289,296],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/models/snli.py":[4,6,9,10,11,12,14,15,16,17,20,32,35,36,38,41,42,43,44,45,48,49,50,51,52,54,57,58,59,60,61,63,65,66,68,77,78,79,80,81,82,83,87,89,90,91,92,94,95,99,100,101,102,104,105,107,113,115,116,117,121,122,123,124,126,127,128,129,130,131,134,136,137,138,139,142,143,144,145,146,147,148,149,151,153,154,155,156,158,160,162,165,167,168,169,174,177,178,179,182,183,184,185,186,187,189,190,193,194,195,196,197,198,199,202,204,205,208,209,211,213,214,215,216,217,218,220],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/models/star_transformer.py":[3,5,6,7,8,11,12,14,15,16,17,20,36,38,46,47,48,49,51,52,53,54,55,56,58,67,68,69,70,73,74,75,76,77,78,79,80,83,84,85,88,89,90,91,92,93,94,95,96,99,100,101,102,105,123,133,134,135,136,137,138,139,140,141,142,143,145,152,153,154,155,156,158,165,166,167,170,188,198,199,200,201,202,203,204,205,206,207,208,210,217,218,219,220,221,223,230,231,232,235,253,263,264,265,266,267,268,269,270,271,272,273,275,284,285,287,288,289,291,292,293,294,296,305,306,307],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/core/dist_trainer.py":[3,4,5,6,7,9,10,11,12,13,14,15,16,18,19,20,21,22,23,24,25,26,152,29,30,157,34,169,47,304,50,179,183,312,58,320,332,343,355,229],"/hdd/fudanNLP/fastNLP/fastNLP/fastNLP/core/predictor.py":[1,4,7,9,11,12,13,14,17,25,27,28,31,32,33,35,42,44,47,48,49,50,51,53,56,58,59,60,61,62,64,67,68,69,70,80,81]}}
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
deleted file mode 100644
index 17f7654f..00000000
--- a/.gitignore
+++ /dev/null
@@ -1,18 +0,0 @@
-.gitignore
-
-.DS_Store
-.ipynb_checkpoints
-*.pyc
-__pycache__
-*.swp
-.vscode/
-.idea/**
-
-caches
-
-# fitlog
-.fitlog
-logs/
-.fitconfig
-
-docs/build
diff --git a/.travis.yml b/.travis.yml
deleted file mode 100644
index 4fc99810..00000000
--- a/.travis.yml
+++ /dev/null
@@ -1,30 +0,0 @@
-language: python
-python:
- - "3.6"
-
-env:
- - TRAVIS=1
-
-# command to install dependencies
-install:
- - pip install --quiet -r requirements.txt
- - pip install --quiet fitlog
- - pip install pytest>=3.6
- - pip install pytest-cov
-# command to run tests
-script:
-# - python -m spacy download en
- - pytest --cov=fastNLP tests/
-
-after_success:
- - bash <(curl -s https://codecov.io/bash)
-
-notifications:
- webhooks:
- urls:
- - https://open.feishu.cn/officialapp/notify/55ba4b15d04608e875c122f11484a4e2fa807c42b9ca074509bea654d1b99ca6
- on_success: always # default: always
- on_failure: always # default: always
- on_start: never # default: never
- on_cancel: always # default: always
- on_error: always # default: always
diff --git a/MANIFEST.in b/MANIFEST.in
index 61279be1..52f2484f 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -2,6 +2,4 @@ include requirements.txt
include LICENSE
include README.md
prune tests/
-prune reproduction/
-prune fastNLP/api
-prune fastNLP/automl
\ No newline at end of file
+prune tutorials/
\ No newline at end of file
diff --git a/README.md b/README.md
index 017eae52..c0af9ebc 100644
--- a/README.md
+++ b/README.md
@@ -1,110 +1,239 @@
# fastNLP
-[![Build Status](https://travis-ci.org/fastnlp/fastNLP.svg?branch=master)](https://travis-ci.org/fastnlp/fastNLP)
-[![codecov](https://codecov.io/gh/fastnlp/fastNLP/branch/master/graph/badge.svg)](https://codecov.io/gh/fastnlp/fastNLP)
-[![Pypi](https://img.shields.io/pypi/v/fastNLP.svg)](https://pypi.org/project/fastNLP)
-![Hex.pm](https://img.shields.io/hexpm/l/plug.svg)
-[![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest)
-fastNLP是一款面向自然语言处理(NLP)的轻量级框架,目标是快速实现NLP任务以及构建复杂模型。
+[//]: # ([![Build Status](https://travis-ci.org/fastnlp/fastNLP.svg?branch=master)](https://travis-ci.org/fastnlp/fastNLP))
-fastNLP具有如下的特性:
+[//]: # ([![codecov](https://codecov.io/gh/fastnlp/fastNLP/branch/master/graph/badge.svg)](https://codecov.io/gh/fastnlp/fastNLP))
-- 统一的Tabular式数据容器,简化数据预处理过程;
-- 内置多种数据集的Loader和Pipe,省去预处理代码;
-- 各种方便的NLP工具,例如Embedding加载(包括ELMo和BERT)、中间数据cache等;
-- 部分[数据集与预训练模型](https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0)的自动下载;
-- 提供多种神经网络组件以及复现模型(涵盖中文分词、命名实体识别、句法分析、文本分类、文本匹配、指代消解、摘要等任务);
-- Trainer提供多种内置Callback函数,方便实验记录、异常捕获等。
+[//]: # ([![Pypi](https://img.shields.io/pypi/v/fastNLP.svg)](https://pypi.org/project/fastNLP))
-## 安装指南
+[//]: # (![Hex.pm](https://img.shields.io/hexpm/l/plug.svg))
-fastNLP 依赖以下包:
+[//]: # ([![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest))
-+ numpy>=1.14.2
-+ torch>=1.0.0
-+ tqdm>=4.28.1
-+ nltk>=3.4.1
-+ requests
-+ spacy
-+ prettytable>=0.7.2
-其中torch的安装可能与操作系统及 CUDA 的版本相关,请参见 [PyTorch 官网](https://pytorch.org/) 。
-在依赖包安装完成后,您可以在命令行执行如下指令完成安装
+fastNLP是一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等。
-```shell
-pip install fastNLP
-python -m spacy download en
-```
+fastNLP具有如下的特性:
+- 便捷。在数据处理中可以通过apply函数避免循环、使用多进程提速等;在训练循环阶段可以很方便定制操作。
+- 高效。无需改动代码,实现fp16切换、多卡、ZeRO优化等。
+- 兼容。fastNLP支持多种深度学习框架作为后端。
-## fastNLP教程
-中文[文档](http://www.fastnlp.top/docs/fastNLP/)、 [教程](http://www.fastnlp.top/docs/fastNLP/user/quickstart.html)
+> :warning: **为了实现对不同深度学习架构的兼容,fastNLP 1.0.0之后的版本重新设计了架构,因此与过去的fastNLP版本不完全兼容,
+> 基于更早的fastNLP代码需要做一定的调整**:
-### 快速入门
+## fastNLP文档
+[中文文档](http://www.fastnlp.top/docs/fastNLP/master/index.html)
-- [Quick-1. 文本分类](http://www.fastnlp.top/docs/fastNLP/tutorials/%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB.html)
-- [Quick-2. 序列标注](http://www.fastnlp.top/docs/fastNLP/tutorials/%E5%BA%8F%E5%88%97%E6%A0%87%E6%B3%A8.html)
+## 安装指南
+fastNLP可以通过以下的命令进行安装
+```shell
+pip install fastNLP>=1.0.0alpha
+```
+如果需要安装更早版本的fastNLP请指定版本号,例如
+```shell
+pip install fastNLP==0.7.1
+```
+另外,请根据使用的深度学习框架,安装相应的深度学习框架。
+
+
+Pytorch
+下面是使用pytorch来进行文本分类的例子。需要安装torch>=1.6.0。
+
+```python
+from fastNLP.io import ChnSentiCorpLoader
+from functools import partial
+from fastNLP import cache_results
+from fastNLP.transformers.torch import BertTokenizer
+
+# 使用cache_results装饰器装饰函数,将prepare_data的返回结果缓存到caches/cache.pkl,再次运行时,如果
+# 该文件还存在,将自动读取缓存文件,而不再次运行预处理代码。
+@cache_results('caches/cache.pkl')
+def prepare_data():
+ # 会自动下载数据,并且可以通过文档看到返回的 dataset 应该是包含"raw_words"和"target"两个field的
+ data_bundle = ChnSentiCorpLoader().load()
+ # 使用tokenizer对数据进行tokenize
+ tokenizer = BertTokenizer.from_pretrained('hfl/chinese-bert-wwm')
+ tokenize = partial(tokenizer, max_length=256) # 限制数据的最大长度
+ data_bundle.apply_field_more(tokenize, field_name='raw_chars', num_proc=4) # 会新增"input_ids", "attention_mask"等field进入dataset中
+ data_bundle.apply_field(int, field_name='target', new_field_name='labels') # 将int函数应用到每个target上,并且放入新的labels field中
+ return data_bundle
+data_bundle = prepare_data()
+print(data_bundle.get_dataset('train')[:4])
+
+# 初始化model, optimizer
+from fastNLP.transformers.torch import BertForSequenceClassification
+from torch import optim
+model = BertForSequenceClassification.from_pretrained('hfl/chinese-bert-wwm')
+optimizer = optim.AdamW(model.parameters(), lr=2e-5)
+
+# 准备dataloader
+from fastNLP import prepare_dataloader
+dls = prepare_dataloader(data_bundle, batch_size=32)
+
+# 准备训练
+from fastNLP import Trainer, Accuracy, LoadBestModelCallback, TorchWarmupCallback, Event
+callbacks = [
+ TorchWarmupCallback(warmup=0.1, schedule='linear'), # 训练过程中调整学习率。
+ LoadBestModelCallback() # 将在训练结束之后,加载性能最优的model
+]
+# 在训练特定时机加入一些操作, 不同时机能够获取到的参数不一样,可以通过Trainer.on函数的文档查看每个时机的参数
+@Trainer.on(Event.on_before_backward())
+def print_loss(trainer, outputs):
+ if trainer.global_forward_batches % 10 == 0: # 每10个batch打印一次loss。
+ print(outputs.loss.item())
+
+trainer = Trainer(model=model, train_dataloader=dls['train'], optimizers=optimizer,
+ device=0, evaluate_dataloaders=dls['dev'], metrics={'acc': Accuracy()},
+ callbacks=callbacks, monitor='acc#acc',n_epochs=5,
+ # Accuracy的update()函数需要pred,target两个参数,它们实际对应的就是以下的field。
+ evaluate_input_mapping={'labels': 'target'}, # 在评测时,将dataloader中会输入到模型的labels重新命名为target
+ evaluate_output_mapping={'logits': 'pred'} # 在评测时,将model输出中的logits重新命名为pred
+ )
+trainer.run()
+
+# 在测试集合上进行评测
+from fastNLP import Evaluator
+evaluator = Evaluator(model=model, dataloaders=dls['test'], metrics={'acc': Accuracy()},
+ # Accuracy的update()函数需要pred,target两个参数,它们实际对应的就是以下的field。
+ output_mapping={'logits': 'pred'},
+ input_mapping={'labels': 'target'})
+evaluator.run()
+```
-### 详细使用教程
+更多内容可以参考如下的链接
+### 快速入门
-- [1. 使用DataSet预处理文本](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_1_data_preprocess.html)
-- [2. 使用Vocabulary转换文本与index](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_2_vocabulary.html)
-- [3. 使用Embedding模块将文本转成向量](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_3_embedding.html)
-- [4. 使用Loader和Pipe加载并处理数据集](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_4_load_dataset.html)
-- [5. 动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_5_loss_optimizer.html)
-- [6. 动手实现一个文本分类器II-使用DataSetIter实现自定义训练过程](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_6_datasetiter.html)
-- [7. 使用Metric快速评测你的模型](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_7_metrics.html)
-- [8. 使用Modules和Models快速搭建自定义模型](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_8_modules_models.html)
-- [9. 使用Callback自定义你的训练过程](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_9_callback.html)
+- [0. 10 分钟快速上手 fastNLP torch](http://www.fastnlp.top/docs/fastNLP/master/tutorials/torch/fastnlp_torch_tutorial.html)
-### 扩展教程
+### 详细使用教程
-- [Extend-1. BertEmbedding的各种用法](http://www.fastnlp.top/docs/fastNLP/tutorials/extend_1_bert_embedding.html)
-- [Extend-2. 分布式训练简介](http://www.fastnlp.top/docs/fastNLP/tutorials/extend_2_dist.html)
-- [Extend-3. 使用fitlog 辅助 fastNLP 进行科研](http://www.fastnlp.top/docs/fastNLP/tutorials/extend_3_fitlog.html)
+- [1. Trainer 和 Evaluator 的基本使用](http://www.fastnlp.top/docs/fastNLP/master/tutorials/basic/fastnlp_tutorial_0.html)
+- [2. DataSet 和 Vocabulary 的基本使用](http://www.fastnlp.top/docs/fastNLP/master/tutorials/basic/fastnlp_tutorial_1.html)
+- [3. DataBundle 和 Tokenizer 的基本使用](http://www.fastnlp.top/docs/fastNLP/master/tutorials/basic/fastnlp_tutorial_2.html)
+- [4. TorchDataloader 的内部结构和基本使用](http://www.fastnlp.top/docs/fastNLP/master/tutorials/basic/fastnlp_tutorial_3.html)
+- [5. fastNLP 中的预定义模型](http://www.fastnlp.top/docs/fastNLP/master/tutorials/basic/fastnlp_tutorial_4.html)
+- [6. Trainer 和 Evaluator 的深入介绍](http://www.fastnlp.top/docs/fastNLP/master/tutorials/basic/fastnlp_tutorial_4.html)
+- [7. fastNLP 与 paddle 或 jittor 的结合](http://www.fastnlp.top/docs/fastNLP/master/tutorials/basic/fastnlp_tutorial_5.html)
+- [8. 使用 Bert + fine-tuning 完成 SST-2 分类](http://www.fastnlp.top/docs/fastNLP/master/tutorials/basic/fastnlp_tutorial_e1.html)
+- [9. 使用 Bert + prompt 完成 SST-2 分类](http://www.fastnlp.top/docs/fastNLP/master/tutorials/basic/fastnlp_tutorial_e2.html)
+
+
+
+
+
+Paddle
+下面是使用paddle来进行文本分类的例子。需要安装paddle>=2.2.0以及paddlenlp>=2.3.3。
+
+```python
+from fastNLP.io import ChnSentiCorpLoader
+from functools import partial
+
+# 会自动下载数据,并且可以通过文档看到返回的 dataset 应该是包含"raw_words"和"target"两个field的
+data_bundle = ChnSentiCorpLoader().load()
+
+# 使用tokenizer对数据进行tokenize
+from paddlenlp.transformers import BertTokenizer
+tokenizer = BertTokenizer.from_pretrained('hfl/chinese-bert-wwm')
+tokenize = partial(tokenizer, max_length=256) # 限制一下最大长度
+data_bundle.apply_field_more(tokenize, field_name='raw_chars', num_proc=4) # 会新增"input_ids", "attention_mask"等field进入dataset中
+data_bundle.apply_field(int, field_name='target', new_field_name='labels') # 将int函数应用到每个target上,并且放入新的labels field中
+print(data_bundle.get_dataset('train')[:4])
+
+# 初始化 model
+from paddlenlp.transformers import BertForSequenceClassification, LinearDecayWithWarmup
+from paddle import optimizer, nn
+class SeqClsModel(nn.Layer):
+ def __init__(self, model_checkpoint, num_labels):
+ super(SeqClsModel, self).__init__()
+ self.num_labels = num_labels
+ self.bert = BertForSequenceClassification.from_pretrained(model_checkpoint)
+
+ def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):
+ logits = self.bert(input_ids, token_type_ids, position_ids, attention_mask)
+ return logits
+
+ def train_step(self, input_ids, labels, token_type_ids=None, position_ids=None, attention_mask=None):
+ logits = self(input_ids, token_type_ids, position_ids, attention_mask)
+ loss_fct = nn.CrossEntropyLoss()
+ loss = loss_fct(logits.reshape((-1, self.num_labels)), labels.reshape((-1, )))
+ return {
+ "logits": logits,
+ "loss": loss,
+ }
+
+ def evaluate_step(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):
+ logits = self(input_ids, token_type_ids, position_ids, attention_mask)
+ return {
+ "logits": logits,
+ }
+
+model = SeqClsModel('hfl/chinese-bert-wwm', num_labels=2)
+
+# 准备dataloader
+from fastNLP import prepare_dataloader
+dls = prepare_dataloader(data_bundle, batch_size=16)
+
+# 训练过程中调整学习率。
+scheduler = LinearDecayWithWarmup(2e-5, total_steps=20 * len(dls['train']), warmup=0.1)
+optimizer = optimizer.AdamW(parameters=model.parameters(), learning_rate=scheduler)
+
+# 准备训练
+from fastNLP import Trainer, Accuracy, LoadBestModelCallback, Event
+callbacks = [
+ LoadBestModelCallback() # 将在训练结束之后,加载性能最优的model
+]
+# 在训练特定时机加入一些操作, 不同时机能够获取到的参数不一样,可以通过Trainer.on函数的文档查看每个时机的参数
+@Trainer.on(Event.on_before_backward())
+def print_loss(trainer, outputs):
+ if trainer.global_forward_batches % 10 == 0: # 每10个batch打印一次loss。
+ print(outputs["loss"].item())
+
+trainer = Trainer(model=model, train_dataloader=dls['train'], optimizers=optimizer,
+ device=0, evaluate_dataloaders=dls['dev'], metrics={'acc': Accuracy()},
+ callbacks=callbacks, monitor='acc#acc',
+ # Accuracy的update()函数需要pred,target两个参数,它们实际对应的就是以下的field。
+ evaluate_output_mapping={'logits': 'pred'},
+ evaluate_input_mapping={'labels': 'target'}
+ )
+trainer.run()
+
+# 在测试集合上进行评测
+from fastNLP import Evaluator
+evaluator = Evaluator(model=model, dataloaders=dls['test'], metrics={'acc': Accuracy()},
+ # Accuracy的update()函数需要pred,target两个参数,它们实际对应的就是以下的field。
+ output_mapping={'logits': 'pred'},
+ input_mapping={'labels': 'target'})
+evaluator.run()
+```
+更多内容可以参考如下的链接
+### 快速入门
-## 内置组件
+- [0. 10 分钟快速上手 fastNLP paddle](http://www.fastnlp.top/docs/fastNLP/master/tutorials/torch/fastnlp_torch_tutorial.html)
-大部分用于的 NLP 任务神经网络都可以看做由词嵌入(embeddings)和两种模块:编码器(encoder)、解码器(decoder)组成。
+### 详细使用教程
-以文本分类任务为例,下图展示了一个BiLSTM+Attention实现文本分类器的模型流程图:
+- [1. 使用 paddlenlp 和 fastNLP 实现中文文本情感分析](http://www.fastnlp.top/docs/fastNLP/master/tutorials/paddle/fastnlp_tutorial_paddle_e1.html)
+- [2. 使用 paddlenlp 和 fastNLP 训练中文阅读理解任务](http://www.fastnlp.top/docs/fastNLP/master/tutorials/paddle/fastnlp_tutorial_paddle_e2.html)
+
-![](./docs/source/figures/text_classification.png)
+
+oneflow
+
-fastNLP 在 embeddings 模块中内置了几种不同的embedding:静态embedding(GloVe、word2vec)、上下文相关embedding
-(ELMo、BERT)、字符embedding(基于CNN或者LSTM的CharEmbedding)
-与此同时,fastNLP 在 modules 模块中内置了两种模块的诸多组件,可以帮助用户快速搭建自己所需的网络。 两种模块的功能和常见组件如下:
-
-
- 类型 |
- 功能 |
- 例子 |
-
-
- encoder |
- 将输入编码为具有具有表示能力的向量 |
- Embedding, RNN, CNN, Transformer, ...
- |
-
- decoder |
- 将具有某种表示意义的向量解码为需要的输出形式 |
- MLP, CRF, ... |
-
-
+
+jittor
+
## 项目结构
-
-
-
-
-fastNLP的大致工作流程如上图所示,而项目结构如下:
+fastNLP的项目结构如下:
@@ -135,4 +264,3 @@ fastNLP的大致工作流程如上图所示,而项目结构如下:
-*In memory of @FengZiYjun. May his soul rest in peace. We will miss you very very much!*
diff --git a/codecov.yml b/codecov.yml
deleted file mode 100644
index f91e0445..00000000
--- a/codecov.yml
+++ /dev/null
@@ -1,5 +0,0 @@
-ignore:
-- "reproduction" # ignore folders and all its contents
-- "setup.py"
-- "docs"
-- "tutorials"
\ No newline at end of file
diff --git a/docs/Makefile b/docs/Makefile
index 35306867..fd4035db 100644
--- a/docs/Makefile
+++ b/docs/Makefile
@@ -6,24 +6,35 @@ SPHINXOPTS =
SPHINXAPIDOC = sphinx-apidoc
SPHINXBUILD = sphinx-build
SPHINXPROJ = fastNLP
+SPHINXEXCLUDE = ../fastNLP/transformers/*
SOURCEDIR = source
BUILDDIR = build
+PORT = 8000
# Put it first so that "make" without argument is like "make help".
help:
- @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS)
apidoc:
- $(SPHINXAPIDOC) -efM -o source ../$(SPHINXPROJ)
+ $(SPHINXAPIDOC) -efM -o source ../$(SPHINXPROJ) $(SPHINXEXCLUDE)
server:
- cd build/html && python -m http.server
+ cd build/html && python -m http.server $(PORT)
+
+delete:
+ rm -f source/$(SPHINXPROJ).* source/modules.rst && rm -rf build
+
+web:
+ make html && make server
dev:
- rm -f source/$(SPHINXPROJ).* source/modules.rst && rm -rf build && make apidoc && make html && make server
+ make delete && make apidoc && make html && make server
+
+versions:
+ sphinx-multiversion "$(SOURCEDIR)" "$(BUILDDIR)" && cd build && python -m http.server $(PORT)
prod:
- make apidoc && make html
+ make apidoc && make html
.PHONY: help Makefile
diff --git a/docs/README.md b/docs/README.md
deleted file mode 100644
index 2bb6953c..00000000
--- a/docs/README.md
+++ /dev/null
@@ -1,40 +0,0 @@
-# 快速入门 fastNLP 文档编写
-
-本教程为 fastNLP 文档编写者创建,文档编写者包括合作开发人员和文档维护人员。您在一般情况下属于前者,
-只需要了解整个框架的部分内容即可。
-
-## 合作开发人员
-
-FastNLP的文档使用基于[reStructuredText标记语言](http://docutils.sourceforge.net/rst.html)的
-[Sphinx](http://sphinx.pocoo.org/)工具生成,由[Read the Docs](https://readthedocs.org/)网站自动维护生成。
-一般开发者只要编写符合reStructuredText语法规范的文档并通过[PR](https://help.github.com/en/articles/about-pull-requests),
-就可以为fastNLP的文档贡献一份力量。
-
-如果你想在本地编译文档并进行大段文档的编写,您需要安装Sphinx工具以及sphinx-rtd-theme主题:
-```bash
-fastNLP/docs> pip install sphinx
-fastNLP/docs> pip install sphinx-rtd-theme
-```
-然后在本目录下执行 `make dev` 命令。该命令只支持Linux和MacOS系统,期望看到如下输出:
-```bash
-fastNLP/docs> make dev
-rm -rf build/html && make html && make server
-Running Sphinx v1.5.6
-making output directory...
-......
-Build finished. The HTML pages are in build/html.
-cd build/html && python -m http.server
-Serving HTTP on 0.0.0.0 port 8000 (http://0.0.0.0:8000/) ...
-```
-现在您浏览器访问 http://localhost:8000/ 查看文档。如果你在远程服务器尚进行工作,则访问地址为 http://{服务器的ip地址}:8000/ 。
-但您必须保证服务器的8000端口是开放的。如果您的电脑或远程服务器的8000端口被占用,程序会顺延使用8001、8002……等端口。
-当你结束访问时,您可以使用Control(Ctrl) + C 来结束进程。
-
-我们在[这里](./source/user/example.rst)列举了fastNLP文档经常用到的reStructuredText语法(网页查看请结合Raw模式),
-您可以通过阅读它进行快速上手。FastNLP大部分的文档都是写在代码中通过Sphinx工具进行抽取生成的,
-
-## 文档维护人员
-
-文档维护人员需要了解 Makefile 中全部命令的含义,并了解到目前的文档结构
-是在 sphinx-apidoc 自动抽取的基础上进行手动修改得到的。
-文档维护人员应进一步提升整个框架的自动化程度,并监督合作开发人员不要破坏文档项目的整体结构。
\ No newline at end of file
diff --git a/docs/check_tools.py b/docs/check_tools.py
deleted file mode 100644
index 59d942fc..00000000
--- a/docs/check_tools.py
+++ /dev/null
@@ -1,191 +0,0 @@
-import inspect
-import os
-import sys
-
-
-def _colored_string(string: str, color: str or int) -> str:
- """在终端中显示一串有颜色的文字
- :param string: 在终端中显示的文字
- :param color: 文字的颜色
- :return:
- """
- if isinstance(color, str):
- color = {
- "black": 30, "Black": 30, "BLACK": 30,
- "red": 31, "Red": 31, "RED": 31,
- "green": 32, "Green": 32, "GREEN": 32,
- "yellow": 33, "Yellow": 33, "YELLOW": 33,
- "blue": 34, "Blue": 34, "BLUE": 34,
- "purple": 35, "Purple": 35, "PURPLE": 35,
- "cyan": 36, "Cyan": 36, "CYAN": 36,
- "white": 37, "White": 37, "WHITE": 37
- }[color]
- return "\033[%dm%s\033[0m" % (color, string)
-
-
-def gr(string, flag):
- if flag:
- return _colored_string(string, "green")
- else:
- return _colored_string(string, "red")
-
-
-def find_all_modules():
- modules = {}
- children = {}
- to_doc = set()
- root = '../fastNLP'
- for path, dirs, files in os.walk(root):
- for file in files:
- if file.endswith('.py'):
- name = ".".join(path.split('/')[1:])
- if file.split('.')[0] != "__init__":
- name = name + '.' + file.split('.')[0]
- __import__(name)
- m = sys.modules[name]
- modules[name] = m
- try:
- m.__all__
- except:
- print(name, "__all__ missing")
- continue
- if m.__doc__ is None:
- print(name, "__doc__ missing")
- continue
- if "undocumented" not in m.__doc__:
- to_doc.add(name)
- for module in to_doc:
- t = ".".join(module.split('.')[:-1])
- if t in to_doc:
- if t not in children:
- children[t] = set()
- children[t].add(module)
- for m in children:
- children[m] = sorted(children[m])
- return modules, to_doc, children
-
-
-def create_rst_file(modules, name, children):
- m = modules[name]
- with open("./source/" + name + ".rst", "w") as fout:
- t = "=" * len(name)
- fout.write(name + "\n")
- fout.write(t + "\n")
- fout.write("\n")
- fout.write(".. automodule:: " + name + "\n")
- if name != "fastNLP.core" and len(m.__all__) > 0:
- fout.write(" :members: " + ", ".join(m.__all__) + "\n")
- short = name[len("fastNLP."):]
- if not (short.startswith('models') or short.startswith('modules') or short.startswith('embeddings')):
- fout.write(" :inherited-members:\n")
- fout.write("\n")
- if name in children:
- fout.write("子模块\n------\n\n.. toctree::\n :maxdepth: 1\n\n")
- for module in children[name]:
- fout.write(" " + module + "\n")
-
-
-def check_file(m, name):
- names = name.split('.')
- test_name = "test." + ".".join(names[1:-1]) + ".test_" + names[-1]
- try:
- __import__(test_name)
- tm = sys.modules[test_name]
- except ModuleNotFoundError:
- tm = None
- tested = tm is not None
- funcs = {}
- classes = {}
- for item, obj in inspect.getmembers(m):
- if inspect.isclass(obj) and obj.__module__ == name and not obj.__name__.startswith('_'):
- this = (obj.__doc__ is not None, tested and obj.__name__ in dir(tm), {})
- for i in dir(obj):
- func = getattr(obj, i)
- if inspect.isfunction(func) and not i.startswith('_'):
- this[2][i] = (func.__doc__ is not None, False)
- classes[obj.__name__] = this
- if inspect.isfunction(obj) and obj.__module__ == name and not obj.__name__.startswith('_'):
- this = (obj.__doc__ is not None, tested and obj.__name__ in dir(tm)) # docs
- funcs[obj.__name__] = this
- return funcs, classes
-
-
-def check_files(modules, out=None):
- for name in sorted(modules.keys()):
- print(name, file=out)
- funcs, classes = check_file(modules[name], name)
- if out is None:
- for f in funcs:
- print("%-30s \t %s \t %s" % (f, gr("文档", funcs[f][0]), gr("测试", funcs[f][1])))
- for c in classes:
- print("%-30s \t %s \t %s" % (c, gr("文档", classes[c][0]), gr("测试", classes[c][1])))
- methods = classes[c][2]
- for f in methods:
- print(" %-28s \t %s" % (f, gr("文档", methods[f][0])))
- else:
- for f in funcs:
- if not funcs[f][0]:
- print("缺少文档 %s" % (f), file=out)
- if not funcs[f][1]:
- print("缺少测试 %s" % (f), file=out)
- for c in classes:
- if not classes[c][0]:
- print("缺少文档 %s" % (c), file=out)
- if not classes[c][1]:
- print("缺少测试 %s" % (c), file=out)
- methods = classes[c][2]
- for f in methods:
- if not methods[f][0]:
- print("缺少文档 %s" % (c + "." + f), file=out)
- print(file=out)
-
-
-def main_check():
- sys.path.append("..")
- print(_colored_string('Getting modules...', "Blue"))
- modules, to_doc, children = find_all_modules()
- print(_colored_string('Done!', "Green"))
- print(_colored_string('Creating rst files...', "Blue"))
- for name in to_doc:
- create_rst_file(modules, name, children)
- print(_colored_string('Done!', "Green"))
- print(_colored_string('Checking all files...', "Blue"))
- check_files(modules, out=open("results.txt", "w"))
- print(_colored_string('Done!', "Green"))
-
-
-def check_file_r(file_path):
- with open(file_path) as fin:
- content = fin.read()
- index = -3
- cuts = []
- while index != -1:
- index = content.find('"""',index+3)
- cuts.append(index)
- cuts = cuts[:-1]
- assert len(cuts)%2 == 0
- write_content = ""
- last = 0
- for i in range(len(cuts)//2):
- start, end = cuts[i+i], cuts[i+i+1]
- if content[start-1] == "r":
- write_content += content[last:end+3]
- else:
- write_content += content[last:start] + "r"
- write_content += content[start:end+3]
- last = end + 3
- write_content += content[last:]
- with open(file_path, "w") as fout:
- fout.write(write_content)
-
-
-def add_r(base_path='../fastNLP'):
- for path, _, files in os.walk(base_path):
- for f in files:
- if f.endswith(".py"):
- check_file_r(os.path.abspath(os.path.join(path,f)))
- # sys.exit(0)
-
-
-if __name__ == "__main__":
- add_r()
diff --git a/docs/requirements.txt b/docs/requirements.txt
index cfa9c93a..91e78913 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,4 +1,4 @@
-sphinx==3.2.1
-docutils==0.16
-sphinx-rtd-theme==0.5.0
-readthedocs-sphinx-search==0.1.0rc3
\ No newline at end of file
+sphinx
+sphinx_rtd_theme
+sphinx_autodoc_typehints
+sphinx-multiversion
\ No newline at end of file
diff --git a/docs/source/_static/notebooks/extend_1_bert_embedding.ipynb b/docs/source/_static/notebooks/extend_1_bert_embedding.ipynb
deleted file mode 100644
index 2169c8b5..00000000
--- a/docs/source/_static/notebooks/extend_1_bert_embedding.ipynb
+++ /dev/null
@@ -1,260 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# BertEmbedding的各种用法\n",
- "Bert自从在 BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding 中被提出后,因其性能卓越受到了极大的关注,在这里我们展示一下在fastNLP中如何使用Bert进行各类任务。其中中文Bert我们使用的模型的权重来自于 中文Bert预训练 。\n",
- "\n",
- "为了方便大家的使用,fastNLP提供了预训练的Embedding权重及数据集的自动下载,支持自动下载的Embedding和数据集见 数据集 。或您可从 使用Embedding模块将文本转成向量 与 使用Loader和Pipe加载并处理数据集 了解更多相关信息\n",
- "\n",
- "\n",
- "下面我们将介绍通过使用Bert来进行文本分类, 中文命名实体识别, 文本匹配, 中文问答。\n",
- "\n",
- "## 1. 使用Bert进行文本分类\n",
- "\n",
- "文本分类是指给定一段文字,判定其所属的类别。例如下面的文本情感分类\n",
- "\n",
- " *1, 商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错!*\n",
- "\n",
- "这里我们使用fastNLP提供自动下载的微博分类进行测试"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from fastNLP.io import WeiboSenti100kPipe\n",
- "from fastNLP.embeddings import BertEmbedding\n",
- "from fastNLP.models import BertForSequenceClassification\n",
- "from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam\n",
- "import torch\n",
- "\n",
- "data_bundle =WeiboSenti100kPipe().process_from_file()\n",
- "data_bundle.rename_field('chars', 'words')\n",
- "\n",
- "# 载入BertEmbedding\n",
- "embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn-wwm', include_cls_sep=True)\n",
- "\n",
- "# 载入模型\n",
- "model = BertForSequenceClassification(embed, len(data_bundle.get_vocab('target')))\n",
- "\n",
- "# 训练模型\n",
- "device = 0 if torch.cuda.is_available() else 'cpu' \n",
- "trainer = Trainer(data_bundle.get_dataset('train'), model,\n",
- " optimizer=Adam(model_params=model.parameters(), lr=2e-5),\n",
- " loss=CrossEntropyLoss(), device=device,\n",
- " batch_size=8, dev_data=data_bundle.get_dataset('dev'),\n",
- " metrics=AccuracyMetric(), n_epochs=2, print_every=1)\n",
- "trainer.train()\n",
- "\n",
- "# 测试结果\n",
- "from fastNLP import Tester\n",
- "\n",
- "tester = Tester(data_bundle.get_dataset('test'), model, batch_size=128, metrics=AccuracyMetric())\n",
- "tester.test()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 2. 使用Bert进行命名实体识别\n",
- "\n",
- "命名实体识别是给定一句话,标记出其中的实体。一般序列标注的任务都使用conll格式,conll格式是至一行中通过制表符分隔不同的内容,使用空行分隔 两句话,例如下面的例子\n",
- "\n",
- "```\n",
- " 中 B-ORG\n",
- " 共 I-ORG\n",
- " 中 I-ORG\n",
- " 央 I-ORG\n",
- " 致 O\n",
- " 中 B-ORG\n",
- " 国 I-ORG\n",
- " 致 I-ORG\n",
- " 公 I-ORG\n",
- " 党 I-ORG\n",
- " 十 I-ORG\n",
- " 一 I-ORG\n",
- " 大 I-ORG\n",
- " 的 O\n",
- " 贺 O\n",
- " 词 O\n",
- "```\n",
- "\n",
- "这部分内容请参考 快速实现序列标注模型\n",
- "\n",
- "## 3. 使用Bert进行文本匹配\n",
- "\n",
- "文本匹配任务是指给定两句话判断他们的关系。比如,给定两句话判断前一句是否和后一句具有因果关系或是否是矛盾关系;或者给定两句话判断两句话是否 具有相同的意思。这里我们使用"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from fastNLP.io import CNXNLIBertPipe\n",
- "from fastNLP.embeddings import BertEmbedding\n",
- "from fastNLP.models import BertForSentenceMatching\n",
- "from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam\n",
- "from fastNLP.core.optimizer import AdamW\n",
- "from fastNLP.core.callback import WarmupCallback\n",
- "from fastNLP import Tester\n",
- "import torch\n",
- "\n",
- "data_bundle = CNXNLIBertPipe().process_from_file()\n",
- "data_bundle.rename_field('chars', 'words')\n",
- "print(data_bundle)\n",
- "\n",
- "# 载入BertEmbedding\n",
- "embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn-wwm', include_cls_sep=True)\n",
- "\n",
- "# 载入模型\n",
- "model = BertForSentenceMatching(embed, len(data_bundle.get_vocab('target')))\n",
- "\n",
- "# 训练模型\n",
- "callbacks = [WarmupCallback(warmup=0.1, schedule='linear'), ]\n",
- "device = 0 if torch.cuda.is_available() else 'cpu' \n",
- "trainer = Trainer(data_bundle.get_dataset('train'), model,\n",
- " optimizer=AdamW(params=model.parameters(), lr=4e-5),\n",
- " loss=CrossEntropyLoss(), device=device,\n",
- " batch_size=8, dev_data=data_bundle.get_dataset('dev'),\n",
- " metrics=AccuracyMetric(), n_epochs=5, print_every=1,\n",
- " update_every=8, callbacks=callbacks)\n",
- "trainer.train()\n",
- "\n",
- "tester = Tester(data_bundle.get_dataset('test'), model, batch_size=8, metrics=AccuracyMetric())\n",
- "tester.test()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 4. 使用Bert进行中文问答\n",
- "\n",
- "问答任务是给定一段内容,以及一个问题,需要从这段内容中找到答案。 例如:\n",
- "\n",
- "```\n",
- "\"context\": \"锣鼓经是大陆传统器乐及戏曲里面常用的打击乐记谱方法,以中文字的声音模拟敲击乐的声音,纪录打击乐的各种不同的演奏方法。常\n",
- "用的节奏型称为「锣鼓点」。而锣鼓是戏曲节奏的支柱,除了加强演员身段动作的节奏感,也作为音乐的引子和尾声,提示音乐的板式和速度,以及\n",
- "作为唱腔和念白的伴奏,令诗句的韵律更加抑扬顿锉,段落分明。锣鼓的运用有约定俗成的程式,依照角色行当的身份、性格、情绪以及环境,配合\n",
- "相应的锣鼓点。锣鼓亦可以模仿大自然的音响效果,如雷电、波浪等等。戏曲锣鼓所运用的敲击乐器主要分为鼓、锣、钹和板四类型:鼓类包括有单\n",
- "皮鼓(板鼓)、大鼓、大堂鼓(唐鼓)、小堂鼓、怀鼓、花盆鼓等;锣类有大锣、小锣(手锣)、钲锣、筛锣、马锣、镗锣、云锣;钹类有铙钹、大\n",
- "钹、小钹、水钹、齐钹、镲钹、铰子、碰钟等;打拍子用的檀板、木鱼、梆子等。因为京剧的锣鼓通常由四位乐师负责,又称为四大件,领奏的师\n",
- "傅称为:「鼓佬」,其职责有如西方乐队的指挥,负责控制速度以及利用各种手势提示乐师演奏不同的锣鼓点。粤剧吸收了部份京剧的锣鼓,但以木鱼\n",
- "和沙的代替了京剧的板和鼓,作为打拍子的主要乐器。以下是京剧、昆剧和粤剧锣鼓中乐器对应的口诀用字:\",\n",
- "\"question\": \"锣鼓经是什么?\",\n",
- "\"answers\": [\n",
- " {\n",
- " \"text\": \"大陆传统器乐及戏曲里面常用的打击乐记谱方法\",\n",
- " \"answer_start\": 4\n",
- " },\n",
- " {\n",
- " \"text\": \"大陆传统器乐及戏曲里面常用的打击乐记谱方法\",\n",
- " \"answer_start\": 4\n",
- " },\n",
- " {\n",
- " \"text\": \"大陆传统器乐及戏曲里面常用的打击乐记谱方法\",\n",
- " \"answer_start\": 4\n",
- " }\n",
- "]\n",
- "```"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "您可以通过以下的代码训练 (原文代码:[CMRC2018](https://github.com/ymcui/cmrc2018) )"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from fastNLP.embeddings import BertEmbedding\n",
- "from fastNLP.models import BertForQuestionAnswering\n",
- "from fastNLP.core.losses import CMRC2018Loss\n",
- "from fastNLP.core.metrics import CMRC2018Metric\n",
- "from fastNLP.io.pipe.qa import CMRC2018BertPipe\n",
- "from fastNLP import Trainer, BucketSampler\n",
- "from fastNLP import WarmupCallback, GradientClipCallback\n",
- "from fastNLP.core.optimizer import AdamW\n",
- "import torch\n",
- "\n",
- "data_bundle = CMRC2018BertPipe().process_from_file()\n",
- "data_bundle.rename_field('chars', 'words')\n",
- "\n",
- "print(data_bundle)\n",
- "\n",
- "embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn', requires_grad=True, include_cls_sep=False, auto_truncate=True,\n",
- " dropout=0.5, word_dropout=0.01)\n",
- "model = BertForQuestionAnswering(embed)\n",
- "loss = CMRC2018Loss()\n",
- "metric = CMRC2018Metric()\n",
- "\n",
- "wm_callback = WarmupCallback(schedule='linear')\n",
- "gc_callback = GradientClipCallback(clip_value=1, clip_type='norm')\n",
- "callbacks = [wm_callback, gc_callback]\n",
- "\n",
- "optimizer = AdamW(model.parameters(), lr=5e-5)\n",
- "\n",
- "device = 0 if torch.cuda.is_available() else 'cpu' \n",
- "trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer,\n",
- " sampler=BucketSampler(seq_len_field_name='context_len'),\n",
- " dev_data=data_bundle.get_dataset('dev'), metrics=metric,\n",
- " callbacks=callbacks, device=device, batch_size=6, num_workers=2, n_epochs=2, print_every=1,\n",
- " test_use_tqdm=False, update_every=10)\n",
- "trainer.train(load_best_model=False)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "训练结果(和论文中报道的基本一致):\n",
- "\n",
- "```\n",
- " In Epoch:2/Step:1692, got best dev performance:\n",
- " CMRC2018Metric: f1=85.61, em=66.08\n",
- "```"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python Now",
- "language": "python",
- "name": "now"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.0"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/docs/source/_static/notebooks/tutorial_1_data_preprocess.ipynb b/docs/source/_static/notebooks/tutorial_1_data_preprocess.ipynb
deleted file mode 100644
index a987e7f2..00000000
--- a/docs/source/_static/notebooks/tutorial_1_data_preprocess.ipynb
+++ /dev/null
@@ -1,292 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# fastNLP中的DataSet"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "+------------------------------+---------------------------------------------+---------+\n",
- "| raw_words | words | seq_len |\n",
- "+------------------------------+---------------------------------------------+---------+\n",
- "| This is the first instance . | ['this', 'is', 'the', 'first', 'instance... | 6 |\n",
- "| Second instance . | ['Second', 'instance', '.'] | 3 |\n",
- "| Third instance . | ['Third', 'instance', '.'] | 3 |\n",
- "+------------------------------+---------------------------------------------+---------+\n"
- ]
- }
- ],
- "source": [
- "from fastNLP import DataSet\n",
- "data = {'raw_words':[\"This is the first instance .\", \"Second instance .\", \"Third instance .\"],\n",
- " 'words': [['this', 'is', 'the', 'first', 'instance', '.'], ['Second', 'instance', '.'], ['Third', 'instance', '.']],\n",
- " 'seq_len': [6, 3, 3]}\n",
- "dataset = DataSet(data)\n",
- "# 传入的dict的每个key的value应该为具有相同长度的list\n",
- "print(dataset)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## DataSet的构建"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "+----------------------------+---------------------------------------------+---------+\n",
- "| raw_words | words | seq_len |\n",
- "+----------------------------+---------------------------------------------+---------+\n",
- "| This is the first instance | ['this', 'is', 'the', 'first', 'instance... | 6 |\n",
- "+----------------------------+---------------------------------------------+---------+"
- ]
- },
- "execution_count": 2,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from fastNLP import DataSet\n",
- "from fastNLP import Instance\n",
- "dataset = DataSet()\n",
- "instance = Instance(raw_words=\"This is the first instance\",\n",
- " words=['this', 'is', 'the', 'first', 'instance', '.'],\n",
- " seq_len=6)\n",
- "dataset.append(instance)\n",
- "dataset"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "+----------------------------+---------------------------------------------+---------+\n",
- "| raw_words | words | seq_len |\n",
- "+----------------------------+---------------------------------------------+---------+\n",
- "| This is the first instance | ['this', 'is', 'the', 'first', 'instance... | 6 |\n",
- "| Second instance . | ['Second', 'instance', '.'] | 3 |\n",
- "+----------------------------+---------------------------------------------+---------+"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from fastNLP import DataSet\n",
- "from fastNLP import Instance\n",
- "dataset = DataSet([\n",
- " Instance(raw_words=\"This is the first instance\",\n",
- " words=['this', 'is', 'the', 'first', 'instance', '.'],\n",
- " seq_len=6),\n",
- " Instance(raw_words=\"Second instance .\",\n",
- " words=['Second', 'instance', '.'],\n",
- " seq_len=3)\n",
- " ])\n",
- "dataset"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## DataSet的删除"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "+----+---+\n",
- "| a | c |\n",
- "+----+---+\n",
- "| -5 | 0 |\n",
- "| -4 | 0 |\n",
- "| -3 | 0 |\n",
- "| -2 | 0 |\n",
- "| -1 | 0 |\n",
- "| 0 | 0 |\n",
- "| 1 | 0 |\n",
- "| 2 | 0 |\n",
- "| 3 | 0 |\n",
- "| 4 | 0 |\n",
- "+----+---+"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from fastNLP import DataSet\n",
- "dataset = DataSet({'a': range(-5, 5), 'c': [0]*10})\n",
- "dataset"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "+---+\n",
- "| c |\n",
- "+---+\n",
- "| 0 |\n",
- "| 0 |\n",
- "| 0 |\n",
- "| 0 |\n",
- "+---+"
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# 不改变dataset,生成一个删除了满足条件的instance的新 DataSet\n",
- "dropped_dataset = dataset.drop(lambda ins:ins['a']<0, inplace=False)\n",
- "# 在dataset中删除满足条件的instance\n",
- "dataset.drop(lambda ins:ins['a']<0)\n",
- "# 删除第3个instance\n",
- "dataset.delete_instance(2)\n",
- "# 删除名为'a'的field\n",
- "dataset.delete_field('a')\n",
- "dataset"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 简单的数据预处理"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "False\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "4"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# 检查是否存在名为'a'的field\n",
- "print(dataset.has_field('a')) # 或 ('a' in dataset)\n",
- "# 将名为'a'的field改名为'b'\n",
- "dataset.rename_field('c', 'b')\n",
- "# DataSet的长度\n",
- "len(dataset)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "+------------------------------+-------------------------------------------------+\n",
- "| raw_words | words |\n",
- "+------------------------------+-------------------------------------------------+\n",
- "| This is the first instance . | ['This', 'is', 'the', 'first', 'instance', '.'] |\n",
- "| Second instance . | ['Second', 'instance', '.'] |\n",
- "| Third instance . | ['Third', 'instance', '.'] |\n",
- "+------------------------------+-------------------------------------------------+"
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from fastNLP import DataSet\n",
- "data = {'raw_words':[\"This is the first instance .\", \"Second instance .\", \"Third instance .\"]}\n",
- "dataset = DataSet(data)\n",
- "\n",
- "# 将句子分成单词形式, 详见DataSet.apply()方法\n",
- "dataset.apply(lambda ins: ins['raw_words'].split(), new_field_name='words')\n",
- "\n",
- "# 或使用DataSet.apply_field()\n",
- "dataset.apply_field(lambda sent:sent.split(), field_name='raw_words', new_field_name='words')\n",
- "\n",
- "# 除了匿名函数,也可以定义函数传递进去\n",
- "def get_words(instance):\n",
- " sentence = instance['raw_words']\n",
- " words = sentence.split()\n",
- " return words\n",
- "dataset.apply(get_words, new_field_name='words')\n",
- "dataset"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python Now",
- "language": "python",
- "name": "now"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.0"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/docs/source/_static/notebooks/tutorial_2_vocabulary.ipynb b/docs/source/_static/notebooks/tutorial_2_vocabulary.ipynb
deleted file mode 100644
index 50862293..00000000
--- a/docs/source/_static/notebooks/tutorial_2_vocabulary.ipynb
+++ /dev/null
@@ -1,343 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# fastNLP中的 Vocabulary\n",
- "## 构建 Vocabulary"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "from fastNLP import Vocabulary\n",
- "\n",
- "vocab = Vocabulary()\n",
- "vocab.add_word_lst(['复', '旦', '大', '学']) # 加入新的字\n",
- "vocab.add_word('上海') # `上海`会作为一个整体\n",
- "vocab.to_index('复') # 应该会为3\n",
- "vocab.to_index('我') # 会输出1,Vocabulary中默认pad的index为0, unk(没有找到的词)的index为1\n",
- "\n",
- "# 在构建target的Vocabulary时,词表中应该用不上pad和unk,可以通过以下的初始化\n",
- "vocab = Vocabulary(unknown=None, padding=None)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Vocabulary(['positive', 'negative']...)"
- ]
- },
- "execution_count": 2,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "vocab.add_word_lst(['positive', 'negative'])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "0"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "vocab.to_index('positive')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 没有设置 unk 的情况"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "ename": "ValueError",
- "evalue": "word `neutral` not in vocabulary",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mvocab\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'neutral'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# 会报错,因为没有unk这种情况\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;32m~/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/core/vocabulary.py\u001b[0m in \u001b[0;36mto_index\u001b[0;34m(self, w)\u001b[0m\n\u001b[1;32m 414\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;32mreturn\u001b[0m \u001b[0mint\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mnumber\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 415\u001b[0m \"\"\"\n\u001b[0;32m--> 416\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__getitem__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 417\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 418\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/core/vocabulary.py\u001b[0m in \u001b[0;36m_wrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_word2idx\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrebuild\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuild_vocab\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_wrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/core/vocabulary.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, w)\u001b[0m\n\u001b[1;32m 272\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_word2idx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munknown\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 273\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 274\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"word `{}` not in vocabulary\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 275\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 276\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0m_check_build_vocab\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;31mValueError\u001b[0m: word `neutral` not in vocabulary"
- ]
- }
- ],
- "source": [
- "vocab.to_index('neutral') # 会报错,因为没有unk这种情况"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 设置 unk 的情况"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 25,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(0, '')"
- ]
- },
- "execution_count": 25,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from fastNLP import Vocabulary\n",
- "\n",
- "vocab = Vocabulary(unknown='', padding=None)\n",
- "vocab.add_word_lst(['positive', 'negative'])\n",
- "vocab.to_index('neutral'), vocab.to_word(vocab.to_index('neutral'))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Vocabulary(['positive', 'negative']...)"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "vocab"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "+---------------------------------------------------+--------+\n",
- "| chars | target |\n",
- "+---------------------------------------------------+--------+\n",
- "| [4, 2, 2, 5, 6, 7, 3] | 0 |\n",
- "| [8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 3] | 1 |\n",
- "+---------------------------------------------------+--------+\n"
- ]
- }
- ],
- "source": [
- "from fastNLP import Vocabulary\n",
- "from fastNLP import DataSet\n",
- "\n",
- "dataset = DataSet({'chars': [\n",
- " ['今', '天', '天', '气', '很', '好', '。'],\n",
- " ['被', '这', '部', '电', '影', '浪', '费', '了', '两', '个', '小', '时', '。']\n",
- " ],\n",
- " 'target': ['neutral', 'negative']\n",
- "})\n",
- "\n",
- "vocab = Vocabulary()\n",
- "vocab.from_dataset(dataset, field_name='chars')\n",
- "vocab.index_dataset(dataset, field_name='chars')\n",
- "\n",
- "target_vocab = Vocabulary(padding=None, unknown=None)\n",
- "target_vocab.from_dataset(dataset, field_name='target')\n",
- "target_vocab.index_dataset(dataset, field_name='target')\n",
- "print(dataset)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Vocabulary(['今', '天', '心', '情', '很']...)"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from fastNLP import Vocabulary\n",
- "from fastNLP import DataSet\n",
- "\n",
- "tr_data = DataSet({'chars': [\n",
- " ['今', '天', '心', '情', '很', '好', '。'],\n",
- " ['被', '这', '部', '电', '影', '浪', '费', '了', '两', '个', '小', '时', '。']\n",
- " ],\n",
- " 'target': ['positive', 'negative']\n",
- "})\n",
- "dev_data = DataSet({'chars': [\n",
- " ['住', '宿', '条', '件', '还', '不', '错'],\n",
- " ['糟', '糕', '的', '天', '气', ',', '无', '法', '出', '行', '。']\n",
- " ],\n",
- " 'target': ['positive', 'negative']\n",
- "})\n",
- "\n",
- "vocab = Vocabulary()\n",
- "# 将验证集或者测试集在建立词表是放入no_create_entry_dataset这个参数中。\n",
- "vocab.from_dataset(tr_data, field_name='chars', no_create_entry_dataset=[dev_data])\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " 4%|▎ | 2.31M/63.5M [00:00<00:02, 22.9MB/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "http://212.129.155.247/embedding/glove.6B.50d.zip not found in cache, downloading to /tmp/tmpvziobj_e\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████| 63.5M/63.5M [00:01<00:00, 41.3MB/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Finish download from http://212.129.155.247/embedding/glove.6B.50d.zip\n",
- "Copy file to /remote-home/ynzheng/.fastNLP/embedding/glove.6B.50d\n",
- "Found 2 out of 6 words in the pre-training embedding.\n",
- "tensor([[ 0.9497, 0.3433, 0.8450, -0.8852, -0.7208, -0.2931, -0.7468, 0.6512,\n",
- " 0.4730, -0.7401, 0.1877, -0.3828, -0.5590, 0.4295, -0.2698, -0.4238,\n",
- " -0.3124, 1.3423, -0.7857, -0.6302, 0.9182, 0.2113, -0.5744, 1.4549,\n",
- " 0.7546, -1.6165, -0.0085, 0.0029, 0.5130, -0.4745, 2.5306, 0.8594,\n",
- " -0.3067, 0.0578, 0.6623, 0.2080, 0.6424, -0.5246, -0.0534, 1.1404,\n",
- " -0.1370, -0.1836, 0.4546, -0.5096, -0.0255, -0.0286, 0.1805, -0.4483,\n",
- " 0.4053, -0.3682]], grad_fn=)\n",
- "tensor([[ 0.1320, -0.2392, 0.1732, -0.2390, -0.0463, 0.0494, 0.0488, -0.0886,\n",
- " 0.0224, -0.1300, 0.0369, 0.1800, 0.0750, -0.0183, 0.2264, 0.1628,\n",
- " 0.1261, -0.1259, 0.1663, -0.1230, -0.1904, -0.0532, 0.1397, -0.0259,\n",
- " -0.1799, 0.0226, 0.1858, 0.1981, 0.1338, 0.2394, 0.0248, 0.0203,\n",
- " -0.1722, -0.1683, -0.1892, 0.0874, 0.0562, -0.0394, 0.0306, -0.1761,\n",
- " 0.1015, -0.0171, 0.1172, 0.1357, 0.1519, -0.0011, 0.1572, 0.1265,\n",
- " -0.2391, -0.0258]], grad_fn=)\n",
- "tensor([[ 0.1318, -0.2552, -0.0679, 0.2619, -0.2616, 0.2357, 0.1308, -0.0118,\n",
- " 1.7659, 0.2078, 0.2620, -0.1643, -0.8464, 0.0201, 0.0702, 0.3978,\n",
- " 0.1528, -0.2021, -1.6184, -0.5433, -0.1786, 0.5389, 0.4987, -0.1017,\n",
- " 0.6626, -1.7051, 0.0572, -0.3241, -0.6683, 0.2665, 2.8420, 0.2684,\n",
- " -0.5954, -0.5004, 1.5199, 0.0396, 1.6659, 0.9976, -0.5597, -0.7049,\n",
- " -0.0309, -0.2830, -0.1356, 0.6429, 0.4149, 1.2362, 0.7659, 0.9780,\n",
- " 0.5851, -0.3018]], grad_fn=)\n",
- "tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
- " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
- " 0., 0.]], grad_fn=)\n",
- "tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
- " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
- " 0., 0.]], grad_fn=)\n"
- ]
- }
- ],
- "source": [
- "import torch\n",
- "from fastNLP.embeddings import StaticEmbedding\n",
- "from fastNLP import Vocabulary\n",
- "\n",
- "vocab = Vocabulary()\n",
- "vocab.add_word('train')\n",
- "vocab.add_word('only_in_train') # 仅在train出现,但肯定在预训练词表中不存在\n",
- "vocab.add_word('test', no_create_entry=True) # 该词只在dev或test中出现\n",
- "vocab.add_word('only_in_test', no_create_entry=True) # 这个词在预训练的词表中找不到\n",
- "\n",
- "embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d')\n",
- "print(embed(torch.LongTensor([vocab.to_index('train')])))\n",
- "print(embed(torch.LongTensor([vocab.to_index('only_in_train')])))\n",
- "print(embed(torch.LongTensor([vocab.to_index('test')])))\n",
- "print(embed(torch.LongTensor([vocab.to_index('only_in_test')])))\n",
- "print(embed(torch.LongTensor([vocab.unknown_idx])))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python Now",
- "language": "python",
- "name": "now"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.0"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/docs/source/_static/notebooks/tutorial_3_embedding.ipynb b/docs/source/_static/notebooks/tutorial_3_embedding.ipynb
deleted file mode 100644
index 154a0756..00000000
--- a/docs/source/_static/notebooks/tutorial_3_embedding.ipynb
+++ /dev/null
@@ -1,524 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Found 5 out of 7 words in the pre-training embedding.\n",
- "torch.Size([1, 5, 50])\n"
- ]
- }
- ],
- "source": [
- "import torch\n",
- "from fastNLP.embeddings import StaticEmbedding\n",
- "from fastNLP import Vocabulary\n",
- "\n",
- "vocab = Vocabulary()\n",
- "vocab.add_word_lst(\"this is a demo .\".split())\n",
- "\n",
- "embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d')\n",
- "\n",
- "words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]]) # 将文本转为index\n",
- "print(embed(words).size()) # StaticEmbedding的使用和pytorch的nn.Embedding是类似的"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "torch.Size([1, 5, 30])\n"
- ]
- }
- ],
- "source": [
- "from fastNLP.embeddings import StaticEmbedding\n",
- "from fastNLP import Vocabulary\n",
- "\n",
- "vocab = Vocabulary()\n",
- "vocab.add_word_lst(\"this is a demo .\".split())\n",
- "\n",
- "embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=30)\n",
- "\n",
- "words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]])\n",
- "print(embed(words).size())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "22 out of 22 characters were found in pretrained elmo embedding.\n",
- "torch.Size([1, 5, 256])\n"
- ]
- }
- ],
- "source": [
- "from fastNLP.embeddings import ElmoEmbedding\n",
- "from fastNLP import Vocabulary\n",
- "\n",
- "vocab = Vocabulary()\n",
- "vocab.add_word_lst(\"this is a demo .\".split())\n",
- "\n",
- "embed = ElmoEmbedding(vocab, model_dir_or_name='en-small', requires_grad=False)\n",
- "words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]])\n",
- "print(embed(words).size())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "22 out of 22 characters were found in pretrained elmo embedding.\n",
- "torch.Size([1, 5, 512])\n"
- ]
- }
- ],
- "source": [
- "embed = ElmoEmbedding(vocab, model_dir_or_name='en-small', requires_grad=False, layers='1,2')\n",
- "print(embed(words).size())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "22 out of 22 characters were found in pretrained elmo embedding.\n",
- "torch.Size([1, 5, 256])\n"
- ]
- }
- ],
- "source": [
- "embed = ElmoEmbedding(vocab, model_dir_or_name='en-small', requires_grad=True, layers='mix')\n",
- "print(embed(words).size()) # 三层输出按照权重element-wise的加起来"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
- "Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
- "Start to generate word pieces for word.\n",
- "Found(Or segment into word pieces) 7 words out of 7.\n",
- "torch.Size([1, 5, 768])\n"
- ]
- }
- ],
- "source": [
- "from fastNLP.embeddings import BertEmbedding\n",
- "from fastNLP import Vocabulary\n",
- "\n",
- "vocab = Vocabulary()\n",
- "vocab.add_word_lst(\"this is a demo .\".split())\n",
- "\n",
- "embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased')\n",
- "words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]])\n",
- "print(embed(words).size())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
- "Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
- "Start to generate word pieces for word.\n",
- "Found(Or segment into word pieces) 7 words out of 7.\n",
- "torch.Size([1, 5, 1536])\n"
- ]
- }
- ],
- "source": [
- "# 使用后面两层的输出\n",
- "embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='10,11')\n",
- "print(embed(words).size()) # 结果将是在最后一维做拼接"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
- "Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
- "Start to generate word pieces for word.\n",
- "Found(Or segment into word pieces) 7 words out of 7.\n",
- "torch.Size([1, 7, 768])\n"
- ]
- }
- ],
- "source": [
- "embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='-1', include_cls_sep=True)\n",
- "print(embed(words).size()) # 结果将在序列维度上增加2\n",
- "# 取出句子的cls表示\n",
- "cls_reps = embed(words)[:, 0] # shape: [batch_size, 768]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
- "Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
- "Start to generate word pieces for word.\n",
- "Found(Or segment into word pieces) 7 words out of 7.\n",
- "torch.Size([1, 5, 768])\n"
- ]
- }
- ],
- "source": [
- "embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='-1', pool_method='max')\n",
- "print(embed(words).size())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
- "Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
- "Start to generate word pieces for word.\n",
- "Found(Or segment into word pieces) 10 words out of 10.\n",
- "torch.Size([1, 9, 768])\n"
- ]
- }
- ],
- "source": [
- "vocab = Vocabulary()\n",
- "vocab.add_word_lst(\"this is a demo . [SEP] another sentence .\".split())\n",
- "\n",
- "embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='-1', pool_method='max')\n",
- "words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo . [SEP] another sentence .\".split()]])\n",
- "print(embed(words).size())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Start constructing character vocabulary.\n",
- "In total, there are 8 distinct characters.\n",
- "torch.Size([1, 5, 64])\n"
- ]
- }
- ],
- "source": [
- "from fastNLP.embeddings import CNNCharEmbedding\n",
- "from fastNLP import Vocabulary\n",
- "\n",
- "vocab = Vocabulary()\n",
- "vocab.add_word_lst(\"this is a demo .\".split())\n",
- "\n",
- "# character的embedding维度大小为50,返回的embedding结果维度大小为64。\n",
- "embed = CNNCharEmbedding(vocab, embed_size=64, char_emb_size=50)\n",
- "words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]])\n",
- "print(embed(words).size())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Start constructing character vocabulary.\n",
- "In total, there are 8 distinct characters.\n",
- "torch.Size([1, 5, 64])\n"
- ]
- }
- ],
- "source": [
- "from fastNLP.embeddings import LSTMCharEmbedding\n",
- "from fastNLP import Vocabulary\n",
- "\n",
- "vocab = Vocabulary()\n",
- "vocab.add_word_lst(\"this is a demo .\".split())\n",
- "\n",
- "# character的embedding维度大小为50,返回的embedding结果维度大小为64。\n",
- "embed = LSTMCharEmbedding(vocab, embed_size=64, char_emb_size=50)\n",
- "words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]])\n",
- "print(embed(words).size())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Found 5 out of 7 words in the pre-training embedding.\n",
- "50\n",
- "Start constructing character vocabulary.\n",
- "In total, there are 8 distinct characters.\n",
- "30\n",
- "22 out of 22 characters were found in pretrained elmo embedding.\n",
- "256\n",
- "22 out of 22 characters were found in pretrained elmo embedding.\n",
- "512\n",
- "loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
- "Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
- "Start to generate word pieces for word.\n",
- "Found(Or segment into word pieces) 7 words out of 7.\n",
- "768\n",
- "loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
- "Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
- "Start to generate word pieces for word.\n",
- "Found(Or segment into word pieces) 7 words out of 7.\n",
- "1536\n",
- "80\n"
- ]
- }
- ],
- "source": [
- "from fastNLP.embeddings import *\n",
- "\n",
- "vocab = Vocabulary()\n",
- "vocab.add_word_lst(\"this is a demo .\".split())\n",
- "\n",
- "static_embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d')\n",
- "print(static_embed.embedding_dim) # 50\n",
- "char_embed = CNNCharEmbedding(vocab, embed_size=30)\n",
- "print(char_embed.embedding_dim) # 30\n",
- "elmo_embed_1 = ElmoEmbedding(vocab, model_dir_or_name='en-small', layers='2')\n",
- "print(elmo_embed_1.embedding_dim) # 256\n",
- "elmo_embed_2 = ElmoEmbedding(vocab, model_dir_or_name='en-small', layers='1,2')\n",
- "print(elmo_embed_2.embedding_dim) # 512\n",
- "bert_embed_1 = BertEmbedding(vocab, layers='-1', model_dir_or_name='en-base-cased')\n",
- "print(bert_embed_1.embedding_dim) # 768\n",
- "bert_embed_2 = BertEmbedding(vocab, layers='2,-1', model_dir_or_name='en-base-cased')\n",
- "print(bert_embed_2.embedding_dim) # 1536\n",
- "stack_embed = StackEmbedding([static_embed, char_embed])\n",
- "print(stack_embed.embedding_dim) # 80"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
- "Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
- "Start to generate word pieces for word.\n",
- "Found(Or segment into word pieces) 7 words out of 7.\n"
- ]
- }
- ],
- "source": [
- "from fastNLP.embeddings import *\n",
- "\n",
- "vocab = Vocabulary()\n",
- "vocab.add_word_lst(\"this is a demo .\".split())\n",
- "\n",
- "embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', requires_grad=True) # 初始化时设定为需要更新\n",
- "embed.requires_grad = False # 修改BertEmbedding的权重为不更新"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor([[ 0.3633, -0.2091, -0.0353, -0.3771, -0.5193]],\n",
- " grad_fn=)\n",
- "tensor([[ 0.0926, -0.4812, -0.7744, 0.4836, -0.5475]],\n",
- " grad_fn=)\n"
- ]
- }
- ],
- "source": [
- "from fastNLP.embeddings import StaticEmbedding\n",
- "from fastNLP import Vocabulary\n",
- "\n",
- "vocab = Vocabulary().add_word_lst(\"The the a A\".split())\n",
- "# 下面用随机的StaticEmbedding演示,但与使用预训练词向量时效果是一致的\n",
- "embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5)\n",
- "print(embed(torch.LongTensor([vocab.to_index('The')])))\n",
- "print(embed(torch.LongTensor([vocab.to_index('the')])))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "All word in the vocab have been lowered. There are 6 words, 4 unique lowered words.\n",
- "tensor([[ 0.4530, -0.1558, -0.1941, 0.3203, 0.0355]],\n",
- " grad_fn=)\n",
- "tensor([[ 0.4530, -0.1558, -0.1941, 0.3203, 0.0355]],\n",
- " grad_fn=)\n"
- ]
- }
- ],
- "source": [
- "from fastNLP.embeddings import StaticEmbedding\n",
- "from fastNLP import Vocabulary\n",
- "\n",
- "vocab = Vocabulary().add_word_lst(\"The the a A\".split())\n",
- "# 下面用随机的StaticEmbedding演示,但与使用预训练时效果是一致的\n",
- "embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5, lower=True)\n",
- "print(embed(torch.LongTensor([vocab.to_index('The')])))\n",
- "print(embed(torch.LongTensor([vocab.to_index('the')])))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "1 out of 4 words have frequency less than 2.\n",
- "tensor([[ 0.4724, -0.7277, -0.6350, -0.5258, -0.6063]],\n",
- " grad_fn=)\n",
- "tensor([[ 0.7638, -0.0552, 0.1625, -0.2210, 0.4993]],\n",
- " grad_fn=)\n",
- "tensor([[ 0.7638, -0.0552, 0.1625, -0.2210, 0.4993]],\n",
- " grad_fn=)\n"
- ]
- }
- ],
- "source": [
- "from fastNLP.embeddings import StaticEmbedding\n",
- "from fastNLP import Vocabulary\n",
- "\n",
- "vocab = Vocabulary().add_word_lst(\"the the the a\".split())\n",
- "# 下面用随机的StaticEmbedding演示,但与使用预训练时效果是一致的\n",
- "embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5, min_freq=2)\n",
- "print(embed(torch.LongTensor([vocab.to_index('the')])))\n",
- "print(embed(torch.LongTensor([vocab.to_index('a')])))\n",
- "print(embed(torch.LongTensor([vocab.unknown_idx])))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "0 out of 5 words have frequency less than 2.\n",
- "All word in the vocab have been lowered. There are 5 words, 4 unique lowered words.\n",
- "tensor([[ 0.1943, 0.3739, 0.2769, -0.4746, -0.3181]],\n",
- " grad_fn=)\n",
- "tensor([[ 0.5892, -0.6916, 0.7319, -0.3803, 0.4979]],\n",
- " grad_fn=)\n",
- "tensor([[ 0.5892, -0.6916, 0.7319, -0.3803, 0.4979]],\n",
- " grad_fn=)\n",
- "tensor([[-0.1348, -0.2172, -0.0071, 0.5704, -0.2607]],\n",
- " grad_fn=)\n"
- ]
- }
- ],
- "source": [
- "from fastNLP.embeddings import StaticEmbedding\n",
- "from fastNLP import Vocabulary\n",
- "\n",
- "vocab = Vocabulary().add_word_lst(\"the the the a A\".split())\n",
- "# 下面用随机的StaticEmbedding演示,但与使用预训练时效果是一致的\n",
- "embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5, min_freq=2, lower=True)\n",
- "print(embed(torch.LongTensor([vocab.to_index('the')])))\n",
- "print(embed(torch.LongTensor([vocab.to_index('a')])))\n",
- "print(embed(torch.LongTensor([vocab.to_index('A')])))\n",
- "print(embed(torch.LongTensor([vocab.unknown_idx])))"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python Now",
- "language": "python",
- "name": "now"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.0"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/docs/source/_static/notebooks/tutorial_4_load_dataset.ipynb b/docs/source/_static/notebooks/tutorial_4_load_dataset.ipynb
deleted file mode 100644
index f6de83bc..00000000
--- a/docs/source/_static/notebooks/tutorial_4_load_dataset.ipynb
+++ /dev/null
@@ -1,309 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 使用Loader和Pipe加载并处理数据集\n",
- "\n",
- "这一部分是关于如何加载数据集的教程\n",
- "\n",
- "## Part I: 数据集容器DataBundle\n",
- "\n",
- "而由于对于同一个任务,训练集,验证集和测试集会共用同一个词表以及具有相同的目标值,所以在fastNLP中我们使用了 DataBundle 来承载同一个任务的多个数据集 DataSet 以及它们的词表 Vocabulary 。下面会有例子介绍 DataBundle 的相关使用。\n",
- "\n",
- "DataBundle 在fastNLP中主要在各个 Loader 和 Pipe 中被使用。 下面我们先介绍一下 Loader 和 Pipe 。\n",
- "\n",
- "## Part II: 加载的各种数据集的Loader\n",
- "\n",
- "在fastNLP中,所有的 Loader 都可以通过其文档判断其支持读取的数据格式,以及读取之后返回的 DataSet 的格式, 例如 ChnSentiCorpLoader \n",
- "\n",
- "- download() 函数:自动将该数据集下载到缓存地址,默认缓存地址为~/.fastNLP/datasets/。由于版权等原因,不是所有的Loader都实现了该方法。该方法会返回下载后文件所处的缓存地址。\n",
- "\n",
- "- _load() 函数:从一个数据文件中读取数据,返回一个 DataSet 。返回的DataSet的格式可从Loader文档判断。\n",
- "\n",
- "- load() 函数:从文件或者文件夹中读取数据为 DataSet 并将它们组装成 DataBundle。支持接受的参数类型有以下的几种\n",
- "\n",
- " - None, 将尝试读取自动缓存的数据,仅支持提供了自动下载数据的Loader\n",
- " - 文件夹路径, 默认将尝试在该文件夹下匹配文件名中含有 train , test , dev 的文件,如果有多个文件含有相同的关键字,将无法通过该方式读取\n",
- " - dict, 例如{'train':\"/path/to/tr.conll\", 'dev':\"/to/validate.conll\", \"test\":\"/to/te.conll\"}。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "In total 3 datasets:\n",
- "\ttest has 1944 instances.\n",
- "\ttrain has 17196 instances.\n",
- "\tdev has 1858 instances.\n",
- "\n"
- ]
- }
- ],
- "source": [
- "from fastNLP.io import CWSLoader\n",
- "\n",
- "loader = CWSLoader(dataset_name='pku')\n",
- "data_bundle = loader.load()\n",
- "print(data_bundle)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "这里表示一共有3个数据集。其中:\n",
- "\n",
- " 3个数据集的名称分别为train、dev、test,分别有17223、1831、1944个instance\n",
- "\n",
- "也可以取出DataSet,并打印DataSet中的具体内容"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "+----------------------------------------------------------------+\n",
- "| raw_words |\n",
- "+----------------------------------------------------------------+\n",
- "| 迈向 充满 希望 的 新 世纪 —— 一九九八年 新年 讲话 ... |\n",
- "| 中共中央 总书记 、 国家 主席 江 泽民 |\n",
- "+----------------------------------------------------------------+\n"
- ]
- }
- ],
- "source": [
- "tr_data = data_bundle.get_dataset('train')\n",
- "print(tr_data[:2])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Part III: 使用Pipe对数据集进行预处理\n",
- "\n",
- "通过 Loader 可以将文本数据读入,但并不能直接被神经网络使用,还需要进行一定的预处理。\n",
- "\n",
- "在fastNLP中,我们使用 Pipe 的子类作为数据预处理的类, Loader 和 Pipe 一般具备一一对应的关系,该关系可以从其名称判断, 例如 CWSLoader 与 CWSPipe 是一一对应的。一般情况下Pipe处理包含以下的几个过程,\n",
- "1. 将raw_words或 raw_chars进行tokenize以切分成不同的词或字; \n",
- "2. 再建立词或字的 Vocabulary , 并将词或字转换为index; \n",
- "3. 将target 列建立词表并将target列转为index;\n",
- "\n",
- "所有的Pipe都可通过其文档查看该Pipe支持处理的 DataSet 以及返回的 DataBundle 中的Vocabulary的情况; 如 OntoNotesNERPipe\n",
- "\n",
- "各种数据集的Pipe当中,都包含了以下的两个函数:\n",
- "\n",
- "- process() 函数:对输入的 DataBundle 进行处理, 然后返回处理之后的 DataBundle 。process函数的文档中包含了该Pipe支持处理的DataSet的格式。\n",
- "- process_from_file() 函数:输入数据集所在文件夹,使用对应的Loader读取数据(所以该函数支持的参数类型是由于其对应的Loader的load函数决定的),然后调用相对应的process函数对数据进行预处理。相当于是把Load和process放在一个函数中执行。\n",
- "\n",
- "接着上面 CWSLoader 的例子,我们展示一下 CWSPipe 的功能:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "In total 3 datasets:\n",
- "\ttest has 1944 instances.\n",
- "\ttrain has 17196 instances.\n",
- "\tdev has 1858 instances.\n",
- "In total 2 vocabs:\n",
- "\tchars has 4777 entries.\n",
- "\ttarget has 4 entries.\n",
- "\n"
- ]
- }
- ],
- "source": [
- "from fastNLP.io import CWSPipe\n",
- "\n",
- "data_bundle = CWSPipe().process(data_bundle)\n",
- "print(data_bundle)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "表示一共有3个数据集和2个词表。其中:\n",
- "\n",
- "- 3个数据集的名称分别为train、dev、test,分别有17223、1831、1944个instance\n",
- "- 2个词表分别为chars词表与target词表。其中chars词表为句子文本所构建的词表,一共有4777个不同的字;target词表为目标标签所构建的词表,一共有4种标签。\n",
- "\n",
- "相较于之前CWSLoader读取的DataBundle,新增了两个Vocabulary。 我们可以打印一下处理之后的DataSet"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "+---------------------+---------------------+---------------------+---------+\n",
- "| raw_words | chars | target | seq_len |\n",
- "+---------------------+---------------------+---------------------+---------+\n",
- "| 迈向 充满 希望... | [1224, 178, 674,... | [0, 1, 0, 1, 0, ... | 29 |\n",
- "| 中共中央 总书记... | [11, 212, 11, 33... | [0, 3, 3, 1, 0, ... | 15 |\n",
- "+---------------------+---------------------+---------------------+---------+\n"
- ]
- }
- ],
- "source": [
- "tr_data = data_bundle.get_dataset('train')\n",
- "print(tr_data[:2])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "可以看到有两列为int的field: chars和target。这两列的名称同时也是DataBundle中的Vocabulary的名称。可以通过下列的代码获取并查看Vocabulary的 信息"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Vocabulary(['B', 'E', 'S', 'M']...)\n"
- ]
- }
- ],
- "source": [
- "vocab = data_bundle.get_vocab('target')\n",
- "print(vocab)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Part IV: fastNLP封装好的Loader和Pipe\n",
- "\n",
- "fastNLP封装了多种任务/数据集的 Loader 和 Pipe 并提供自动下载功能,具体参见文档 [数据集](https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0)\n",
- "\n",
- "## Part V: 不同格式类型的基础Loader\n",
- "\n",
- "除了上面提到的针对具体任务的Loader,我们还提供了CSV格式和JSON格式的Loader\n",
- "\n",
- "**CSVLoader** 读取CSV类型的数据集文件。例子如下:\n",
- "\n",
- "```python\n",
- "from fastNLP.io.loader import CSVLoader\n",
- "data_set_loader = CSVLoader(\n",
- " headers=('raw_words', 'target'), sep='\\t'\n",
- ")\n",
- "```\n",
- "\n",
- "表示将CSV文件中每一行的第一项将填入'raw_words' field,第二项填入'target' field。其中项之间由'\\t'分割开来\n",
- "\n",
- "```python\n",
- "data_set = data_set_loader._load('path/to/your/file')\n",
- "```\n",
- "\n",
- "文件内容样例如下\n",
- "\n",
- "```csv\n",
- "But it does not leave you with much . 1\n",
- "You could hate it for the same reason . 1\n",
- "The performances are an absolute joy . 4\n",
- "```\n",
- "\n",
- "读取之后的DataSet具有以下的field\n",
- "\n",
- "| raw_words | target |\n",
- "| --------------------------------------- | ------ |\n",
- "| But it does not leave you with much . | 1 |\n",
- "| You could hate it for the same reason . | 1 |\n",
- "| The performances are an absolute joy . | 4 |\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "**JsonLoader** 读取Json类型的数据集文件,数据必须按行存储,每行是一个包含各类属性的Json对象。例子如下\n",
- "\n",
- "```python\n",
- "from fastNLP.io.loader import JsonLoader\n",
- "loader = JsonLoader(\n",
- " fields={'sentence1': 'raw_words1', 'sentence2': 'raw_words2', 'gold_label': 'target'}\n",
- ")\n",
- "```\n",
- "\n",
- "表示将Json对象中'sentence1'、'sentence2'和'gold_label'对应的值赋给'raw_words1'、'raw_words2'、'target'这三个fields\n",
- "\n",
- "```python\n",
- "data_set = loader._load('path/to/your/file')\n",
- "```\n",
- "\n",
- "数据集内容样例如下\n",
- "```\n",
- "{\"annotator_labels\": [\"neutral\"], \"captionID\": \"3416050480.jpg#4\", \"gold_label\": \"neutral\", ... }\n",
- "{\"annotator_labels\": [\"contradiction\"], \"captionID\": \"3416050480.jpg#4\", \"gold_label\": \"contradiction\", ... }\n",
- "{\"annotator_labels\": [\"entailment\"], \"captionID\": \"3416050480.jpg#4\", \"gold_label\": \"entailment\", ... }\n",
- "```\n",
- "\n",
- "读取之后的DataSet具有以下的field\n",
- "\n",
- "| raw_words0 | raw_words1 | target |\n",
- "| ------------------------------------------------------ | ------------------------------------------------- | ------------- |\n",
- "| A person on a horse jumps over a broken down airplane. | A person is training his horse for a competition. | neutral |\n",
- "| A person on a horse jumps over a broken down airplane. | A person is at a diner, ordering an omelette. | contradiction |\n",
- "| A person on a horse jumps over a broken down airplane. | A person is outdoors, on a horse. | entailment |"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python Now",
- "language": "python",
- "name": "now"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.0"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/docs/source/_static/notebooks/tutorial_5_loss_optimizer.ipynb b/docs/source/_static/notebooks/tutorial_5_loss_optimizer.ipynb
deleted file mode 100644
index cba78175..00000000
--- a/docs/source/_static/notebooks/tutorial_5_loss_optimizer.ipynb
+++ /dev/null
@@ -1,603 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 使用Trainer和Tester快速训练和测试"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 数据读入和处理"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/remote-home/ynzheng/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/io/loader/classification.py:340: UserWarning: SST2's test file has no target.\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "In total 3 datasets:\n",
- "\ttest has 1821 instances.\n",
- "\ttrain has 67349 instances.\n",
- "\tdev has 872 instances.\n",
- "In total 2 vocabs:\n",
- "\twords has 16292 entries.\n",
- "\ttarget has 2 entries.\n",
- "\n",
- "+-----------------------------------+--------+-----------------------------------+---------+\n",
- "| raw_words | target | words | seq_len |\n",
- "+-----------------------------------+--------+-----------------------------------+---------+\n",
- "| hide new secretions from the p... | 1 | [4110, 97, 12009, 39, 2, 6843,... | 7 |\n",
- "+-----------------------------------+--------+-----------------------------------+---------+\n",
- "Vocabulary(['hide', 'new', 'secretions', 'from', 'the']...)\n"
- ]
- }
- ],
- "source": [
- "from fastNLP.io import SST2Pipe\n",
- "\n",
- "pipe = SST2Pipe()\n",
- "databundle = pipe.process_from_file()\n",
- "vocab = databundle.get_vocab('words')\n",
- "print(databundle)\n",
- "print(databundle.get_dataset('train')[0])\n",
- "print(databundle.get_vocab('words'))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "4925 872 75\n"
- ]
- }
- ],
- "source": [
- "train_data = databundle.get_dataset('train')[:5000]\n",
- "train_data, test_data = train_data.split(0.015)\n",
- "dev_data = databundle.get_dataset('dev')\n",
- "print(len(train_data),len(dev_data),len(test_data))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "scrolled": false
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "+-------------+-----------+--------+-------+---------+\n",
- "| field_names | raw_words | target | words | seq_len |\n",
- "+-------------+-----------+--------+-------+---------+\n",
- "| is_input | False | False | True | True |\n",
- "| is_target | False | True | False | False |\n",
- "| ignore_type | | False | False | False |\n",
- "| pad_value | | 0 | 0 | 0 |\n",
- "+-------------+-----------+--------+-------+---------+\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- ""
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "train_data.print_field_meta()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 使用内置模型训练"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [],
- "source": [
- "from fastNLP.models import CNNText\n",
- "\n",
- "#词嵌入的维度\n",
- "EMBED_DIM = 100\n",
- "\n",
- "#使用CNNText的时候第一个参数输入一个tuple,作为模型定义embedding的参数\n",
- "#还可以传入 kernel_nums, kernel_sizes, padding, dropout的自定义值\n",
- "model_cnn = CNNText((len(vocab),EMBED_DIM), num_classes=2, dropout=0.1)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [],
- "source": [
- "from fastNLP import AccuracyMetric\n",
- "from fastNLP import Const\n",
- "\n",
- "# metrics=AccuracyMetric() 在本例中与下面这行代码等价\n",
- "metrics=AccuracyMetric(pred=Const.OUTPUT, target=Const.TARGET)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "from fastNLP import CrossEntropyLoss\n",
- "\n",
- "# loss = CrossEntropyLoss() 在本例中与下面这行代码等价\n",
- "loss = CrossEntropyLoss(pred=Const.OUTPUT, target=Const.TARGET)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [],
- "source": [
- "# 这表示构建了一个损失函数类,由func计算损失函数,其中将从模型返回值或者DataSet的target=True的field\n",
- "# 当中找到一个参数名为`pred`的参数传入func一个参数名为`input`的参数;找到一个参数名为`label`的参数\n",
- "# 传入func作为一个名为`target`的参数\n",
- "#下面自己构建了一个交叉熵函数,和之后直接使用fastNLP中的交叉熵函数是一个效果\n",
- "import torch\n",
- "from fastNLP import LossFunc\n",
- "func = torch.nn.functional.cross_entropy\n",
- "loss_func = LossFunc(func, input=Const.OUTPUT, target=Const.TARGET)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [],
- "source": [
- "import torch.optim as optim\n",
- "\n",
- "#使用 torch.optim 定义优化器\n",
- "optimizer=optim.RMSprop(model_cnn.parameters(), lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "input fields after batch(if batch size is 2):\n",
- "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4]) \n",
- "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
- "target fields after batch(if batch size is 2):\n",
- "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
- "\n",
- "training epochs started 2020-02-27-11-31-25\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=3080.0), HTML(value='')), layout=Layout(d…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.75 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 1/10. Step:308/3080: \n",
- "\r",
- "AccuracyMetric: acc=0.751147\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.83 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 2/10. Step:616/3080: \n",
- "\r",
- "AccuracyMetric: acc=0.755734\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 1.32 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 3/10. Step:924/3080: \n",
- "\r",
- "AccuracyMetric: acc=0.758028\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.88 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 4/10. Step:1232/3080: \n",
- "\r",
- "AccuracyMetric: acc=0.741972\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.96 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 5/10. Step:1540/3080: \n",
- "\r",
- "AccuracyMetric: acc=0.728211\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.87 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 6/10. Step:1848/3080: \n",
- "\r",
- "AccuracyMetric: acc=0.755734\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 1.04 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 7/10. Step:2156/3080: \n",
- "\r",
- "AccuracyMetric: acc=0.732798\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.57 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 8/10. Step:2464/3080: \n",
- "\r",
- "AccuracyMetric: acc=0.747706\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.48 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 9/10. Step:2772/3080: \n",
- "\r",
- "AccuracyMetric: acc=0.732798\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.48 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 10/10. Step:3080/3080: \n",
- "\r",
- "AccuracyMetric: acc=0.740826\n",
- "\n",
- "\r\n",
- "In Epoch:3/Step:924, got best dev performance:\n",
- "AccuracyMetric: acc=0.758028\n",
- "Reloaded the best model.\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "{'best_eval': {'AccuracyMetric': {'acc': 0.758028}},\n",
- " 'best_epoch': 3,\n",
- " 'best_step': 924,\n",
- " 'seconds': 160.58}"
- ]
- },
- "execution_count": 9,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from fastNLP import Trainer\n",
- "\n",
- "#训练的轮数和batch size\n",
- "N_EPOCHS = 10\n",
- "BATCH_SIZE = 16\n",
- "\n",
- "#如果在定义trainer的时候没有传入optimizer参数,模型默认的优化器为torch.optim.Adam且learning rate为lr=4e-3\n",
- "#这里只使用了loss作为损失函数输入,感兴趣可以尝试其他损失函数(如之前自定义的loss_func)作为输入\n",
- "trainer = Trainer(model=model_cnn, train_data=train_data, dev_data=dev_data, loss=loss, metrics=metrics,\n",
- "optimizer=optimizer,n_epochs=N_EPOCHS, batch_size=BATCH_SIZE)\n",
- "trainer.train()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5.0), HTML(value='')), layout=Layout(disp…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.43 seconds!\n",
- "[tester] \n",
- "AccuracyMetric: acc=0.773333\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "{'AccuracyMetric': {'acc': 0.773333}}"
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from fastNLP import Tester\n",
- "\n",
- "tester = Tester(test_data, model_cnn, metrics=AccuracyMetric())\n",
- "tester.test()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python Now",
- "language": "python",
- "name": "now"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.0"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/docs/source/_static/notebooks/tutorial_6_datasetiter.ipynb b/docs/source/_static/notebooks/tutorial_6_datasetiter.ipynb
deleted file mode 100644
index 2caa4cc2..00000000
--- a/docs/source/_static/notebooks/tutorial_6_datasetiter.ipynb
+++ /dev/null
@@ -1,681 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 使用Trainer和Tester快速训练和测试"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 数据读入和处理"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/remote-home/ynzheng/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/io/loader/classification.py:340: UserWarning: SST2's test file has no target.\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "In total 3 datasets:\n",
- "\ttest has 1821 instances.\n",
- "\ttrain has 67349 instances.\n",
- "\tdev has 872 instances.\n",
- "In total 2 vocabs:\n",
- "\twords has 16292 entries.\n",
- "\ttarget has 2 entries.\n",
- "\n",
- "+-----------------------------------+--------+-----------------------------------+---------+\n",
- "| raw_words | target | words | seq_len |\n",
- "+-----------------------------------+--------+-----------------------------------+---------+\n",
- "| hide new secretions from the p... | 1 | [4110, 97, 12009, 39, 2, 6843,... | 7 |\n",
- "+-----------------------------------+--------+-----------------------------------+---------+\n",
- "Vocabulary(['hide', 'new', 'secretions', 'from', 'the']...)\n"
- ]
- }
- ],
- "source": [
- "from fastNLP.io import SST2Pipe\n",
- "\n",
- "pipe = SST2Pipe()\n",
- "databundle = pipe.process_from_file()\n",
- "vocab = databundle.get_vocab('words')\n",
- "print(databundle)\n",
- "print(databundle.get_dataset('train')[0])\n",
- "print(databundle.get_vocab('words'))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "4925 872 75\n"
- ]
- }
- ],
- "source": [
- "train_data = databundle.get_dataset('train')[:5000]\n",
- "train_data, test_data = train_data.split(0.015)\n",
- "dev_data = databundle.get_dataset('dev')\n",
- "print(len(train_data),len(dev_data),len(test_data))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "scrolled": false
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "+-------------+-----------+--------+-------+---------+\n",
- "| field_names | raw_words | target | words | seq_len |\n",
- "+-------------+-----------+--------+-------+---------+\n",
- "| is_input | False | False | True | True |\n",
- "| is_target | False | True | False | False |\n",
- "| ignore_type | | False | False | False |\n",
- "| pad_value | | 0 | 0 | 0 |\n",
- "+-------------+-----------+--------+-------+---------+\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- ""
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "train_data.print_field_meta()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [],
- "source": [
- "from fastNLP import AccuracyMetric\n",
- "from fastNLP import Const\n",
- "\n",
- "# metrics=AccuracyMetric() 在本例中与下面这行代码等价\n",
- "metrics=AccuracyMetric(pred=Const.OUTPUT, target=Const.TARGET)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## DataSetIter初探"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "batch_x: {'words': tensor([[ 13, 830, 7746, 174, 3, 47, 6, 83, 5752, 15,\n",
- " 2177, 15, 63, 57, 406, 84, 1009, 4973, 27, 17,\n",
- " 13785, 3, 533, 3687, 15623, 39, 375, 8, 15624, 8,\n",
- " 1323, 4398, 7],\n",
- " [ 1045, 11113, 16, 104, 5, 4, 176, 1824, 1704, 3,\n",
- " 2, 18, 11, 4, 1018, 432, 143, 33, 245, 308,\n",
- " 7, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
- " 0, 0, 0]]), 'seq_len': tensor([33, 21])}\n",
- "batch_y: {'target': tensor([1, 0])}\n",
- "batch_x: {'words': tensor([[ 14, 10, 4, 311, 5, 154, 1418, 609, 7],\n",
- " [ 14, 10, 437, 32, 78, 3, 78, 437, 7]]), 'seq_len': tensor([9, 9])}\n",
- "batch_y: {'target': tensor([0, 1])}\n",
- "batch_x: {'words': tensor([[ 4, 277, 685, 18, 7],\n",
- " [15618, 3204, 5, 1675, 0]]), 'seq_len': tensor([5, 4])}\n",
- "batch_y: {'target': tensor([1, 1])}\n",
- "batch_x: {'words': tensor([[ 2, 155, 3, 4426, 3, 239, 3, 739, 5, 1136,\n",
- " 41, 43, 2427, 736, 2, 648, 10, 15620, 2285, 7],\n",
- " [ 24, 95, 28, 46, 8, 336, 38, 239, 8, 2133,\n",
- " 2, 18, 10, 15622, 1421, 6, 61, 5, 387, 7]]), 'seq_len': tensor([20, 20])}\n",
- "batch_y: {'target': tensor([0, 0])}\n",
- "batch_x: {'words': tensor([[ 879, 96, 8, 1026, 12, 8067, 11, 13623, 8, 15619,\n",
- " 4, 673, 662, 15, 4, 1154, 240, 639, 417, 7],\n",
- " [ 45, 752, 327, 180, 10, 15621, 16, 72, 8904, 9,\n",
- " 1217, 7, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([20, 12])}\n",
- "batch_y: {'target': tensor([0, 1])}\n"
- ]
- }
- ],
- "source": [
- "from fastNLP import BucketSampler\n",
- "from fastNLP import DataSetIter\n",
- "\n",
- "tmp_data = dev_data[:10]\n",
- "# 定义一个Batch,传入DataSet,规定batch_size和去batch的规则。\n",
- "# 顺序(Sequential),随机(Random),相似长度组成一个batch(Bucket)\n",
- "sampler = BucketSampler(batch_size=2, seq_len_field_name='seq_len')\n",
- "batch = DataSetIter(batch_size=2, dataset=tmp_data, sampler=sampler)\n",
- "for batch_x, batch_y in batch:\n",
- " print(\"batch_x: \",batch_x)\n",
- " print(\"batch_y: \", batch_y)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "batch_x: {'words': tensor([[ 13, 830, 7746, 174, 3, 47, 6, 83, 5752, 15,\n",
- " 2177, 15, 63, 57, 406, 84, 1009, 4973, 27, 17,\n",
- " 13785, 3, 533, 3687, 15623, 39, 375, 8, 15624, 8,\n",
- " 1323, 4398, 7],\n",
- " [ 1045, 11113, 16, 104, 5, 4, 176, 1824, 1704, 3,\n",
- " 2, 18, 11, 4, 1018, 432, 143, 33, 245, 308,\n",
- " 7, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n",
- " -1, -1, -1]]), 'seq_len': tensor([33, 21])}\n",
- "batch_y: {'target': tensor([1, 0])}\n",
- "batch_x: {'words': tensor([[ 14, 10, 4, 311, 5, 154, 1418, 609, 7],\n",
- " [ 14, 10, 437, 32, 78, 3, 78, 437, 7]]), 'seq_len': tensor([9, 9])}\n",
- "batch_y: {'target': tensor([0, 1])}\n",
- "batch_x: {'words': tensor([[ 2, 155, 3, 4426, 3, 239, 3, 739, 5, 1136,\n",
- " 41, 43, 2427, 736, 2, 648, 10, 15620, 2285, 7],\n",
- " [ 24, 95, 28, 46, 8, 336, 38, 239, 8, 2133,\n",
- " 2, 18, 10, 15622, 1421, 6, 61, 5, 387, 7]]), 'seq_len': tensor([20, 20])}\n",
- "batch_y: {'target': tensor([0, 0])}\n",
- "batch_x: {'words': tensor([[ 4, 277, 685, 18, 7],\n",
- " [15618, 3204, 5, 1675, -1]]), 'seq_len': tensor([5, 4])}\n",
- "batch_y: {'target': tensor([1, 1])}\n",
- "batch_x: {'words': tensor([[ 879, 96, 8, 1026, 12, 8067, 11, 13623, 8, 15619,\n",
- " 4, 673, 662, 15, 4, 1154, 240, 639, 417, 7],\n",
- " [ 45, 752, 327, 180, 10, 15621, 16, 72, 8904, 9,\n",
- " 1217, 7, -1, -1, -1, -1, -1, -1, -1, -1]]), 'seq_len': tensor([20, 12])}\n",
- "batch_y: {'target': tensor([0, 1])}\n"
- ]
- }
- ],
- "source": [
- "tmp_data.set_pad_val('words',-1)\n",
- "batch = DataSetIter(batch_size=2, dataset=tmp_data, sampler=sampler)\n",
- "for batch_x, batch_y in batch:\n",
- " print(\"batch_x: \",batch_x)\n",
- " print(\"batch_y: \", batch_y)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "batch_x: {'words': tensor([[ 45, 752, 327, 180, 10, 15621, 16, 72, 8904, 9,\n",
- " 1217, 7, 0, 0, 0, 0, 0, 0, 0, 0,\n",
- " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
- " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
- " [ 879, 96, 8, 1026, 12, 8067, 11, 13623, 8, 15619,\n",
- " 4, 673, 662, 15, 4, 1154, 240, 639, 417, 7,\n",
- " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
- " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([12, 20])}\n",
- "batch_y: {'target': tensor([1, 0])}\n",
- "batch_x: {'words': tensor([[ 13, 830, 7746, 174, 3, 47, 6, 83, 5752, 15,\n",
- " 2177, 15, 63, 57, 406, 84, 1009, 4973, 27, 17,\n",
- " 13785, 3, 533, 3687, 15623, 39, 375, 8, 15624, 8,\n",
- " 1323, 4398, 7, 0, 0, 0, 0, 0, 0, 0],\n",
- " [ 1045, 11113, 16, 104, 5, 4, 176, 1824, 1704, 3,\n",
- " 2, 18, 11, 4, 1018, 432, 143, 33, 245, 308,\n",
- " 7, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
- " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([33, 21])}\n",
- "batch_y: {'target': tensor([1, 0])}\n",
- "batch_x: {'words': tensor([[ 14, 10, 4, 311, 5, 154, 1418, 609, 7, 0, 0, 0,\n",
- " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
- " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
- " 0, 0, 0, 0],\n",
- " [ 14, 10, 437, 32, 78, 3, 78, 437, 7, 0, 0, 0,\n",
- " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
- " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
- " 0, 0, 0, 0]]), 'seq_len': tensor([9, 9])}\n",
- "batch_y: {'target': tensor([0, 1])}\n",
- "batch_x: {'words': tensor([[ 2, 155, 3, 4426, 3, 239, 3, 739, 5, 1136,\n",
- " 41, 43, 2427, 736, 2, 648, 10, 15620, 2285, 7,\n",
- " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
- " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
- " [ 24, 95, 28, 46, 8, 336, 38, 239, 8, 2133,\n",
- " 2, 18, 10, 15622, 1421, 6, 61, 5, 387, 7,\n",
- " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
- " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([20, 20])}\n",
- "batch_y: {'target': tensor([0, 0])}\n",
- "batch_x: {'words': tensor([[ 4, 277, 685, 18, 7, 0, 0, 0, 0, 0,\n",
- " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
- " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
- " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
- " [15618, 3204, 5, 1675, 0, 0, 0, 0, 0, 0,\n",
- " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
- " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
- " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([5, 4])}\n",
- "batch_y: {'target': tensor([1, 1])}\n"
- ]
- }
- ],
- "source": [
- "from fastNLP.core.field import Padder\n",
- "import numpy as np\n",
- "class FixLengthPadder(Padder):\n",
- " def __init__(self, pad_val=0, length=None):\n",
- " super().__init__(pad_val=pad_val)\n",
- " self.length = length\n",
- " assert self.length is not None, \"Creating FixLengthPadder with no specific length!\"\n",
- "\n",
- " def __call__(self, contents, field_name, field_ele_dtype, dim):\n",
- " #计算当前contents中的最大长度\n",
- " max_len = max(map(len, contents))\n",
- " #如果当前contents中的最大长度大于指定的padder length的话就报错\n",
- " assert max_len <= self.length, \"Fixed padder length smaller than actual length! with length {}\".format(max_len)\n",
- " array = np.full((len(contents), self.length), self.pad_val, dtype=field_ele_dtype)\n",
- " for i, content_i in enumerate(contents):\n",
- " array[i, :len(content_i)] = content_i\n",
- " return array\n",
- "\n",
- "#设定FixLengthPadder的固定长度为40\n",
- "tmp_padder = FixLengthPadder(pad_val=0,length=40)\n",
- "#利用dataset的set_padder函数设定words field的padder\n",
- "tmp_data.set_padder('words',tmp_padder)\n",
- "batch = DataSetIter(batch_size=2, dataset=tmp_data, sampler=sampler)\n",
- "for batch_x, batch_y in batch:\n",
- " print(\"batch_x: \",batch_x)\n",
- " print(\"batch_y: \", batch_y)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 使用DataSetIter自己编写训练过程\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "-----start training-----\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 2.68 seconds!\n",
- "Epoch 0 Avg Loss: 0.66 AccuracyMetric: acc=0.708716 29307ms\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.38 seconds!\n",
- "Epoch 1 Avg Loss: 0.41 AccuracyMetric: acc=0.770642 52200ms\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.51 seconds!\n",
- "Epoch 2 Avg Loss: 0.16 AccuracyMetric: acc=0.747706 70268ms\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.96 seconds!\n",
- "Epoch 3 Avg Loss: 0.06 AccuracyMetric: acc=0.741972 90349ms\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 1.04 seconds!\n",
- "Epoch 4 Avg Loss: 0.03 AccuracyMetric: acc=0.740826 114250ms\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.8 seconds!\n",
- "Epoch 5 Avg Loss: 0.02 AccuracyMetric: acc=0.738532 134742ms\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.65 seconds!\n",
- "Epoch 6 Avg Loss: 0.01 AccuracyMetric: acc=0.731651 154503ms\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.8 seconds!\n",
- "Epoch 7 Avg Loss: 0.01 AccuracyMetric: acc=0.738532 175397ms\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.36 seconds!\n",
- "Epoch 8 Avg Loss: 0.01 AccuracyMetric: acc=0.733945 192384ms\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.84 seconds!\n",
- "Epoch 9 Avg Loss: 0.01 AccuracyMetric: acc=0.744266 214417ms\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5.0), HTML(value='')), layout=Layout(disp…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.04 seconds!\n",
- "[tester] \n",
- "AccuracyMetric: acc=0.786667\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "{'AccuracyMetric': {'acc': 0.786667}}"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from fastNLP import BucketSampler\n",
- "from fastNLP import DataSetIter\n",
- "from fastNLP.models import CNNText\n",
- "from fastNLP import Tester\n",
- "import torch\n",
- "import time\n",
- "\n",
- "embed_dim = 100\n",
- "model = CNNText((len(vocab),embed_dim), num_classes=2, dropout=0.1)\n",
- "\n",
- "def train(epoch, data, devdata):\n",
- " optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
- " lossfunc = torch.nn.CrossEntropyLoss()\n",
- " batch_size = 32\n",
- "\n",
- " # 定义一个Batch,传入DataSet,规定batch_size和去batch的规则。\n",
- " # 顺序(Sequential),随机(Random),相似长度组成一个batch(Bucket)\n",
- " train_sampler = BucketSampler(batch_size=batch_size, seq_len_field_name='seq_len')\n",
- " train_batch = DataSetIter(batch_size=batch_size, dataset=data, sampler=train_sampler)\n",
- "\n",
- " start_time = time.time()\n",
- " print(\"-\"*5+\"start training\"+\"-\"*5)\n",
- " for i in range(epoch):\n",
- " loss_list = []\n",
- " for batch_x, batch_y in train_batch:\n",
- " optimizer.zero_grad()\n",
- " output = model(batch_x['words'])\n",
- " loss = lossfunc(output['pred'], batch_y['target'])\n",
- " loss.backward()\n",
- " optimizer.step()\n",
- " loss_list.append(loss.item())\n",
- "\n",
- " #这里verbose如果为0,在调用Tester对象的test()函数时不输出任何信息,返回评估信息; 如果为1,打印出验证结果,返回评估信息\n",
- " #在调用过Tester对象的test()函数后,调用其_format_eval_results(res)函数,结构化输出验证结果\n",
- " tester_tmp = Tester(devdata, model, metrics=AccuracyMetric(), verbose=0)\n",
- " res=tester_tmp.test()\n",
- "\n",
- " print('Epoch {:d} Avg Loss: {:.2f}'.format(i, sum(loss_list) / len(loss_list)),end=\" \")\n",
- " print(tester_tmp._format_eval_results(res),end=\" \")\n",
- " print('{:d}ms'.format(round((time.time()-start_time)*1000)))\n",
- " loss_list.clear()\n",
- "\n",
- "train(10, train_data, dev_data)\n",
- "#使用tester进行快速测试\n",
- "tester = Tester(test_data, model, metrics=AccuracyMetric())\n",
- "tester.test()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python Now",
- "language": "python",
- "name": "now"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.0"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/docs/source/_static/notebooks/tutorial_7_metrics.ipynb b/docs/source/_static/notebooks/tutorial_7_metrics.ipynb
deleted file mode 100644
index ef791683..00000000
--- a/docs/source/_static/notebooks/tutorial_7_metrics.ipynb
+++ /dev/null
@@ -1,1206 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 使用Metric快速评测你的模型\n",
- "\n",
- "和上一篇教程一样的实验准备代码"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "from fastNLP.io import SST2Pipe\n",
- "from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric\n",
- "from fastNLP.models import CNNText\n",
- "import torch\n",
- "\n",
- "databundle = SST2Pipe().process_from_file()\n",
- "vocab = databundle.get_vocab('words')\n",
- "train_data = databundle.get_dataset('train')[:5000]\n",
- "train_data, test_data = train_data.split(0.015)\n",
- "dev_data = databundle.get_dataset('dev')\n",
- "\n",
- "model = CNNText((len(vocab),100), num_classes=2, dropout=0.1)\n",
- "loss = CrossEntropyLoss()\n",
- "metric = AccuracyMetric()\n",
- "device = 0 if torch.cuda.is_available() else 'cpu'"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "进行训练时,fastNLP提供了各种各样的 metrics 。 如前面的教程中所介绍,AccuracyMetric 类的对象被直接传到 Trainer 中用于训练"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "input fields after batch(if batch size is 2):\n",
- "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4]) \n",
- "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
- "target fields after batch(if batch size is 2):\n",
- "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
- "\n",
- "training epochs started 2020-02-28-00-37-08\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1540.0), HTML(value='')), layout=Layout(d…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.28 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.747706\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.17 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.745413\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.19 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.74656\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.15 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.762615\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.42 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.736239\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.16 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.761468\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.42 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.727064\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.21 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.731651\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.52 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.752294\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.44 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.760321\n",
- "\n",
- "\r\n",
- "In Epoch:4/Step:616, got best dev performance:\n",
- "AccuracyMetric: acc=0.762615\n",
- "Reloaded the best model.\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "{'best_eval': {'AccuracyMetric': {'acc': 0.762615}},\n",
- " 'best_epoch': 4,\n",
- " 'best_step': 616,\n",
- " 'seconds': 32.63}"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,\n",
- " loss=loss, device=device, metrics=metric)\n",
- "trainer.train()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "除了 AccuracyMetric 之外,SpanFPreRecMetric 也是一种非常见的评价指标, 例如在序列标注问题中,常以span的方式计算 F-measure, precision, recall。\n",
- "\n",
- "另外,fastNLP 还实现了用于抽取式QA(如SQuAD)的metric ExtractiveQAMetric。 用户可以参考下面这个表格。\n",
- "\n",
- "| 名称 | 介绍 |\n",
- "| -------------------- | ------------------------------------------------- |\n",
- "| `MetricBase` | 自定义metrics需继承的基类 |\n",
- "| `AccuracyMetric` | 简单的正确率metric |\n",
- "| `SpanFPreRecMetric` | 同时计算 F-measure, precision, recall 值的 metric |\n",
- "| `ExtractiveQAMetric` | 用于抽取式QA任务 的metric |\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 定义自己的metrics\n",
- "\n",
- "在定义自己的metrics类时需继承 fastNLP 的 MetricBase, 并覆盖写入 evaluate 和 get_metric 方法。\n",
- "\n",
- "- evaluate(xxx) 中传入一个批次的数据,将针对一个批次的预测结果做评价指标的累计\n",
- "\n",
- "- get_metric(xxx) 当所有数据处理完毕时调用该方法,它将根据 evaluate函数累计的评价指标统计量来计算最终的评价结果\n",
- "\n",
- "以分类问题中,Accuracy计算为例,假设model的forward返回dict中包含 pred 这个key, 并且该key需要用于Accuracy:\n",
- "\n",
- "```python\n",
- "class Model(nn.Module):\n",
- " def __init__(xxx):\n",
- " # do something\n",
- " def forward(self, xxx):\n",
- " # do something\n",
- " return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes\n",
- "```"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Version 1\n",
- "\n",
- "假设dataset中 `target` 这个 field 是需要预测的值,并且该 field 被设置为了 target 对应的 `AccMetric` 可以按如下的定义"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [],
- "source": [
- "from fastNLP import MetricBase\n",
- "\n",
- "class AccMetric(MetricBase):\n",
- "\n",
- " def __init__(self):\n",
- " super().__init__()\n",
- " # 根据你的情况自定义指标\n",
- " self.total = 0\n",
- " self.acc_count = 0\n",
- "\n",
- " # evaluate的参数需要和DataSet 中 field 名以及模型输出的结果 field 名一致,不然找不到对应的value\n",
- " # pred, target 的参数是 fastNLP 的默认配置\n",
- " def evaluate(self, pred, target):\n",
- " # dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric\n",
- " self.total += target.size(0)\n",
- " self.acc_count += target.eq(pred).sum().item()\n",
- "\n",
- " def get_metric(self, reset=True): # 在这里定义如何计算metric\n",
- " acc = self.acc_count/self.total\n",
- " if reset: # 是否清零以便重新计算\n",
- " self.acc_count = 0\n",
- " self.total = 0\n",
- " return {'acc': acc}\n",
- " # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "input fields after batch(if batch size is 2):\n",
- "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4]) \n",
- "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
- "target fields after batch(if batch size is 2):\n",
- "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
- "\n",
- "training epochs started 2020-02-28-00-37-41\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1540.0), HTML(value='')), layout=Layout(d…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.27 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
- "\r",
- "AccMetric: acc=0.7431192660550459\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.42 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
- "\r",
- "AccMetric: acc=0.7522935779816514\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.51 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
- "\r",
- "AccMetric: acc=0.7477064220183486\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.48 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
- "\r",
- "AccMetric: acc=0.7442660550458715\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.5 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
- "\r",
- "AccMetric: acc=0.7362385321100917\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.45 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
- "\r",
- "AccMetric: acc=0.7293577981651376\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.33 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
- "\r",
- "AccMetric: acc=0.7190366972477065\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.29 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
- "\r",
- "AccMetric: acc=0.7419724770642202\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.34 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
- "\r",
- "AccMetric: acc=0.7350917431192661\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.18 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
- "\r",
- "AccMetric: acc=0.6846330275229358\n",
- "\n",
- "\r\n",
- "In Epoch:2/Step:308, got best dev performance:\n",
- "AccMetric: acc=0.7522935779816514\n",
- "Reloaded the best model.\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "{'best_eval': {'AccMetric': {'acc': 0.7522935779816514}},\n",
- " 'best_epoch': 2,\n",
- " 'best_step': 308,\n",
- " 'seconds': 42.7}"
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,\n",
- " loss=loss, device=device, metrics=AccMetric())\n",
- "trainer.train()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Version 2\n",
- "\n",
- "如果需要复用 metric,比如下一次使用 `AccMetric` 时,dataset中目标field不叫 `target` 而叫 `y` ,或者model的输出不是 `pred`\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "class AccMetric(MetricBase):\n",
- " def __init__(self, pred=None, target=None):\n",
- " \"\"\"\n",
- " 假设在另一场景使用时,目标field叫y,model给出的key为pred_y。则只需要在初始化AccMetric时,\n",
- " acc_metric = AccMetric(pred='pred_y', target='y')即可。\n",
- " 当初始化为acc_metric = AccMetric() 时,fastNLP会直接使用 'pred', 'target' 作为key去索取对应的的值\n",
- " \"\"\"\n",
- "\n",
- " super().__init__()\n",
- "\n",
- " # 如果没有注册该则效果与 Version 1 就是一样的\n",
- " self._init_param_map(pred=pred, target=target) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可\n",
- "\n",
- " # 根据你的情况自定义指标\n",
- " self.total = 0\n",
- " self.acc_count = 0\n",
- "\n",
- " # evaluate的参数需要和DataSet 中 field 名以及模型输出的结果 field 名一致,不然找不到对应的value\n",
- " # pred, target 的参数是 fastNLP 的默认配置\n",
- " def evaluate(self, pred, target):\n",
- " # dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric\n",
- " self.total += target.size(0)\n",
- " self.acc_count += target.eq(pred).sum().item()\n",
- "\n",
- " def get_metric(self, reset=True): # 在这里定义如何计算metric\n",
- " acc = self.acc_count/self.total\n",
- " if reset: # 是否清零以便重新计算\n",
- " self.acc_count = 0\n",
- " self.total = 0\n",
- " return {'acc': acc}\n",
- " # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "input fields after batch(if batch size is 2):\n",
- "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4]) \n",
- "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
- "target fields after batch(if batch size is 2):\n",
- "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
- "\n",
- "training epochs started 2020-02-28-00-38-24\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1540.0), HTML(value='')), layout=Layout(d…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.32 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
- "\r",
- "AccMetric: acc=0.7511467889908257\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.29 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
- "\r",
- "AccMetric: acc=0.7454128440366973\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.42 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
- "\r",
- "AccMetric: acc=0.7224770642201835\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.4 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
- "\r",
- "AccMetric: acc=0.7534403669724771\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.41 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
- "\r",
- "AccMetric: acc=0.7396788990825688\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.22 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
- "\r",
- "AccMetric: acc=0.7442660550458715\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.45 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
- "\r",
- "AccMetric: acc=0.6903669724770642\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.25 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
- "\r",
- "AccMetric: acc=0.7293577981651376\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.4 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
- "\r",
- "AccMetric: acc=0.7006880733944955\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.48 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
- "\r",
- "AccMetric: acc=0.7339449541284404\n",
- "\n",
- "\r\n",
- "In Epoch:4/Step:616, got best dev performance:\n",
- "AccMetric: acc=0.7534403669724771\n",
- "Reloaded the best model.\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "{'best_eval': {'AccMetric': {'acc': 0.7534403669724771}},\n",
- " 'best_epoch': 4,\n",
- " 'best_step': 616,\n",
- " 'seconds': 34.74}"
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,\n",
- " loss=loss, device=device, metrics=AccMetric())\n",
- "trainer.train()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "``MetricBase`` 将会在输入的字典 ``pred_dict`` 和 ``target_dict`` 中进行检查.\n",
- "``pred_dict`` 是模型当中 ``forward()`` 函数或者 ``predict()`` 函数的返回值.\n",
- "``target_dict`` 是DataSet当中的ground truth, 判定ground truth的条件是field的 ``is_target`` 被设置为True.\n",
- "\n",
- "``MetricBase`` 会进行以下的类型检测:\n",
- "\n",
- "1. self.evaluate当中是否有 varargs, 这是不支持的.\n",
- "2. self.evaluate当中所需要的参数是否既不在 ``pred_dict`` 也不在 ``target_dict`` .\n",
- "3. self.evaluate当中所需要的参数是否既在 ``pred_dict`` 也在 ``target_dict`` .\n",
- "\n",
- "除此以外,在参数被传入self.evaluate以前,这个函数会检测 ``pred_dict`` 和 ``target_dict`` 当中没有被用到的参数\n",
- "如果kwargs是self.evaluate的参数,则不会检测\n",
- "\n",
- "self.evaluate将计算一个批次(batch)的评价指标,并累计。 没有返回值\n",
- "self.get_metric将统计当前的评价指标并返回评价结果, 返回值需要是一个dict, key是指标名称,value是指标的值\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python Now",
- "language": "python",
- "name": "now"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.0"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/docs/source/_static/notebooks/tutorial_8_modules_models.ipynb b/docs/source/_static/notebooks/tutorial_8_modules_models.ipynb
deleted file mode 100644
index 2784cca1..00000000
--- a/docs/source/_static/notebooks/tutorial_8_modules_models.ipynb
+++ /dev/null
@@ -1,1014 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 使用Modules和Models快速搭建自定义模型\n",
- "\n",
- "modules 和 models 用于构建 fastNLP 所需的神经网络模型,它可以和 torch.nn 中的模型一起使用。 下面我们会分三节介绍编写构建模型的具体方法。\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "我们首先准备好和上篇教程一样的基础实验代码"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "from fastNLP.io import SST2Pipe\n",
- "from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric\n",
- "import torch\n",
- "\n",
- "databundle = SST2Pipe().process_from_file()\n",
- "vocab = databundle.get_vocab('words')\n",
- "train_data = databundle.get_dataset('train')[:5000]\n",
- "train_data, test_data = train_data.split(0.015)\n",
- "dev_data = databundle.get_dataset('dev')\n",
- "\n",
- "loss = CrossEntropyLoss()\n",
- "metric = AccuracyMetric()\n",
- "device = 0 if torch.cuda.is_available() else 'cpu'"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 使用 models 中的模型\n",
- "\n",
- "fastNLP 在 models 模块中内置了如 CNNText 、 SeqLabeling 等完整的模型,以供用户直接使用。 以文本分类的任务为例,我们从 models 中导入 CNNText 模型,用它进行训练。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "input fields after batch(if batch size is 2):\n",
- "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 41]) \n",
- "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
- "target fields after batch(if batch size is 2):\n",
- "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
- "\n",
- "training epochs started 2020-02-28-00-56-04\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1540.0), HTML(value='')), layout=Layout(d…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.22 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.760321\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.29 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.727064\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.48 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.758028\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.24 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.759174\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.47 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.743119\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.22 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.756881\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.21 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.752294\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.21 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.756881\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.15 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.75344\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.12 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.752294\n",
- "\n",
- "\r\n",
- "In Epoch:1/Step:154, got best dev performance:\n",
- "AccuracyMetric: acc=0.760321\n",
- "Reloaded the best model.\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "{'best_eval': {'AccuracyMetric': {'acc': 0.760321}},\n",
- " 'best_epoch': 1,\n",
- " 'best_step': 154,\n",
- " 'seconds': 29.3}"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from fastNLP.models import CNNText\n",
- "\n",
- "model_cnn = CNNText((len(vocab),100), num_classes=2, dropout=0.1)\n",
- "\n",
- "trainer = Trainer(train_data=train_data, dev_data=dev_data, metrics=metric,\n",
- " loss=loss, device=device, model=model_cnn)\n",
- "trainer.train()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "在 iPython 环境输入 model_cnn ,我们可以看到 model_cnn 的网络结构"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "CNNText(\n",
- " (embed): Embedding(\n",
- " (embed): Embedding(16292, 100)\n",
- " (dropout): Dropout(p=0.0, inplace=False)\n",
- " )\n",
- " (conv_pool): ConvMaxpool(\n",
- " (convs): ModuleList(\n",
- " (0): Conv1d(100, 30, kernel_size=(1,), stride=(1,), bias=False)\n",
- " (1): Conv1d(100, 40, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
- " (2): Conv1d(100, 50, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
- " )\n",
- " )\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (fc): Linear(in_features=120, out_features=2, bias=True)\n",
- ")"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "model_cnn"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 使用 nn.torch 编写模型\n",
- "\n",
- "FastNLP 完全支持使用 pyTorch 编写的模型,但与 pyTorch 中编写模型的常见方法不同, 用于 fastNLP 的模型中 forward 函数需要返回一个字典,字典中至少需要包含 pred 这个字段。\n",
- "\n",
- "下面是使用 pyTorch 中的 torch.nn 模块编写的文本分类,注意观察代码中标注的向量维度。 由于 pyTorch 使用了约定俗成的维度设置,使得 forward 中需要多次处理维度顺序"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [],
- "source": [
- "import torch\n",
- "import torch.nn as nn\n",
- "\n",
- "class LSTMText(nn.Module):\n",
- " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n",
- " super().__init__()\n",
- "\n",
- " self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
- " self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=True, dropout=dropout)\n",
- " self.fc = nn.Linear(hidden_dim * 2, output_dim)\n",
- " self.dropout = nn.Dropout(dropout)\n",
- "\n",
- " def forward(self, words):\n",
- " # (input) words : (batch_size, seq_len)\n",
- " words = words.permute(1,0)\n",
- " # words : (seq_len, batch_size)\n",
- "\n",
- " embedded = self.dropout(self.embedding(words))\n",
- " # embedded : (seq_len, batch_size, embedding_dim)\n",
- " output, (hidden, cell) = self.lstm(embedded)\n",
- " # output: (seq_len, batch_size, hidden_dim * 2)\n",
- " # hidden: (num_layers * 2, batch_size, hidden_dim)\n",
- " # cell: (num_layers * 2, batch_size, hidden_dim)\n",
- "\n",
- " hidden = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)\n",
- " hidden = self.dropout(hidden)\n",
- " # hidden: (batch_size, hidden_dim * 2)\n",
- "\n",
- " pred = self.fc(hidden.squeeze(0))\n",
- " # result: (batch_size, output_dim)\n",
- " return {\"pred\":pred}"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "我们同样可以在 iPython 环境中查看这个模型的网络结构"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "LSTMText(\n",
- " (embedding): Embedding(16292, 100)\n",
- " (lstm): LSTM(100, 64, num_layers=2, dropout=0.5, bidirectional=True)\n",
- " (fc): Linear(in_features=128, out_features=2, bias=True)\n",
- " (dropout): Dropout(p=0.5, inplace=False)\n",
- ")"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "model_lstm = LSTMText(len(vocab), 100, 2)\n",
- "model_lstm "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "input fields after batch(if batch size is 2):\n",
- "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 41]) \n",
- "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
- "target fields after batch(if batch size is 2):\n",
- "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
- "\n",
- "training epochs started 2020-02-28-00-56-34\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1540.0), HTML(value='')), layout=Layout(d…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.36 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.59289\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.35 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.674312\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.21 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.724771\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.4 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.748853\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.24 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.756881\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.29 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.741972\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.32 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.754587\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.24 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.756881\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.28 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.740826\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.23 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.751147\n",
- "\n",
- "\r\n",
- "In Epoch:5/Step:770, got best dev performance:\n",
- "AccuracyMetric: acc=0.756881\n",
- "Reloaded the best model.\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "{'best_eval': {'AccuracyMetric': {'acc': 0.756881}},\n",
- " 'best_epoch': 5,\n",
- " 'best_step': 770,\n",
- " 'seconds': 45.69}"
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "trainer = Trainer(train_data=train_data, dev_data=dev_data, metrics=metric,\n",
- " loss=loss, device=device, model=model_lstm)\n",
- "trainer.train()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 使用 modules 编写模型\n",
- "\n",
- "下面我们使用 fastNLP.modules 中的组件来构建同样的网络。由于 fastNLP 统一把 batch_size 放在第一维, 在编写代码的过程中会有一定的便利。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "MyText(\n",
- " (embedding): Embedding(\n",
- " (embed): Embedding(16292, 100)\n",
- " (dropout): Dropout(p=0.0, inplace=False)\n",
- " )\n",
- " (lstm): LSTM(\n",
- " (lstm): LSTM(100, 64, num_layers=2, batch_first=True, bidirectional=True)\n",
- " )\n",
- " (mlp): MLP(\n",
- " (hiddens): ModuleList()\n",
- " (output): Linear(in_features=128, out_features=2, bias=True)\n",
- " (dropout): Dropout(p=0.5, inplace=False)\n",
- " )\n",
- ")"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from fastNLP.modules import LSTM, MLP\n",
- "from fastNLP.embeddings import Embedding\n",
- "\n",
- "\n",
- "class MyText(nn.Module):\n",
- " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n",
- " super().__init__()\n",
- "\n",
- " self.embedding = Embedding((vocab_size, embedding_dim))\n",
- " self.lstm = LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=True)\n",
- " self.mlp = MLP([hidden_dim*2,output_dim], dropout=dropout)\n",
- "\n",
- " def forward(self, words):\n",
- " embedded = self.embedding(words)\n",
- " _,(hidden,_) = self.lstm(embedded)\n",
- " pred = self.mlp(torch.cat((hidden[-1],hidden[-2]),dim=1))\n",
- " return {\"pred\":pred}\n",
- " \n",
- "model_text = MyText(len(vocab), 100, 2)\n",
- "model_text"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "input fields after batch(if batch size is 2):\n",
- "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 41]) \n",
- "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
- "target fields after batch(if batch size is 2):\n",
- "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
- "\n",
- "training epochs started 2020-02-28-00-57-19\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "16a35f2b0ef0457dae15c5f240a19a3a",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1540.0), HTML(value='')), layout=Layout(d…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.38 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.767202\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.22 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
- "\r",
- "AccuracyMetric: acc=0.743119\n",
- "\n"
- ]
- }
- ],
- "source": [
- "trainer = Trainer(train_data=train_data, dev_data=dev_data, metrics=metric,\n",
- " loss=loss, device=device, model=model_lstm)\n",
- "trainer.train()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python Now",
- "language": "python",
- "name": "now"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.0"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/docs/source/_static/notebooks/tutorial_9_callback.ipynb b/docs/source/_static/notebooks/tutorial_9_callback.ipynb
deleted file mode 100644
index ed71a9b0..00000000
--- a/docs/source/_static/notebooks/tutorial_9_callback.ipynb
+++ /dev/null
@@ -1,622 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 使用 Callback 自定义你的训练过程"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "- 什么是 Callback\n",
- "- 使用 Callback \n",
- "- 一些常用的 Callback\n",
- "- 自定义实现 Callback"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "什么是Callback\n",
- "------\n",
- "\n",
- "Callback 是与 Trainer 紧密结合的模块,利用 Callback 可以在 Trainer 训练时,加入自定义的操作,比如梯度裁剪,学习率调节,测试模型的性能等。定义的 Callback 会在训练的特定阶段被调用。\n",
- "\n",
- "fastNLP 中提供了很多常用的 Callback ,开箱即用。"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "使用 Callback\n",
- " ------\n",
- "\n",
- "使用 Callback 很简单,将需要的 callback 按 list 存储,以对应参数 ``callbacks`` 传入对应的 Trainer。Trainer 在训练时就会自动执行这些 Callback 指定的操作了。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2019-09-17T07:34:46.465871Z",
- "start_time": "2019-09-17T07:34:30.648758Z"
- }
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "In total 3 datasets:\n",
- "\ttest has 1200 instances.\n",
- "\ttrain has 9600 instances.\n",
- "\tdev has 1200 instances.\n",
- "In total 2 vocabs:\n",
- "\tchars has 4409 entries.\n",
- "\ttarget has 2 entries.\n",
- "\n",
- "training epochs started 2019-09-17-03-34-34\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluate data in 0.1 seconds!\n",
- "Evaluation on dev at Epoch 1/3. Step:300/900: \n",
- "AccuracyMetric: acc=0.863333\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluate data in 0.11 seconds!\n",
- "Evaluation on dev at Epoch 2/3. Step:600/900: \n",
- "AccuracyMetric: acc=0.886667\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluate data in 0.1 seconds!\n",
- "Evaluation on dev at Epoch 3/3. Step:900/900: \n",
- "AccuracyMetric: acc=0.890833\n",
- "\n",
- "\r\n",
- "In Epoch:3/Step:900, got best dev performance:\n",
- "AccuracyMetric: acc=0.890833\n",
- "Reloaded the best model.\n"
- ]
- }
- ],
- "source": [
- "from fastNLP import (Callback, EarlyStopCallback,\n",
- " Trainer, CrossEntropyLoss, AccuracyMetric)\n",
- "from fastNLP.models import CNNText\n",
- "import torch.cuda\n",
- "\n",
- "# prepare data\n",
- "def get_data():\n",
- " from fastNLP.io import ChnSentiCorpPipe as pipe\n",
- " data = pipe().process_from_file()\n",
- " print(data)\n",
- " data.rename_field('chars', 'words')\n",
- " train_data = data.datasets['train']\n",
- " dev_data = data.datasets['dev']\n",
- " test_data = data.datasets['test']\n",
- " vocab = data.vocabs['words']\n",
- " tgt_vocab = data.vocabs['target']\n",
- " return train_data, dev_data, test_data, vocab, tgt_vocab\n",
- "\n",
- "# prepare model\n",
- "train_data, dev_data, _, vocab, tgt_vocab = get_data()\n",
- "device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
- "model = CNNText((len(vocab),50), num_classes=len(tgt_vocab))\n",
- "\n",
- "# define callback\n",
- "callbacks=[EarlyStopCallback(5)]\n",
- "\n",
- "# pass callbacks to Trainer\n",
- "def train_with_callback(cb_list):\n",
- " trainer = Trainer(\n",
- " device=device,\n",
- " n_epochs=3,\n",
- " model=model, \n",
- " train_data=train_data, \n",
- " dev_data=dev_data, \n",
- " loss=CrossEntropyLoss(), \n",
- " metrics=AccuracyMetric(), \n",
- " callbacks=cb_list, \n",
- " check_code_level=-1\n",
- " )\n",
- " trainer.train()\n",
- "\n",
- "train_with_callback(callbacks)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "fastNLP 中的 Callback\n",
- "-------\n",
- "fastNLP 中提供了很多常用的 Callback,如梯度裁剪,训练时早停和测试验证集,fitlog 等等。具体 Callback 请参考 fastNLP.core.callbacks"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2019-09-17T07:35:02.182727Z",
- "start_time": "2019-09-17T07:34:49.443863Z"
- }
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "training epochs started 2019-09-17-03-34-49\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluate data in 0.13 seconds!\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluate data in 0.12 seconds!\n",
- "Evaluation on data-test:\n",
- "AccuracyMetric: acc=0.890833\n",
- "Evaluation on dev at Epoch 1/3. Step:300/900: \n",
- "AccuracyMetric: acc=0.890833\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluate data in 0.09 seconds!\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluate data in 0.09 seconds!\n",
- "Evaluation on data-test:\n",
- "AccuracyMetric: acc=0.8875\n",
- "Evaluation on dev at Epoch 2/3. Step:600/900: \n",
- "AccuracyMetric: acc=0.8875\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluate data in 0.11 seconds!\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluate data in 0.1 seconds!\n",
- "Evaluation on data-test:\n",
- "AccuracyMetric: acc=0.885\n",
- "Evaluation on dev at Epoch 3/3. Step:900/900: \n",
- "AccuracyMetric: acc=0.885\n",
- "\n",
- "\r\n",
- "In Epoch:1/Step:300, got best dev performance:\n",
- "AccuracyMetric: acc=0.890833\n",
- "Reloaded the best model.\n"
- ]
- }
- ],
- "source": [
- "from fastNLP import EarlyStopCallback, GradientClipCallback, EvaluateCallback\n",
- "callbacks = [\n",
- " EarlyStopCallback(5),\n",
- " GradientClipCallback(clip_value=5, clip_type='value'),\n",
- " EvaluateCallback(dev_data)\n",
- "]\n",
- "\n",
- "train_with_callback(callbacks)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "自定义 Callback\n",
- "------\n",
- "\n",
- "这里我们以一个简单的 Callback作为例子,它的作用是打印每一个 Epoch 平均训练 loss。\n",
- "\n",
- "#### 创建 Callback\n",
- " \n",
- "要自定义 Callback,我们要实现一个类,继承 fastNLP.Callback。\n",
- "\n",
- "这里我们定义 MyCallBack ,继承 fastNLP.Callback 。\n",
- "\n",
- "#### 指定 Callback 调用的阶段\n",
- " \n",
- "Callback 中所有以 on_ 开头的类方法会在 Trainer 的训练中在特定阶段调用。 如 on_train_begin() 会在训练开始时被调用,on_epoch_end() 会在每个 epoch 结束时调用。 具体有哪些类方法,参见 Callback 文档。\n",
- "\n",
- "这里, MyCallBack 在求得loss时调用 on_backward_begin() 记录当前 loss ,在每一个 epoch 结束时调用 on_epoch_end() ,求当前 epoch 平均loss并输出。\n",
- "\n",
- "#### 使用 Callback 的属性访问 Trainer 的内部信息\n",
- " \n",
- "为了方便使用,可以使用 Callback 的属性,访问 Trainer 中的对应信息,如 optimizer, epoch, n_epochs,分别对应训练时的优化器,当前 epoch 数,和总 epoch 数。 具体可访问的属性,参见文档 Callback 。\n",
- "\n",
- "这里, MyCallBack 为了求平均 loss ,需要知道当前 epoch 的总步数,可以通过 self.step 属性得到当前训练了多少步。\n",
- "\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2019-09-17T07:43:10.907139Z",
- "start_time": "2019-09-17T07:42:58.488177Z"
- }
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "training epochs started 2019-09-17-03-42-58\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluate data in 0.11 seconds!\n",
- "Evaluation on dev at Epoch 1/3. Step:300/900: \n",
- "AccuracyMetric: acc=0.883333\n",
- "\n",
- "Avg loss at epoch 1, 0.100254\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluate data in 0.1 seconds!\n",
- "Evaluation on dev at Epoch 2/3. Step:600/900: \n",
- "AccuracyMetric: acc=0.8775\n",
- "\n",
- "Avg loss at epoch 2, 0.183511\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluate data in 0.13 seconds!\n",
- "Evaluation on dev at Epoch 3/3. Step:900/900: \n",
- "AccuracyMetric: acc=0.875833\n",
- "\n",
- "Avg loss at epoch 3, 0.257103\n",
- "\r\n",
- "In Epoch:1/Step:300, got best dev performance:\n",
- "AccuracyMetric: acc=0.883333\n",
- "Reloaded the best model.\n"
- ]
- }
- ],
- "source": [
- "from fastNLP import Callback\n",
- "from fastNLP import logger\n",
- "\n",
- "class MyCallBack(Callback):\n",
- " \"\"\"Print average loss in each epoch\"\"\"\n",
- " def __init__(self):\n",
- " super().__init__()\n",
- " self.total_loss = 0\n",
- " self.start_step = 0\n",
- " \n",
- " def on_backward_begin(self, loss):\n",
- " self.total_loss += loss.item()\n",
- " \n",
- " def on_epoch_end(self):\n",
- " n_steps = self.step - self.start_step\n",
- " avg_loss = self.total_loss / n_steps\n",
- " logger.info('Avg loss at epoch %d, %.6f', self.epoch, avg_loss)\n",
- " self.start_step = self.step\n",
- "\n",
- "callbacks = [MyCallBack()]\n",
- "train_with_callback(callbacks)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.7.3"
- },
- "varInspector": {
- "cols": {
- "lenName": 16,
- "lenType": 16,
- "lenVar": 40
- },
- "kernels_config": {
- "python": {
- "delete_cmd_postfix": "",
- "delete_cmd_prefix": "del ",
- "library": "var_list.py",
- "varRefreshCmd": "print(var_dic_list())"
- },
- "r": {
- "delete_cmd_postfix": ") ",
- "delete_cmd_prefix": "rm(",
- "library": "var_list.r",
- "varRefreshCmd": "cat(var_dic_list()) "
- }
- },
- "types_to_exclude": [
- "module",
- "function",
- "builtin_function_or_method",
- "instance",
- "_Feature"
- ],
- "window_display": false
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
diff --git a/docs/source/_static/notebooks/序列标注.ipynb b/docs/source/_static/notebooks/序列标注.ipynb
deleted file mode 100644
index 15118708..00000000
--- a/docs/source/_static/notebooks/序列标注.ipynb
+++ /dev/null
@@ -1,912 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 序列标注\n",
- "\n",
- "这一部分的内容主要展示如何使用fastNLP实现序列标注(Sequence labeling)任务。您可以使用fastNLP的各个组件快捷,方便地完成序列标注任务,达到出色的效果。 在阅读这篇教程前,希望您已经熟悉了fastNLP的基础使用,尤其是数据的载入以及模型的构建,通过这个小任务的能让您进一步熟悉fastNLP的使用。\n",
- "\n",
- "## 命名实体识别(name entity recognition, NER)\n",
- "\n",
- "命名实体识别任务是从文本中抽取出具有特殊意义或者指代性非常强的实体,通常包括人名、地名、机构名和时间等。 如下面的例子中\n",
- "\n",
- "*我来自复旦大学*\n",
- "\n",
- "其中“复旦大学”就是一个机构名,命名实体识别就是要从中识别出“复旦大学”这四个字是一个整体,且属于机构名这个类别。这个问题在实际做的时候会被 转换为序列标注问题\n",
- "\n",
- "针对\"我来自复旦大学\"这句话,我们的预测目标将是[O, O, O, B-ORG, I-ORG, I-ORG, I-ORG],其中O表示out,即不是一个实体,B-ORG是ORG( organization的缩写)这个类别的开头(Begin),I-ORG是ORG类别的中间(Inside)。\n",
- "\n",
- "在本tutorial中我们将通过fastNLP尝试写出一个能够执行以上任务的模型。\n",
- "\n",
- "## 载入数据\n",
- "\n",
- "fastNLP的数据载入主要是由Loader与Pipe两个基类衔接完成的,您可以通过《使用Loader和Pipe处理数据》了解如何使用fastNLP提供的数据加载函数。下面我们以微博命名实体任务来演示一下在fastNLP进行序列标注任务。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "+-----------------------------------+-----------------------------------+-----------------------------------+---------+\n",
- "| raw_chars | target | chars | seq_len |\n",
- "+-----------------------------------+-----------------------------------+-----------------------------------+---------+\n",
- "| ['科', '技', '全', '方', '位',... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... | [792, 1015, 156, 198, 291, 714... | 26 |\n",
- "| ['对', ',', '输', '给', '一',... | [0, 0, 0, 0, 0, 0, 3, 1, 0, 0,... | [123, 2, 1205, 115, 8, 24, 101... | 15 |\n",
- "+-----------------------------------+-----------------------------------+-----------------------------------+---------+\n"
- ]
- }
- ],
- "source": [
- "from fastNLP.io import WeiboNERPipe\n",
- "data_bundle = WeiboNERPipe().process_from_file()\n",
- "print(data_bundle.get_dataset('train')[:2])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 模型构建\n",
- "\n",
- "首先选择需要使用的Embedding类型。关于Embedding的相关说明可以参见《使用Embedding模块将文本转成向量》。 在这里我们使用通过word2vec预训练的中文汉字embedding。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Found 3321 out of 3471 words in the pre-training embedding.\n"
- ]
- }
- ],
- "source": [
- "from fastNLP.embeddings import StaticEmbedding\n",
- "\n",
- "embed = StaticEmbedding(vocab=data_bundle.get_vocab('chars'), model_dir_or_name='cn-char-fastnlp-100d')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "选择好Embedding之后,我们可以使用fastNLP中自带的 fastNLP.models.BiLSTMCRF 作为模型。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [],
- "source": [
- "from fastNLP.models import BiLSTMCRF\n",
- "\n",
- "data_bundle.rename_field('chars', 'words') # 这是由于BiLSTMCRF模型的forward函数接受的words,而不是chars,所以需要把这一列重新命名\n",
- "model = BiLSTMCRF(embed=embed, num_classes=len(data_bundle.get_vocab('target')), num_layers=1, hidden_size=200, dropout=0.5,\n",
- " target_vocab=data_bundle.get_vocab('target'))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 进行训练\n",
- "下面我们选择用来评估模型的metric,以及优化用到的优化函数。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [],
- "source": [
- "from fastNLP import SpanFPreRecMetric\n",
- "from torch.optim import Adam\n",
- "from fastNLP import LossInForward\n",
- "\n",
- "metric = SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))\n",
- "optimizer = Adam(model.parameters(), lr=1e-2)\n",
- "loss = LossInForward()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "使用Trainer进行训练, 您可以通过修改 device 的值来选择显卡。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "input fields after batch(if batch size is 2):\n",
- "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n",
- "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
- "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n",
- "target fields after batch(if batch size is 2):\n",
- "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n",
- "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
- "\n",
- "training epochs started 2020-02-27-13-53-24\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=430.0), HTML(value='')), layout=Layout(di…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.89 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 1/10. Step:43/430: \n",
- "\r",
- "SpanFPreRecMetric: f=0.067797, pre=0.192771, rec=0.041131\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.9 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 2/10. Step:86/430: \n",
- "\r",
- "SpanFPreRecMetric: f=0.344086, pre=0.568047, rec=0.246787\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.88 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 3/10. Step:129/430: \n",
- "\r",
- "SpanFPreRecMetric: f=0.446701, pre=0.653465, rec=0.339332\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.81 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 4/10. Step:172/430: \n",
- "\r",
- "SpanFPreRecMetric: f=0.479871, pre=0.642241, rec=0.383033\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.91 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 5/10. Step:215/430: \n",
- "\r",
- "SpanFPreRecMetric: f=0.486312, pre=0.650862, rec=0.388175\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.87 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 6/10. Step:258/430: \n",
- "\r",
- "SpanFPreRecMetric: f=0.541401, pre=0.711297, rec=0.437018\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.86 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 7/10. Step:301/430: \n",
- "\r",
- "SpanFPreRecMetric: f=0.430335, pre=0.685393, rec=0.313625\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.82 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 8/10. Step:344/430: \n",
- "\r",
- "SpanFPreRecMetric: f=0.477759, pre=0.665138, rec=0.372751\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.81 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 9/10. Step:387/430: \n",
- "\r",
- "SpanFPreRecMetric: f=0.500759, pre=0.611111, rec=0.424165\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 0.8 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 10/10. Step:430/430: \n",
- "\r",
- "SpanFPreRecMetric: f=0.496025, pre=0.65, rec=0.401028\n",
- "\n",
- "\r\n",
- "In Epoch:6/Step:258, got best dev performance:\n",
- "SpanFPreRecMetric: f=0.541401, pre=0.711297, rec=0.437018\n",
- "Reloaded the best model.\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "{'best_eval': {'SpanFPreRecMetric': {'f': 0.541401,\n",
- " 'pre': 0.711297,\n",
- " 'rec': 0.437018}},\n",
- " 'best_epoch': 6,\n",
- " 'best_step': 258,\n",
- " 'seconds': 121.39}"
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from fastNLP import Trainer\n",
- "import torch\n",
- "\n",
- "device= 0 if torch.cuda.is_available() else 'cpu'\n",
- "trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer,\n",
- " dev_data=data_bundle.get_dataset('dev'), metrics=metric, device=device)\n",
- "trainer.train()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 进行测试\n",
- "训练结束之后过,可以通过 Tester 测试其在测试集上的性能"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=17.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 1.54 seconds!\n",
- "[tester] \n",
- "SpanFPreRecMetric: f=0.439024, pre=0.685279, rec=0.322967\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "{'SpanFPreRecMetric': {'f': 0.439024, 'pre': 0.685279, 'rec': 0.322967}}"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from fastNLP import Tester\n",
- "tester = Tester(data_bundle.get_dataset('test'), model, metrics=metric)\n",
- "tester.test()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 使用更强的Bert做序列标注\n",
- "\n",
- "在fastNLP使用Bert进行任务,您只需要把fastNLP.embeddings.StaticEmbedding 切换为 fastNLP.embeddings.BertEmbedding(可修改 device 选择显卡)。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-chinese-wwm/vocab.txt\n",
- "Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-chinese-wwm/chinese_wwm_pytorch.bin.\n",
- "Start to generate word pieces for word.\n",
- "Found(Or segment into word pieces) 3384 words out of 3471.\n",
- "input fields after batch(if batch size is 2):\n",
- "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n",
- "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
- "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n",
- "target fields after batch(if batch size is 2):\n",
- "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n",
- "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
- "\n",
- "training epochs started 2020-02-27-13-58-51\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1130.0), HTML(value='')), layout=Layout(d…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluate data in 2.7 seconds!\n",
- "Evaluation on dev at Epoch 1/10. Step:113/1130: \n",
- "SpanFPreRecMetric: f=0.008114, pre=0.019231, rec=0.005141\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluate data in 2.49 seconds!\n",
- "Evaluation on dev at Epoch 2/10. Step:226/1130: \n",
- "SpanFPreRecMetric: f=0.467866, pre=0.467866, rec=0.467866\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluate data in 2.6 seconds!\n",
- "Evaluation on dev at Epoch 3/10. Step:339/1130: \n",
- "SpanFPreRecMetric: f=0.566879, pre=0.482821, rec=0.686375\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluate data in 2.56 seconds!\n",
- "Evaluation on dev at Epoch 4/10. Step:452/1130: \n",
- "SpanFPreRecMetric: f=0.651972, pre=0.59408, rec=0.722365\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 2.69 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 5/10. Step:565/1130: \n",
- "\r",
- "SpanFPreRecMetric: f=0.640909, pre=0.574338, rec=0.724936\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluate data in 2.52 seconds!\n",
- "Evaluation on dev at Epoch 6/10. Step:678/1130: \n",
- "SpanFPreRecMetric: f=0.661836, pre=0.624146, rec=0.70437\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluate data in 2.67 seconds!\n",
- "Evaluation on dev at Epoch 7/10. Step:791/1130: \n",
- "SpanFPreRecMetric: f=0.683429, pre=0.615226, rec=0.768638\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 2.37 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 8/10. Step:904/1130: \n",
- "\r",
- "SpanFPreRecMetric: f=0.674699, pre=0.634921, rec=0.719794\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluate data in 2.42 seconds!\n",
- "Evaluation on dev at Epoch 9/10. Step:1017/1130: \n",
- "SpanFPreRecMetric: f=0.693878, pre=0.650901, rec=0.742931\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 2.46 seconds!\n",
- "\r",
- "Evaluation on dev at Epoch 10/10. Step:1130/1130: \n",
- "\r",
- "SpanFPreRecMetric: f=0.686845, pre=0.62766, rec=0.758355\n",
- "\n",
- "\r\n",
- "In Epoch:9/Step:1017, got best dev performance:\n",
- "SpanFPreRecMetric: f=0.693878, pre=0.650901, rec=0.742931\n",
- "Reloaded the best model.\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=17.0), HTML(value='')), layout=Layout(dis…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\r",
- "Evaluate data in 1.96 seconds!\n",
- "[tester] \n",
- "SpanFPreRecMetric: f=0.626561, pre=0.596112, rec=0.660287\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "{'SpanFPreRecMetric': {'f': 0.626561, 'pre': 0.596112, 'rec': 0.660287}}"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "\n",
- "from fastNLP.io import WeiboNERPipe\n",
- "data_bundle = WeiboNERPipe().process_from_file()\n",
- "data_bundle.rename_field('chars', 'words')\n",
- "\n",
- "from fastNLP.embeddings import BertEmbedding\n",
- "embed = BertEmbedding(vocab=data_bundle.get_vocab('words'), model_dir_or_name='cn')\n",
- "model = BiLSTMCRF(embed=embed, num_classes=len(data_bundle.get_vocab('target')), num_layers=1, hidden_size=200, dropout=0.5,\n",
- " target_vocab=data_bundle.get_vocab('target'))\n",
- "\n",
- "from fastNLP import SpanFPreRecMetric\n",
- "from torch.optim import Adam\n",
- "from fastNLP import LossInForward\n",
- "metric = SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))\n",
- "optimizer = Adam(model.parameters(), lr=2e-5)\n",
- "loss = LossInForward()\n",
- "\n",
- "from fastNLP import Trainer\n",
- "import torch\n",
- "device= 5 if torch.cuda.is_available() else 'cpu'\n",
- "trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer, batch_size=12,\n",
- " dev_data=data_bundle.get_dataset('dev'), metrics=metric, device=device)\n",
- "trainer.train()\n",
- "\n",
- "from fastNLP import Tester\n",
- "tester = Tester(data_bundle.get_dataset('test'), model, metrics=metric)\n",
- "tester.test()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python Now",
- "language": "python",
- "name": "now"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.0"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/docs/source/_static/notebooks/文本分类.ipynb b/docs/source/_static/notebooks/文本分类.ipynb
deleted file mode 100644
index 66439a76..00000000
--- a/docs/source/_static/notebooks/文本分类.ipynb
+++ /dev/null
@@ -1,564 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 文本分类(Text classification)\n",
- "文本分类任务是将一句话或一段话划分到某个具体的类别。比如垃圾邮件识别,文本情绪分类等。\n",
- "\n",
- "Example:: \n",
- "1,商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错!\n",
- "\n",
- "\n",
- "其中开头的1是只这条评论的标签,表示是正面的情绪。我们将使用到的数据可以通过http://dbcloud.irocn.cn:8989/api/public/dl/dataset/chn_senti_corp.zip 下载并解压,当然也可以通过fastNLP自动下载该数据。\n",
- "\n",
- "数据中的内容如下图所示。接下来,我们将用fastNLP在这个数据上训练一个分类网络。"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "![jupyter](./cn_cls_example.png)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 步骤\n",
- "一共有以下的几个步骤 \n",
- "(1) 读取数据 \n",
- "(2) 预处理数据 \n",
- "(3) 选择预训练词向量 \n",
- "(4) 创建模型 \n",
- "(5) 训练模型 "
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### (1) 读取数据\n",
- "fastNLP提供多种数据的自动下载与自动加载功能,对于这里我们要用到的数据,我们可以用\\ref{Loader}自动下载并加载该数据。更多有关Loader的使用可以参考\\ref{Loader}"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from fastNLP.io import ChnSentiCorpLoader\n",
- "\n",
- "loader = ChnSentiCorpLoader() # 初始化一个中文情感分类的loader\n",
- "data_dir = loader.download() # 这一行代码将自动下载数据到默认的缓存地址, 并将该地址返回\n",
- "data_bundle = loader.load(data_dir) # 这一行代码将从{data_dir}处读取数据至DataBundle"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "DataBundle的相关介绍,可以参考\\ref{}。我们可以打印该data_bundle的基本信息。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "print(data_bundle)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "可以看出,该data_bundle中一个含有三个\\ref{DataSet}。通过下面的代码,我们可以查看DataSet的基本情况"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "print(data_bundle.get_dataset('train')[:2]) # 查看Train集前两个sample"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### (2) 预处理数据\n",
- "在NLP任务中,预处理一般包括: (a)将一整句话切分成汉字或者词; (b)将文本转换为index \n",
- "\n",
- "fastNLP中也提供了多种数据集的处理类,这里我们直接使用fastNLP的ChnSentiCorpPipe。更多关于Pipe的说明可以参考\\ref{Pipe}。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from fastNLP.io import ChnSentiCorpPipe\n",
- "\n",
- "pipe = ChnSentiCorpPipe()\n",
- "data_bundle = pipe.process(data_bundle) # 所有的Pipe都实现了process()方法,且输入输出都为DataBundle类型"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "print(data_bundle) # 打印data_bundle,查看其变化"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "可以看到除了之前已经包含的3个\\ref{DataSet}, 还新增了两个\\ref{Vocabulary}。我们可以打印DataSet中的内容"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "print(data_bundle.get_dataset('train')[:2])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "新增了一列为数字列表的chars,以及变为数字的target列。可以看出这两列的名称和刚好与data_bundle中两个Vocabulary的名称是一致的,我们可以打印一下Vocabulary看一下里面的内容。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "char_vocab = data_bundle.get_vocab('chars')\n",
- "print(char_vocab)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Vocabulary是一个记录着词语与index之间映射关系的类,比如"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "index = char_vocab.to_index('选')\n",
- "print(\"'选'的index是{}\".format(index)) # 这个值与上面打印出来的第一个instance的chars的第一个index是一致的\n",
- "print(\"index:{}对应的汉字是{}\".format(index, char_vocab.to_word(index))) "
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### (3) 选择预训练词向量 \n",
- "由于Word2vec, Glove, Elmo, Bert等预训练模型可以增强模型的性能,所以在训练具体任务前,选择合适的预训练词向量非常重要。在fastNLP中我们提供了多种Embedding使得加载这些预训练模型的过程变得更加便捷。更多关于Embedding的说明可以参考\\ref{Embedding}。这里我们先给出一个使用word2vec的中文汉字预训练的示例,之后再给出一个使用Bert的文本分类。这里使用的预训练词向量为'cn-fastnlp-100d',fastNLP将自动下载该embedding至本地缓存,fastNLP支持使用名字指定的Embedding以及相关说明可以参见\\ref{Embedding}"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from fastNLP.embeddings import StaticEmbedding\n",
- "\n",
- "word2vec_embed = StaticEmbedding(char_vocab, model_dir_or_name='cn-char-fastnlp-100d')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### (4) 创建模型\n",
- "这里我们使用到的模型结构如下所示,补图"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from torch import nn\n",
- "from fastNLP.modules import LSTM\n",
- "import torch\n",
- "\n",
- "# 定义模型\n",
- "class BiLSTMMaxPoolCls(nn.Module):\n",
- " def __init__(self, embed, num_classes, hidden_size=400, num_layers=1, dropout=0.3):\n",
- " super().__init__()\n",
- " self.embed = embed\n",
- " \n",
- " self.lstm = LSTM(self.embed.embedding_dim, hidden_size=hidden_size//2, num_layers=num_layers, \n",
- " batch_first=True, bidirectional=True)\n",
- " self.dropout_layer = nn.Dropout(dropout)\n",
- " self.fc = nn.Linear(hidden_size, num_classes)\n",
- " \n",
- " def forward(self, chars, seq_len): # 这里的名称必须和DataSet中相应的field对应,比如之前我们DataSet中有chars,这里就必须为chars\n",
- " # chars:[batch_size, max_len]\n",
- " # seq_len: [batch_size, ]\n",
- " chars = self.embed(chars)\n",
- " outputs, _ = self.lstm(chars, seq_len)\n",
- " outputs = self.dropout_layer(outputs)\n",
- " outputs, _ = torch.max(outputs, dim=1)\n",
- " outputs = self.fc(outputs)\n",
- " \n",
- " return {'pred':outputs} # [batch_size,], 返回值必须是dict类型,且预测值的key建议设为pred\n",
- "\n",
- "# 初始化模型\n",
- "model = BiLSTMMaxPoolCls(word2vec_embed, len(data_bundle.get_vocab('target')))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### (5) 训练模型\n",
- "fastNLP提供了Trainer对象来组织训练过程,包括完成loss计算(所以在初始化Trainer的时候需要指定loss类型),梯度更新(所以在初始化Trainer的时候需要提供优化器optimizer)以及在验证集上的性能验证(所以在初始化时需要提供一个Metric)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from fastNLP import Trainer\n",
- "from fastNLP import CrossEntropyLoss\n",
- "from torch.optim import Adam\n",
- "from fastNLP import AccuracyMetric\n",
- "\n",
- "loss = CrossEntropyLoss()\n",
- "optimizer = Adam(model.parameters(), lr=0.001)\n",
- "metric = AccuracyMetric()\n",
- "device = 0 if torch.cuda.is_available() else 'cpu' # 如果有gpu的话在gpu上运行,训练速度会更快\n",
- "\n",
- "trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, loss=loss, \n",
- " optimizer=optimizer, batch_size=32, dev_data=data_bundle.get_dataset('dev'),\n",
- " metrics=metric, device=device)\n",
- "trainer.train() # 开始训练,训练完成之后默认会加载在dev上表现最好的模型\n",
- "\n",
- "# 在测试集上测试一下模型的性能\n",
- "from fastNLP import Tester\n",
- "print(\"Performance on test is:\")\n",
- "tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)\n",
- "tester.test()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 使用Bert进行文本分类"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# 只需要切换一下Embedding即可\n",
- "from fastNLP.embeddings import BertEmbedding\n",
- "\n",
- "# 这里为了演示一下效果,所以默认Bert不更新权重\n",
- "bert_embed = BertEmbedding(char_vocab, model_dir_or_name='cn', auto_truncate=True, requires_grad=False)\n",
- "model = BiLSTMMaxPoolCls(bert_embed, len(data_bundle.get_vocab('target')), )\n",
- "\n",
- "\n",
- "import torch\n",
- "from fastNLP import Trainer\n",
- "from fastNLP import CrossEntropyLoss\n",
- "from torch.optim import Adam\n",
- "from fastNLP import AccuracyMetric\n",
- "\n",
- "loss = CrossEntropyLoss()\n",
- "optimizer = Adam(model.parameters(), lr=2e-5)\n",
- "metric = AccuracyMetric()\n",
- "device = 0 if torch.cuda.is_available() else 'cpu' # 如果有gpu的话在gpu上运行,训练速度会更快\n",
- "\n",
- "trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, loss=loss, \n",
- " optimizer=optimizer, batch_size=16, dev_data=data_bundle.get_dataset('test'),\n",
- " metrics=metric, device=device, n_epochs=3)\n",
- "trainer.train() # 开始训练,训练完成之后默认会加载在dev上表现最好的模型\n",
- "\n",
- "# 在测试集上测试一下模型的性能\n",
- "from fastNLP import Tester\n",
- "print(\"Performance on test is:\")\n",
- "tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)\n",
- "tester.test()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 基于词进行文本分类"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "由于汉字中没有显示的字与字的边界,一般需要通过分词器先将句子进行分词操作。\n",
- "下面的例子演示了如何不基于fastNLP已有的数据读取、预处理代码进行文本分类。"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### (1) 读取数据"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "这里我们继续以之前的数据为例,但这次我们不使用fastNLP自带的数据读取代码 "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from fastNLP.io import ChnSentiCorpLoader\n",
- "\n",
- "loader = ChnSentiCorpLoader() # 初始化一个中文情感分类的loader\n",
- "data_dir = loader.download() # 这一行代码将自动下载数据到默认的缓存地址, 并将该地址返回"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "下面我们先定义一个read_file_to_dataset的函数, 即给定一个文件路径,读取其中的内容,并返回一个DataSet。然后我们将所有的DataSet放入到DataBundle对象中来方便接下来的预处理"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import os\n",
- "from fastNLP import DataSet, Instance\n",
- "from fastNLP.io import DataBundle\n",
- "\n",
- "\n",
- "def read_file_to_dataset(fp):\n",
- " ds = DataSet()\n",
- " with open(fp, 'r') as f:\n",
- " f.readline() # 第一行是title名称,忽略掉\n",
- " for line in f:\n",
- " line = line.strip()\n",
- " target, chars = line.split('\\t')\n",
- " ins = Instance(target=target, raw_chars=chars)\n",
- " ds.append(ins)\n",
- " return ds\n",
- "\n",
- "data_bundle = DataBundle()\n",
- "for name in ['train.tsv', 'dev.tsv', 'test.tsv']:\n",
- " fp = os.path.join(data_dir, name)\n",
- " ds = read_file_to_dataset(fp)\n",
- " data_bundle.set_dataset(name=name.split('.')[0], dataset=ds)\n",
- "\n",
- "print(data_bundle) # 查看以下数据集的情况\n",
- "# In total 3 datasets:\n",
- "# train has 9600 instances.\n",
- "# dev has 1200 instances.\n",
- "# test has 1200 instances."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### (2) 数据预处理"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "在这里,我们首先把句子通过 [fastHan](http://gitee.com/fastnlp/fastHan) 进行分词操作,然后创建词表,并将词语转换为序号。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from fastHan import FastHan\n",
- "from fastNLP import Vocabulary\n",
- "\n",
- "model=FastHan()\n",
- "# model.set_device('cuda')\n",
- "\n",
- "# 定义分词处理操作\n",
- "def word_seg(ins):\n",
- " raw_chars = ins['raw_chars']\n",
- " # 由于有些句子比较长,我们只截取前128个汉字\n",
- " raw_words = model(raw_chars[:128], target='CWS')[0]\n",
- " return raw_words\n",
- "\n",
- "for name, ds in data_bundle.iter_datasets():\n",
- " # apply函数将对内部的instance依次执行word_seg操作,并把其返回值放入到raw_words这个field\n",
- " ds.apply(word_seg, new_field_name='raw_words')\n",
- " # 除了apply函数,fastNLP还支持apply_field, apply_more(可同时创建多个field)等操作\n",
- " # 同时我们增加一个seq_len的field\n",
- " ds.add_seq_len('raw_words')\n",
- "\n",
- "vocab = Vocabulary()\n",
- "\n",
- "# 对raw_words列创建词表, 建议把非训练集的dataset放在no_create_entry_dataset参数中\n",
- "# 也可以通过add_word(), add_word_lst()等建立词表,请参考http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_2_vocabulary.html\n",
- "vocab.from_dataset(data_bundle.get_dataset('train'), field_name='raw_words', \n",
- " no_create_entry_dataset=[data_bundle.get_dataset('dev'), \n",
- " data_bundle.get_dataset('test')]) \n",
- "\n",
- "# 将建立好词表的Vocabulary用于对raw_words列建立词表,并把转为序号的列存入到words列\n",
- "vocab.index_dataset(data_bundle.get_dataset('train'), data_bundle.get_dataset('dev'), \n",
- " data_bundle.get_dataset('test'), field_name='raw_words', new_field_name='words')\n",
- "\n",
- "# 建立target的词表,target的词表一般不需要padding和unknown\n",
- "target_vocab = Vocabulary(padding=None, unknown=None) \n",
- "# 一般情况下我们可以只用训练集建立target的词表\n",
- "target_vocab.from_dataset(data_bundle.get_dataset('train'), field_name='target') \n",
- "# 如果没有传递new_field_name, 则默认覆盖原词表\n",
- "target_vocab.index_dataset(data_bundle.get_dataset('train'), data_bundle.get_dataset('dev'), \n",
- " data_bundle.get_dataset('test'), field_name='target')\n",
- "\n",
- "# 我们可以把词表保存到data_bundle中,方便之后使用\n",
- "data_bundle.set_vocab(field_name='words', vocab=vocab)\n",
- "data_bundle.set_vocab(field_name='target', vocab=target_vocab)\n",
- "\n",
- "# 我们把words和target分别设置为input和target,这样它们才会在训练循环中被取出并自动padding, 有关这部分更多的内容参考\n",
- "# http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_6_datasetiter.html\n",
- "data_bundle.set_target('target')\n",
- "data_bundle.set_input('words', 'seq_len') # DataSet也有这两个接口\n",
- "# 如果某些field,您希望它被设置为target或者input,但是不希望fastNLP自动padding或需要使用特定的padding方式,请参考\n",
- "# http://www.fastnlp.top/docs/fastNLP/fastNLP.core.dataset.html\n",
- "\n",
- "print(data_bundle.get_dataset('train')[:2]) # 我们可以看一下当前dataset的内容\n",
- "\n",
- "# 由于之后需要使用之前定义的BiLSTMMaxPoolCls模型,所以需要将words这个field修改为chars(因为该模型的forward接受chars参数)\n",
- "data_bundle.rename_field('words', 'chars')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### (3) 选择预训练词向量"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "这里我们选择腾讯的预训练中文词向量,可以在 [腾讯词向量](https://ai.tencent.com/ailab/nlp/en/embedding.html) 处下载并解压。这里我们不能直接使用BERT,因为BERT是基于中文字进行预训练的。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from fastNLP.embeddings import StaticEmbedding\n",
- "\n",
- "word2vec_embed = StaticEmbedding(data_bundle.get_vocab('words'), \n",
- " model_dir_or_name='/path/to/Tencent_AILab_ChineseEmbedding.txt')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from fastNLP import Trainer\n",
- "from fastNLP import CrossEntropyLoss\n",
- "from torch.optim import Adam\n",
- "from fastNLP import AccuracyMetric\n",
- "\n",
- "# 初始化模型\n",
- "model = BiLSTMMaxPoolCls(word2vec_embed, len(data_bundle.get_vocab('target')))\n",
- "\n",
- "# 开始训练\n",
- "loss = CrossEntropyLoss()\n",
- "optimizer = Adam(model.parameters(), lr=0.001)\n",
- "metric = AccuracyMetric()\n",
- "device = 0 if torch.cuda.is_available() else 'cpu' # 如果有gpu的话在gpu上运行,训练速度会更快\n",
- "\n",
- "trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, loss=loss, \n",
- " optimizer=optimizer, batch_size=32, dev_data=data_bundle.get_dataset('dev'),\n",
- " metrics=metric, device=device)\n",
- "trainer.train() # 开始训练,训练完成之后默认会加载在dev上表现最好的模型\n",
- "\n",
- "# 在测试集上测试一下模型的性能\n",
- "from fastNLP import Tester\n",
- "print(\"Performance on test is:\")\n",
- "tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)\n",
- "tester.test()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.6.8"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/docs/source/_templates/versions.html b/docs/source/_templates/versions.html
new file mode 100644
index 00000000..44a44bfc
--- /dev/null
+++ b/docs/source/_templates/versions.html
@@ -0,0 +1,27 @@
+{%- if current_version %}
+
+
+ Other Versions
+ {{ current_version.name }}
+
+
+
+ {%- if versions.tags %}
+
+ - Tags
+ {%- for item in versions.tags %}
+ - {{ item.name }}
+ {%- endfor %}
+
+ {%- endif %}
+ {%- if versions.branches %}
+
+ - Branches
+ {%- for item in versions.branches %}
+ - {{ item.name }}
+ {%- endfor %}
+
+ {%- endif %}
+
+
+{%- endif %}
\ No newline at end of file
diff --git a/docs/source/conf.py b/docs/source/conf.py
index d1db2330..812fb0ec 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -20,13 +20,13 @@ sys.path.insert(0, os.path.abspath('../../'))
# -- Project information -----------------------------------------------------
project = 'fastNLP'
-copyright = '2020, xpqiu'
-author = 'xpqiu'
+copyright = '2022, fastNLP'
+author = 'fastNLP'
# The short X.Y version
-version = '0.6.0'
+version = '1.0'
# The full version, including alpha/beta/rc tags
-release = '0.6.0'
+release = '1.0.0-alpha'
# -- General configuration ---------------------------------------------------
@@ -42,7 +42,10 @@ extensions = [
'sphinx.ext.viewcode',
'sphinx.ext.autosummary',
'sphinx.ext.mathjax',
- 'sphinx.ext.todo'
+ 'sphinx.ext.todo',
+ 'sphinx_autodoc_typehints',
+ 'sphinx_multiversion',
+ 'nbsphinx',
]
autodoc_default_options = {
@@ -51,7 +54,12 @@ autodoc_default_options = {
'undoc-members': False,
}
+add_module_names = False
+autosummary_ignore_module_all = False
+# autodoc_typehints = "description"
autoclass_content = "class"
+typehints_fully_qualified = False
+typehints_defaults = "comma"
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
@@ -110,12 +118,16 @@ html_static_path = ['_static']
# 'searchbox.html']``.
#
# html_sidebars = {}
-
+html_sidebars = {
+ '**': [
+ 'versions.html',
+ ],
+}
# -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder.
-htmlhelp_basename = 'fastNLP doc'
+htmlhelp_basename = 'fastNLP'
# -- Options for LaTeX output ------------------------------------------------
@@ -140,17 +152,14 @@ latex_elements = {
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
-latex_documents = [
- (master_doc, 'fastNLP.tex', 'fastNLP Documentation',
- 'xpqiu', 'manual'),
-]
+latex_documents = []
# -- Options for manual page output ------------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
- (master_doc, 'fastnlp', 'fastNLP Documentation',
+ (master_doc, 'fastNLP', 'fastNLP Documentation',
[author], 1)
]
@@ -161,10 +170,12 @@ man_pages = [
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'fastNLP', 'fastNLP Documentation',
- author, 'fastNLP', 'One line description of project.',
+ author, 'fastNLP', 'A fast NLP tool for programming.',
'Miscellaneous'),
]
+# -- Options for Multiversions ----------------------------------------------
+smv_latest_version = 'dev0.8.0'
# -- Extension configuration -------------------------------------------------
def maybe_skip_member(app, what, name, obj, skip, options):
@@ -174,7 +185,7 @@ def maybe_skip_member(app, what, name, obj, skip, options):
return False
if name.startswith("_"):
return True
- return False
+ return skip
def setup(app):
diff --git a/docs/source/fastNLP.core.batch.rst b/docs/source/fastNLP.core.batch.rst
deleted file mode 100644
index 50ad6fed..00000000
--- a/docs/source/fastNLP.core.batch.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-fastNLP.core.batch
-==================
-
-.. automodule:: fastNLP.core.batch
- :members: BatchIter, DataSetIter, TorchLoaderIter
- :inherited-members:
-
diff --git a/docs/source/fastNLP.core.callback.rst b/docs/source/fastNLP.core.callback.rst
deleted file mode 100644
index 5a508e03..00000000
--- a/docs/source/fastNLP.core.callback.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-fastNLP.core.callback
-=====================
-
-.. automodule:: fastNLP.core.callback
- :members: Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, CallbackException, EarlyStopError
- :inherited-members:
-
diff --git a/docs/source/fastNLP.core.callbacks.callback.rst b/docs/source/fastNLP.core.callbacks.callback.rst
new file mode 100644
index 00000000..15a31183
--- /dev/null
+++ b/docs/source/fastNLP.core.callbacks.callback.rst
@@ -0,0 +1,7 @@
+fastNLP.core.callbacks.callback module
+======================================
+
+.. automodule:: fastNLP.core.callbacks.callback
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.callbacks.callback_event.rst b/docs/source/fastNLP.core.callbacks.callback_event.rst
new file mode 100644
index 00000000..1945b597
--- /dev/null
+++ b/docs/source/fastNLP.core.callbacks.callback_event.rst
@@ -0,0 +1,7 @@
+fastNLP.core.callbacks.callback\_event module
+=============================================
+
+.. automodule:: fastNLP.core.callbacks.callback_event
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.callbacks.callback_manager.rst b/docs/source/fastNLP.core.callbacks.callback_manager.rst
new file mode 100644
index 00000000..3f22d46f
--- /dev/null
+++ b/docs/source/fastNLP.core.callbacks.callback_manager.rst
@@ -0,0 +1,7 @@
+fastNLP.core.callbacks.callback\_manager module
+===============================================
+
+.. automodule:: fastNLP.core.callbacks.callback_manager
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.callbacks.checkpoint_callback.rst b/docs/source/fastNLP.core.callbacks.checkpoint_callback.rst
new file mode 100644
index 00000000..297879df
--- /dev/null
+++ b/docs/source/fastNLP.core.callbacks.checkpoint_callback.rst
@@ -0,0 +1,7 @@
+fastNLP.core.callbacks.checkpoint\_callback module
+==================================================
+
+.. automodule:: fastNLP.core.callbacks.checkpoint_callback
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.callbacks.early_stop_callback.rst b/docs/source/fastNLP.core.callbacks.early_stop_callback.rst
new file mode 100644
index 00000000..81356ed4
--- /dev/null
+++ b/docs/source/fastNLP.core.callbacks.early_stop_callback.rst
@@ -0,0 +1,7 @@
+fastNLP.core.callbacks.early\_stop\_callback module
+===================================================
+
+.. automodule:: fastNLP.core.callbacks.early_stop_callback
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.callbacks.fitlog_callback.rst b/docs/source/fastNLP.core.callbacks.fitlog_callback.rst
new file mode 100644
index 00000000..020c3ff3
--- /dev/null
+++ b/docs/source/fastNLP.core.callbacks.fitlog_callback.rst
@@ -0,0 +1,7 @@
+fastNLP.core.callbacks.fitlog\_callback module
+==============================================
+
+.. automodule:: fastNLP.core.callbacks.fitlog_callback
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.callbacks.has_monitor_callback.rst b/docs/source/fastNLP.core.callbacks.has_monitor_callback.rst
new file mode 100644
index 00000000..c1c6f93c
--- /dev/null
+++ b/docs/source/fastNLP.core.callbacks.has_monitor_callback.rst
@@ -0,0 +1,7 @@
+fastNLP.core.callbacks.has\_monitor\_callback module
+====================================================
+
+.. automodule:: fastNLP.core.callbacks.has_monitor_callback
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.callbacks.load_best_model_callback.rst b/docs/source/fastNLP.core.callbacks.load_best_model_callback.rst
new file mode 100644
index 00000000..9d9b4b78
--- /dev/null
+++ b/docs/source/fastNLP.core.callbacks.load_best_model_callback.rst
@@ -0,0 +1,7 @@
+fastNLP.core.callbacks.load\_best\_model\_callback module
+=========================================================
+
+.. automodule:: fastNLP.core.callbacks.load_best_model_callback
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.callbacks.lr_scheduler_callback.rst b/docs/source/fastNLP.core.callbacks.lr_scheduler_callback.rst
new file mode 100644
index 00000000..30abe617
--- /dev/null
+++ b/docs/source/fastNLP.core.callbacks.lr_scheduler_callback.rst
@@ -0,0 +1,7 @@
+fastNLP.core.callbacks.lr\_scheduler\_callback module
+=====================================================
+
+.. automodule:: fastNLP.core.callbacks.lr_scheduler_callback
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.callbacks.more_evaluate_callback.rst b/docs/source/fastNLP.core.callbacks.more_evaluate_callback.rst
new file mode 100644
index 00000000..a44071e9
--- /dev/null
+++ b/docs/source/fastNLP.core.callbacks.more_evaluate_callback.rst
@@ -0,0 +1,7 @@
+fastNLP.core.callbacks.more\_evaluate\_callback module
+======================================================
+
+.. automodule:: fastNLP.core.callbacks.more_evaluate_callback
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.callbacks.progress_callback.rst b/docs/source/fastNLP.core.callbacks.progress_callback.rst
new file mode 100644
index 00000000..b7cc7801
--- /dev/null
+++ b/docs/source/fastNLP.core.callbacks.progress_callback.rst
@@ -0,0 +1,7 @@
+fastNLP.core.callbacks.progress\_callback module
+================================================
+
+.. automodule:: fastNLP.core.callbacks.progress_callback
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.callbacks.rst b/docs/source/fastNLP.core.callbacks.rst
new file mode 100644
index 00000000..d0f3d210
--- /dev/null
+++ b/docs/source/fastNLP.core.callbacks.rst
@@ -0,0 +1,36 @@
+fastNLP.core.callbacks package
+==============================
+
+.. automodule:: fastNLP.core.callbacks
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Subpackages
+-----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.callbacks.torch_callbacks
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.callbacks.callback
+ fastNLP.core.callbacks.callback_event
+ fastNLP.core.callbacks.callback_manager
+ fastNLP.core.callbacks.checkpoint_callback
+ fastNLP.core.callbacks.early_stop_callback
+ fastNLP.core.callbacks.fitlog_callback
+ fastNLP.core.callbacks.has_monitor_callback
+ fastNLP.core.callbacks.load_best_model_callback
+ fastNLP.core.callbacks.lr_scheduler_callback
+ fastNLP.core.callbacks.more_evaluate_callback
+ fastNLP.core.callbacks.progress_callback
+ fastNLP.core.callbacks.timer_callback
+ fastNLP.core.callbacks.topk_saver
+ fastNLP.core.callbacks.utils
diff --git a/docs/source/fastNLP.core.callbacks.timer_callback.rst b/docs/source/fastNLP.core.callbacks.timer_callback.rst
new file mode 100644
index 00000000..884fa604
--- /dev/null
+++ b/docs/source/fastNLP.core.callbacks.timer_callback.rst
@@ -0,0 +1,7 @@
+fastNLP.core.callbacks.timer\_callback module
+=============================================
+
+.. automodule:: fastNLP.core.callbacks.timer_callback
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.callbacks.topk_saver.rst b/docs/source/fastNLP.core.callbacks.topk_saver.rst
new file mode 100644
index 00000000..20a311ef
--- /dev/null
+++ b/docs/source/fastNLP.core.callbacks.topk_saver.rst
@@ -0,0 +1,7 @@
+fastNLP.core.callbacks.topk\_saver module
+=========================================
+
+.. automodule:: fastNLP.core.callbacks.topk_saver
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.callbacks.torch_callbacks.rst b/docs/source/fastNLP.core.callbacks.torch_callbacks.rst
new file mode 100644
index 00000000..6f00f6f7
--- /dev/null
+++ b/docs/source/fastNLP.core.callbacks.torch_callbacks.rst
@@ -0,0 +1,16 @@
+fastNLP.core.callbacks.torch\_callbacks package
+===============================================
+
+.. automodule:: fastNLP.core.callbacks.torch_callbacks
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.callbacks.torch_callbacks.torch_grad_clip_callback
+ fastNLP.core.callbacks.torch_callbacks.torch_lr_sched_callback
diff --git a/docs/source/fastNLP.core.callbacks.torch_callbacks.torch_grad_clip_callback.rst b/docs/source/fastNLP.core.callbacks.torch_callbacks.torch_grad_clip_callback.rst
new file mode 100644
index 00000000..a4ef03b4
--- /dev/null
+++ b/docs/source/fastNLP.core.callbacks.torch_callbacks.torch_grad_clip_callback.rst
@@ -0,0 +1,7 @@
+fastNLP.core.callbacks.torch\_callbacks.torch\_grad\_clip\_callback module
+==========================================================================
+
+.. automodule:: fastNLP.core.callbacks.torch_callbacks.torch_grad_clip_callback
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.callbacks.torch_callbacks.torch_lr_sched_callback.rst b/docs/source/fastNLP.core.callbacks.torch_callbacks.torch_lr_sched_callback.rst
new file mode 100644
index 00000000..72c3e2bf
--- /dev/null
+++ b/docs/source/fastNLP.core.callbacks.torch_callbacks.torch_lr_sched_callback.rst
@@ -0,0 +1,7 @@
+fastNLP.core.callbacks.torch\_callbacks.torch\_lr\_sched\_callback module
+=========================================================================
+
+.. automodule:: fastNLP.core.callbacks.torch_callbacks.torch_lr_sched_callback
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.callbacks.utils.rst b/docs/source/fastNLP.core.callbacks.utils.rst
new file mode 100644
index 00000000..09e23a1e
--- /dev/null
+++ b/docs/source/fastNLP.core.callbacks.utils.rst
@@ -0,0 +1,7 @@
+fastNLP.core.callbacks.utils module
+===================================
+
+.. automodule:: fastNLP.core.callbacks.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.collators.collator.rst b/docs/source/fastNLP.core.collators.collator.rst
new file mode 100644
index 00000000..f620fd8b
--- /dev/null
+++ b/docs/source/fastNLP.core.collators.collator.rst
@@ -0,0 +1,7 @@
+fastNLP.core.collators.collator module
+======================================
+
+.. automodule:: fastNLP.core.collators.collator
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.collators.packer_unpacker.rst b/docs/source/fastNLP.core.collators.packer_unpacker.rst
new file mode 100644
index 00000000..9a207d07
--- /dev/null
+++ b/docs/source/fastNLP.core.collators.packer_unpacker.rst
@@ -0,0 +1,7 @@
+fastNLP.core.collators.packer\_unpacker module
+==============================================
+
+.. automodule:: fastNLP.core.collators.packer_unpacker
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.collators.padders.exceptions.rst b/docs/source/fastNLP.core.collators.padders.exceptions.rst
new file mode 100644
index 00000000..c2822970
--- /dev/null
+++ b/docs/source/fastNLP.core.collators.padders.exceptions.rst
@@ -0,0 +1,7 @@
+fastNLP.core.collators.padders.exceptions module
+================================================
+
+.. automodule:: fastNLP.core.collators.padders.exceptions
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.collators.padders.get_padder.rst b/docs/source/fastNLP.core.collators.padders.get_padder.rst
new file mode 100644
index 00000000..5ae56bef
--- /dev/null
+++ b/docs/source/fastNLP.core.collators.padders.get_padder.rst
@@ -0,0 +1,7 @@
+fastNLP.core.collators.padders.get\_padder module
+=================================================
+
+.. automodule:: fastNLP.core.collators.padders.get_padder
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.collators.padders.jittor_padder.rst b/docs/source/fastNLP.core.collators.padders.jittor_padder.rst
new file mode 100644
index 00000000..7e908090
--- /dev/null
+++ b/docs/source/fastNLP.core.collators.padders.jittor_padder.rst
@@ -0,0 +1,7 @@
+fastNLP.core.collators.padders.jittor\_padder module
+====================================================
+
+.. automodule:: fastNLP.core.collators.padders.jittor_padder
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.collators.padders.numpy_padder.rst b/docs/source/fastNLP.core.collators.padders.numpy_padder.rst
new file mode 100644
index 00000000..506473ea
--- /dev/null
+++ b/docs/source/fastNLP.core.collators.padders.numpy_padder.rst
@@ -0,0 +1,7 @@
+fastNLP.core.collators.padders.numpy\_padder module
+===================================================
+
+.. automodule:: fastNLP.core.collators.padders.numpy_padder
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.collators.padders.oneflow_padder.rst b/docs/source/fastNLP.core.collators.padders.oneflow_padder.rst
new file mode 100644
index 00000000..ced75ccb
--- /dev/null
+++ b/docs/source/fastNLP.core.collators.padders.oneflow_padder.rst
@@ -0,0 +1,7 @@
+fastNLP.core.collators.padders.oneflow\_padder module
+=====================================================
+
+.. automodule:: fastNLP.core.collators.padders.oneflow_padder
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.collators.padders.padder.rst b/docs/source/fastNLP.core.collators.padders.padder.rst
new file mode 100644
index 00000000..f56b6556
--- /dev/null
+++ b/docs/source/fastNLP.core.collators.padders.padder.rst
@@ -0,0 +1,7 @@
+fastNLP.core.collators.padders.padder module
+============================================
+
+.. automodule:: fastNLP.core.collators.padders.padder
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.collators.padders.paddle_padder.rst b/docs/source/fastNLP.core.collators.padders.paddle_padder.rst
new file mode 100644
index 00000000..93aac85a
--- /dev/null
+++ b/docs/source/fastNLP.core.collators.padders.paddle_padder.rst
@@ -0,0 +1,7 @@
+fastNLP.core.collators.padders.paddle\_padder module
+====================================================
+
+.. automodule:: fastNLP.core.collators.padders.paddle_padder
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.collators.padders.raw_padder.rst b/docs/source/fastNLP.core.collators.padders.raw_padder.rst
new file mode 100644
index 00000000..e8505480
--- /dev/null
+++ b/docs/source/fastNLP.core.collators.padders.raw_padder.rst
@@ -0,0 +1,7 @@
+fastNLP.core.collators.padders.raw\_padder module
+=================================================
+
+.. automodule:: fastNLP.core.collators.padders.raw_padder
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.collators.padders.rst b/docs/source/fastNLP.core.collators.padders.rst
new file mode 100644
index 00000000..0c50dd4c
--- /dev/null
+++ b/docs/source/fastNLP.core.collators.padders.rst
@@ -0,0 +1,25 @@
+fastNLP.core.collators.padders package
+======================================
+
+.. automodule:: fastNLP.core.collators.padders
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.collators.padders.exceptions
+ fastNLP.core.collators.padders.get_padder
+ fastNLP.core.collators.padders.jittor_padder
+ fastNLP.core.collators.padders.numpy_padder
+ fastNLP.core.collators.padders.oneflow_padder
+ fastNLP.core.collators.padders.padder
+ fastNLP.core.collators.padders.paddle_padder
+ fastNLP.core.collators.padders.raw_padder
+ fastNLP.core.collators.padders.torch_padder
+ fastNLP.core.collators.padders.torch_utils
+ fastNLP.core.collators.padders.utils
diff --git a/docs/source/fastNLP.core.collators.padders.torch_padder.rst b/docs/source/fastNLP.core.collators.padders.torch_padder.rst
new file mode 100644
index 00000000..a3dfd1a3
--- /dev/null
+++ b/docs/source/fastNLP.core.collators.padders.torch_padder.rst
@@ -0,0 +1,7 @@
+fastNLP.core.collators.padders.torch\_padder module
+===================================================
+
+.. automodule:: fastNLP.core.collators.padders.torch_padder
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.collators.padders.torch_utils.rst b/docs/source/fastNLP.core.collators.padders.torch_utils.rst
new file mode 100644
index 00000000..ac972b89
--- /dev/null
+++ b/docs/source/fastNLP.core.collators.padders.torch_utils.rst
@@ -0,0 +1,7 @@
+fastNLP.core.collators.padders.torch\_utils module
+==================================================
+
+.. automodule:: fastNLP.core.collators.padders.torch_utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.collators.padders.utils.rst b/docs/source/fastNLP.core.collators.padders.utils.rst
new file mode 100644
index 00000000..e71e2dfc
--- /dev/null
+++ b/docs/source/fastNLP.core.collators.padders.utils.rst
@@ -0,0 +1,7 @@
+fastNLP.core.collators.padders.utils module
+===========================================
+
+.. automodule:: fastNLP.core.collators.padders.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.collators.rst b/docs/source/fastNLP.core.collators.rst
new file mode 100644
index 00000000..22259c12
--- /dev/null
+++ b/docs/source/fastNLP.core.collators.rst
@@ -0,0 +1,24 @@
+fastNLP.core.collators package
+==============================
+
+.. automodule:: fastNLP.core.collators
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Subpackages
+-----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.collators.padders
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.collators.collator
+ fastNLP.core.collators.packer_unpacker
diff --git a/docs/source/fastNLP.core.const.rst b/docs/source/fastNLP.core.const.rst
deleted file mode 100644
index 82a1992e..00000000
--- a/docs/source/fastNLP.core.const.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-fastNLP.core.const
-==================
-
-.. automodule:: fastNLP.core.const
- :members: Const
- :inherited-members:
-
diff --git a/docs/source/fastNLP.core.controllers.evaluator.rst b/docs/source/fastNLP.core.controllers.evaluator.rst
new file mode 100644
index 00000000..d1b2aec3
--- /dev/null
+++ b/docs/source/fastNLP.core.controllers.evaluator.rst
@@ -0,0 +1,7 @@
+fastNLP.core.controllers.evaluator module
+=========================================
+
+.. automodule:: fastNLP.core.controllers.evaluator
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.controllers.loops.evaluate_batch_loop.rst b/docs/source/fastNLP.core.controllers.loops.evaluate_batch_loop.rst
new file mode 100644
index 00000000..a015313d
--- /dev/null
+++ b/docs/source/fastNLP.core.controllers.loops.evaluate_batch_loop.rst
@@ -0,0 +1,7 @@
+fastNLP.core.controllers.loops.evaluate\_batch\_loop module
+===========================================================
+
+.. automodule:: fastNLP.core.controllers.loops.evaluate_batch_loop
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.controllers.loops.loop.rst b/docs/source/fastNLP.core.controllers.loops.loop.rst
new file mode 100644
index 00000000..25351fe3
--- /dev/null
+++ b/docs/source/fastNLP.core.controllers.loops.loop.rst
@@ -0,0 +1,7 @@
+fastNLP.core.controllers.loops.loop module
+==========================================
+
+.. automodule:: fastNLP.core.controllers.loops.loop
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.controllers.loops.rst b/docs/source/fastNLP.core.controllers.loops.rst
new file mode 100644
index 00000000..39879148
--- /dev/null
+++ b/docs/source/fastNLP.core.controllers.loops.rst
@@ -0,0 +1,17 @@
+fastNLP.core.controllers.loops package
+======================================
+
+.. automodule:: fastNLP.core.controllers.loops
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.controllers.loops.evaluate_batch_loop
+ fastNLP.core.controllers.loops.loop
+ fastNLP.core.controllers.loops.train_batch_loop
diff --git a/docs/source/fastNLP.core.controllers.loops.train_batch_loop.rst b/docs/source/fastNLP.core.controllers.loops.train_batch_loop.rst
new file mode 100644
index 00000000..8b04da05
--- /dev/null
+++ b/docs/source/fastNLP.core.controllers.loops.train_batch_loop.rst
@@ -0,0 +1,7 @@
+fastNLP.core.controllers.loops.train\_batch\_loop module
+========================================================
+
+.. automodule:: fastNLP.core.controllers.loops.train_batch_loop
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.controllers.rst b/docs/source/fastNLP.core.controllers.rst
new file mode 100644
index 00000000..9440fbe4
--- /dev/null
+++ b/docs/source/fastNLP.core.controllers.rst
@@ -0,0 +1,25 @@
+fastNLP.core.controllers package
+================================
+
+.. automodule:: fastNLP.core.controllers
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Subpackages
+-----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.controllers.loops
+ fastNLP.core.controllers.utils
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.controllers.evaluator
+ fastNLP.core.controllers.trainer
diff --git a/docs/source/fastNLP.core.controllers.trainer.rst b/docs/source/fastNLP.core.controllers.trainer.rst
new file mode 100644
index 00000000..209a3a43
--- /dev/null
+++ b/docs/source/fastNLP.core.controllers.trainer.rst
@@ -0,0 +1,7 @@
+fastNLP.core.controllers.trainer module
+=======================================
+
+.. automodule:: fastNLP.core.controllers.trainer
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.controllers.utils.rst b/docs/source/fastNLP.core.controllers.utils.rst
new file mode 100644
index 00000000..f7bcc38c
--- /dev/null
+++ b/docs/source/fastNLP.core.controllers.utils.rst
@@ -0,0 +1,16 @@
+fastNLP.core.controllers.utils package
+======================================
+
+.. automodule:: fastNLP.core.controllers.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.controllers.utils.state
+ fastNLP.core.controllers.utils.utils
diff --git a/docs/source/fastNLP.core.controllers.utils.state.rst b/docs/source/fastNLP.core.controllers.utils.state.rst
new file mode 100644
index 00000000..5adcd921
--- /dev/null
+++ b/docs/source/fastNLP.core.controllers.utils.state.rst
@@ -0,0 +1,7 @@
+fastNLP.core.controllers.utils.state module
+===========================================
+
+.. automodule:: fastNLP.core.controllers.utils.state
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.controllers.utils.utils.rst b/docs/source/fastNLP.core.controllers.utils.utils.rst
new file mode 100644
index 00000000..ba864ab0
--- /dev/null
+++ b/docs/source/fastNLP.core.controllers.utils.utils.rst
@@ -0,0 +1,7 @@
+fastNLP.core.controllers.utils.utils module
+===========================================
+
+.. automodule:: fastNLP.core.controllers.utils.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.dataloaders.jittor_dataloader.fdl.rst b/docs/source/fastNLP.core.dataloaders.jittor_dataloader.fdl.rst
new file mode 100644
index 00000000..08df6096
--- /dev/null
+++ b/docs/source/fastNLP.core.dataloaders.jittor_dataloader.fdl.rst
@@ -0,0 +1,7 @@
+fastNLP.core.dataloaders.jittor\_dataloader.fdl module
+======================================================
+
+.. automodule:: fastNLP.core.dataloaders.jittor_dataloader.fdl
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.dataloaders.jittor_dataloader.rst b/docs/source/fastNLP.core.dataloaders.jittor_dataloader.rst
new file mode 100644
index 00000000..78d90c46
--- /dev/null
+++ b/docs/source/fastNLP.core.dataloaders.jittor_dataloader.rst
@@ -0,0 +1,15 @@
+fastNLP.core.dataloaders.jittor\_dataloader package
+===================================================
+
+.. automodule:: fastNLP.core.dataloaders.jittor_dataloader
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.dataloaders.jittor_dataloader.fdl
diff --git a/docs/source/fastNLP.core.dataloaders.oneflow_dataloader.fdl.rst b/docs/source/fastNLP.core.dataloaders.oneflow_dataloader.fdl.rst
new file mode 100644
index 00000000..5a8939b0
--- /dev/null
+++ b/docs/source/fastNLP.core.dataloaders.oneflow_dataloader.fdl.rst
@@ -0,0 +1,7 @@
+fastNLP.core.dataloaders.oneflow\_dataloader.fdl module
+=======================================================
+
+.. automodule:: fastNLP.core.dataloaders.oneflow_dataloader.fdl
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.dataloaders.oneflow_dataloader.rst b/docs/source/fastNLP.core.dataloaders.oneflow_dataloader.rst
new file mode 100644
index 00000000..2b2081e5
--- /dev/null
+++ b/docs/source/fastNLP.core.dataloaders.oneflow_dataloader.rst
@@ -0,0 +1,15 @@
+fastNLP.core.dataloaders.oneflow\_dataloader package
+====================================================
+
+.. automodule:: fastNLP.core.dataloaders.oneflow_dataloader
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.dataloaders.oneflow_dataloader.fdl
diff --git a/docs/source/fastNLP.core.dataloaders.paddle_dataloader.fdl.rst b/docs/source/fastNLP.core.dataloaders.paddle_dataloader.fdl.rst
new file mode 100644
index 00000000..5b40bec0
--- /dev/null
+++ b/docs/source/fastNLP.core.dataloaders.paddle_dataloader.fdl.rst
@@ -0,0 +1,7 @@
+fastNLP.core.dataloaders.paddle\_dataloader.fdl module
+======================================================
+
+.. automodule:: fastNLP.core.dataloaders.paddle_dataloader.fdl
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.dataloaders.paddle_dataloader.rst b/docs/source/fastNLP.core.dataloaders.paddle_dataloader.rst
new file mode 100644
index 00000000..dc4481d2
--- /dev/null
+++ b/docs/source/fastNLP.core.dataloaders.paddle_dataloader.rst
@@ -0,0 +1,15 @@
+fastNLP.core.dataloaders.paddle\_dataloader package
+===================================================
+
+.. automodule:: fastNLP.core.dataloaders.paddle_dataloader
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.dataloaders.paddle_dataloader.fdl
diff --git a/docs/source/fastNLP.core.dataloaders.prepare_dataloader.rst b/docs/source/fastNLP.core.dataloaders.prepare_dataloader.rst
new file mode 100644
index 00000000..ac8c8c20
--- /dev/null
+++ b/docs/source/fastNLP.core.dataloaders.prepare_dataloader.rst
@@ -0,0 +1,7 @@
+fastNLP.core.dataloaders.prepare\_dataloader module
+===================================================
+
+.. automodule:: fastNLP.core.dataloaders.prepare_dataloader
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.dataloaders.rst b/docs/source/fastNLP.core.dataloaders.rst
new file mode 100644
index 00000000..db53dbe0
--- /dev/null
+++ b/docs/source/fastNLP.core.dataloaders.rst
@@ -0,0 +1,27 @@
+fastNLP.core.dataloaders package
+================================
+
+.. automodule:: fastNLP.core.dataloaders
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Subpackages
+-----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.dataloaders.jittor_dataloader
+ fastNLP.core.dataloaders.oneflow_dataloader
+ fastNLP.core.dataloaders.paddle_dataloader
+ fastNLP.core.dataloaders.torch_dataloader
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.dataloaders.prepare_dataloader
+ fastNLP.core.dataloaders.utils
diff --git a/docs/source/fastNLP.core.dataloaders.torch_dataloader.fdl.rst b/docs/source/fastNLP.core.dataloaders.torch_dataloader.fdl.rst
new file mode 100644
index 00000000..33db5bf9
--- /dev/null
+++ b/docs/source/fastNLP.core.dataloaders.torch_dataloader.fdl.rst
@@ -0,0 +1,7 @@
+fastNLP.core.dataloaders.torch\_dataloader.fdl module
+=====================================================
+
+.. automodule:: fastNLP.core.dataloaders.torch_dataloader.fdl
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.dataloaders.torch_dataloader.mix_dataloader.rst b/docs/source/fastNLP.core.dataloaders.torch_dataloader.mix_dataloader.rst
new file mode 100644
index 00000000..cd8bd865
--- /dev/null
+++ b/docs/source/fastNLP.core.dataloaders.torch_dataloader.mix_dataloader.rst
@@ -0,0 +1,7 @@
+fastNLP.core.dataloaders.torch\_dataloader.mix\_dataloader module
+=================================================================
+
+.. automodule:: fastNLP.core.dataloaders.torch_dataloader.mix_dataloader
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.dataloaders.torch_dataloader.rst b/docs/source/fastNLP.core.dataloaders.torch_dataloader.rst
new file mode 100644
index 00000000..a3aeb1bf
--- /dev/null
+++ b/docs/source/fastNLP.core.dataloaders.torch_dataloader.rst
@@ -0,0 +1,16 @@
+fastNLP.core.dataloaders.torch\_dataloader package
+==================================================
+
+.. automodule:: fastNLP.core.dataloaders.torch_dataloader
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.dataloaders.torch_dataloader.fdl
+ fastNLP.core.dataloaders.torch_dataloader.mix_dataloader
diff --git a/docs/source/fastNLP.core.dataloaders.utils.rst b/docs/source/fastNLP.core.dataloaders.utils.rst
new file mode 100644
index 00000000..0b28dfe1
--- /dev/null
+++ b/docs/source/fastNLP.core.dataloaders.utils.rst
@@ -0,0 +1,7 @@
+fastNLP.core.dataloaders.utils module
+=====================================
+
+.. automodule:: fastNLP.core.dataloaders.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.dataset.dataset.rst b/docs/source/fastNLP.core.dataset.dataset.rst
new file mode 100644
index 00000000..7bae1726
--- /dev/null
+++ b/docs/source/fastNLP.core.dataset.dataset.rst
@@ -0,0 +1,7 @@
+fastNLP.core.dataset.dataset module
+===================================
+
+.. automodule:: fastNLP.core.dataset.dataset
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.dataset.field.rst b/docs/source/fastNLP.core.dataset.field.rst
new file mode 100644
index 00000000..6d85fd71
--- /dev/null
+++ b/docs/source/fastNLP.core.dataset.field.rst
@@ -0,0 +1,7 @@
+fastNLP.core.dataset.field module
+=================================
+
+.. automodule:: fastNLP.core.dataset.field
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.dataset.instance.rst b/docs/source/fastNLP.core.dataset.instance.rst
new file mode 100644
index 00000000..f3382a9d
--- /dev/null
+++ b/docs/source/fastNLP.core.dataset.instance.rst
@@ -0,0 +1,7 @@
+fastNLP.core.dataset.instance module
+====================================
+
+.. automodule:: fastNLP.core.dataset.instance
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.dataset.rst b/docs/source/fastNLP.core.dataset.rst
index e13d7f1c..dc36250a 100644
--- a/docs/source/fastNLP.core.dataset.rst
+++ b/docs/source/fastNLP.core.dataset.rst
@@ -1,7 +1,17 @@
-fastNLP.core.dataset
-====================
+fastNLP.core.dataset package
+============================
.. automodule:: fastNLP.core.dataset
- :members: DataSet
- :inherited-members:
+ :members:
+ :undoc-members:
+ :show-inheritance:
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.dataset.dataset
+ fastNLP.core.dataset.field
+ fastNLP.core.dataset.instance
diff --git a/docs/source/fastNLP.core.drivers.choose_driver.rst b/docs/source/fastNLP.core.drivers.choose_driver.rst
new file mode 100644
index 00000000..68912754
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.choose_driver.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.choose\_driver module
+==========================================
+
+.. automodule:: fastNLP.core.drivers.choose_driver
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.driver.rst b/docs/source/fastNLP.core.drivers.driver.rst
new file mode 100644
index 00000000..c5b6be38
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.driver.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.driver module
+==================================
+
+.. automodule:: fastNLP.core.drivers.driver
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.jittor_driver.initialize_jittor_driver.rst b/docs/source/fastNLP.core.drivers.jittor_driver.initialize_jittor_driver.rst
new file mode 100644
index 00000000..56057f75
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.jittor_driver.initialize_jittor_driver.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.jittor\_driver.initialize\_jittor\_driver module
+=====================================================================
+
+.. automodule:: fastNLP.core.drivers.jittor_driver.initialize_jittor_driver
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.jittor_driver.jittor_driver.rst b/docs/source/fastNLP.core.drivers.jittor_driver.jittor_driver.rst
new file mode 100644
index 00000000..c68f8bf5
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.jittor_driver.jittor_driver.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.jittor\_driver.jittor\_driver module
+=========================================================
+
+.. automodule:: fastNLP.core.drivers.jittor_driver.jittor_driver
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.jittor_driver.mpi.rst b/docs/source/fastNLP.core.drivers.jittor_driver.mpi.rst
new file mode 100644
index 00000000..7bcb2fd5
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.jittor_driver.mpi.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.jittor\_driver.mpi module
+==============================================
+
+.. automodule:: fastNLP.core.drivers.jittor_driver.mpi
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.jittor_driver.rst b/docs/source/fastNLP.core.drivers.jittor_driver.rst
new file mode 100644
index 00000000..7ec101c7
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.jittor_driver.rst
@@ -0,0 +1,19 @@
+fastNLP.core.drivers.jittor\_driver package
+===========================================
+
+.. automodule:: fastNLP.core.drivers.jittor_driver
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.drivers.jittor_driver.initialize_jittor_driver
+ fastNLP.core.drivers.jittor_driver.jittor_driver
+ fastNLP.core.drivers.jittor_driver.mpi
+ fastNLP.core.drivers.jittor_driver.single_device
+ fastNLP.core.drivers.jittor_driver.utils
diff --git a/docs/source/fastNLP.core.drivers.jittor_driver.single_device.rst b/docs/source/fastNLP.core.drivers.jittor_driver.single_device.rst
new file mode 100644
index 00000000..d9dcd051
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.jittor_driver.single_device.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.jittor\_driver.single\_device module
+=========================================================
+
+.. automodule:: fastNLP.core.drivers.jittor_driver.single_device
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.jittor_driver.utils.rst b/docs/source/fastNLP.core.drivers.jittor_driver.utils.rst
new file mode 100644
index 00000000..92a75e85
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.jittor_driver.utils.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.jittor\_driver.utils module
+================================================
+
+.. automodule:: fastNLP.core.drivers.jittor_driver.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.oneflow_driver.ddp.rst b/docs/source/fastNLP.core.drivers.oneflow_driver.ddp.rst
new file mode 100644
index 00000000..c7618619
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.oneflow_driver.ddp.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.oneflow\_driver.ddp module
+===============================================
+
+.. automodule:: fastNLP.core.drivers.oneflow_driver.ddp
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.oneflow_driver.dist_utils.rst b/docs/source/fastNLP.core.drivers.oneflow_driver.dist_utils.rst
new file mode 100644
index 00000000..9eae5d19
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.oneflow_driver.dist_utils.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.oneflow\_driver.dist\_utils module
+=======================================================
+
+.. automodule:: fastNLP.core.drivers.oneflow_driver.dist_utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.oneflow_driver.initialize_oneflow_driver.rst b/docs/source/fastNLP.core.drivers.oneflow_driver.initialize_oneflow_driver.rst
new file mode 100644
index 00000000..d7272c8e
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.oneflow_driver.initialize_oneflow_driver.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.oneflow\_driver.initialize\_oneflow\_driver module
+=======================================================================
+
+.. automodule:: fastNLP.core.drivers.oneflow_driver.initialize_oneflow_driver
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.oneflow_driver.oneflow_driver.rst b/docs/source/fastNLP.core.drivers.oneflow_driver.oneflow_driver.rst
new file mode 100644
index 00000000..1f5d159e
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.oneflow_driver.oneflow_driver.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.oneflow\_driver.oneflow\_driver module
+===========================================================
+
+.. automodule:: fastNLP.core.drivers.oneflow_driver.oneflow_driver
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.oneflow_driver.rst b/docs/source/fastNLP.core.drivers.oneflow_driver.rst
new file mode 100644
index 00000000..213dd24b
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.oneflow_driver.rst
@@ -0,0 +1,20 @@
+fastNLP.core.drivers.oneflow\_driver package
+============================================
+
+.. automodule:: fastNLP.core.drivers.oneflow_driver
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.drivers.oneflow_driver.ddp
+ fastNLP.core.drivers.oneflow_driver.dist_utils
+ fastNLP.core.drivers.oneflow_driver.initialize_oneflow_driver
+ fastNLP.core.drivers.oneflow_driver.oneflow_driver
+ fastNLP.core.drivers.oneflow_driver.single_device
+ fastNLP.core.drivers.oneflow_driver.utils
diff --git a/docs/source/fastNLP.core.drivers.oneflow_driver.single_device.rst b/docs/source/fastNLP.core.drivers.oneflow_driver.single_device.rst
new file mode 100644
index 00000000..a54e74ec
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.oneflow_driver.single_device.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.oneflow\_driver.single\_device module
+==========================================================
+
+.. automodule:: fastNLP.core.drivers.oneflow_driver.single_device
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.oneflow_driver.utils.rst b/docs/source/fastNLP.core.drivers.oneflow_driver.utils.rst
new file mode 100644
index 00000000..1eda7794
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.oneflow_driver.utils.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.oneflow\_driver.utils module
+=================================================
+
+.. automodule:: fastNLP.core.drivers.oneflow_driver.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.paddle_driver.dist_utils.rst b/docs/source/fastNLP.core.drivers.paddle_driver.dist_utils.rst
new file mode 100644
index 00000000..c7314d32
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.paddle_driver.dist_utils.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.paddle\_driver.dist\_utils module
+======================================================
+
+.. automodule:: fastNLP.core.drivers.paddle_driver.dist_utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.paddle_driver.fleet.rst b/docs/source/fastNLP.core.drivers.paddle_driver.fleet.rst
new file mode 100644
index 00000000..ba6c34d4
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.paddle_driver.fleet.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.paddle\_driver.fleet module
+================================================
+
+.. automodule:: fastNLP.core.drivers.paddle_driver.fleet
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.paddle_driver.fleet_launcher.rst b/docs/source/fastNLP.core.drivers.paddle_driver.fleet_launcher.rst
new file mode 100644
index 00000000..f20d29fa
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.paddle_driver.fleet_launcher.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.paddle\_driver.fleet\_launcher module
+==========================================================
+
+.. automodule:: fastNLP.core.drivers.paddle_driver.fleet_launcher
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.paddle_driver.initialize_paddle_driver.rst b/docs/source/fastNLP.core.drivers.paddle_driver.initialize_paddle_driver.rst
new file mode 100644
index 00000000..248b6cb6
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.paddle_driver.initialize_paddle_driver.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.paddle\_driver.initialize\_paddle\_driver module
+=====================================================================
+
+.. automodule:: fastNLP.core.drivers.paddle_driver.initialize_paddle_driver
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.paddle_driver.paddle_driver.rst b/docs/source/fastNLP.core.drivers.paddle_driver.paddle_driver.rst
new file mode 100644
index 00000000..16603bb1
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.paddle_driver.paddle_driver.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.paddle\_driver.paddle\_driver module
+=========================================================
+
+.. automodule:: fastNLP.core.drivers.paddle_driver.paddle_driver
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.paddle_driver.rst b/docs/source/fastNLP.core.drivers.paddle_driver.rst
new file mode 100644
index 00000000..0f115eb5
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.paddle_driver.rst
@@ -0,0 +1,21 @@
+fastNLP.core.drivers.paddle\_driver package
+===========================================
+
+.. automodule:: fastNLP.core.drivers.paddle_driver
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.drivers.paddle_driver.dist_utils
+ fastNLP.core.drivers.paddle_driver.fleet
+ fastNLP.core.drivers.paddle_driver.fleet_launcher
+ fastNLP.core.drivers.paddle_driver.initialize_paddle_driver
+ fastNLP.core.drivers.paddle_driver.paddle_driver
+ fastNLP.core.drivers.paddle_driver.single_device
+ fastNLP.core.drivers.paddle_driver.utils
diff --git a/docs/source/fastNLP.core.drivers.paddle_driver.single_device.rst b/docs/source/fastNLP.core.drivers.paddle_driver.single_device.rst
new file mode 100644
index 00000000..b87c836d
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.paddle_driver.single_device.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.paddle\_driver.single\_device module
+=========================================================
+
+.. automodule:: fastNLP.core.drivers.paddle_driver.single_device
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.paddle_driver.utils.rst b/docs/source/fastNLP.core.drivers.paddle_driver.utils.rst
new file mode 100644
index 00000000..8b9bd501
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.paddle_driver.utils.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.paddle\_driver.utils module
+================================================
+
+.. automodule:: fastNLP.core.drivers.paddle_driver.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.rst b/docs/source/fastNLP.core.drivers.rst
new file mode 100644
index 00000000..30652fec
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.rst
@@ -0,0 +1,28 @@
+fastNLP.core.drivers package
+============================
+
+.. automodule:: fastNLP.core.drivers
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Subpackages
+-----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.drivers.jittor_driver
+ fastNLP.core.drivers.oneflow_driver
+ fastNLP.core.drivers.paddle_driver
+ fastNLP.core.drivers.torch_driver
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.drivers.choose_driver
+ fastNLP.core.drivers.driver
+ fastNLP.core.drivers.utils
diff --git a/docs/source/fastNLP.core.drivers.torch_driver.ddp.rst b/docs/source/fastNLP.core.drivers.torch_driver.ddp.rst
new file mode 100644
index 00000000..4d6cafff
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.torch_driver.ddp.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.torch\_driver.ddp module
+=============================================
+
+.. automodule:: fastNLP.core.drivers.torch_driver.ddp
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.torch_driver.deepspeed.rst b/docs/source/fastNLP.core.drivers.torch_driver.deepspeed.rst
new file mode 100644
index 00000000..2944ffec
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.torch_driver.deepspeed.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.torch\_driver.deepspeed module
+===================================================
+
+.. automodule:: fastNLP.core.drivers.torch_driver.deepspeed
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.torch_driver.dist_utils.rst b/docs/source/fastNLP.core.drivers.torch_driver.dist_utils.rst
new file mode 100644
index 00000000..30ba5381
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.torch_driver.dist_utils.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.torch\_driver.dist\_utils module
+=====================================================
+
+.. automodule:: fastNLP.core.drivers.torch_driver.dist_utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.torch_driver.fairscale.rst b/docs/source/fastNLP.core.drivers.torch_driver.fairscale.rst
new file mode 100644
index 00000000..e68972a7
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.torch_driver.fairscale.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.torch\_driver.fairscale module
+===================================================
+
+.. automodule:: fastNLP.core.drivers.torch_driver.fairscale
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.torch_driver.initialize_torch_driver.rst b/docs/source/fastNLP.core.drivers.torch_driver.initialize_torch_driver.rst
new file mode 100644
index 00000000..989050ac
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.torch_driver.initialize_torch_driver.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.torch\_driver.initialize\_torch\_driver module
+===================================================================
+
+.. automodule:: fastNLP.core.drivers.torch_driver.initialize_torch_driver
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.torch_driver.rst b/docs/source/fastNLP.core.drivers.torch_driver.rst
new file mode 100644
index 00000000..c9080a86
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.torch_driver.rst
@@ -0,0 +1,23 @@
+fastNLP.core.drivers.torch\_driver package
+==========================================
+
+.. automodule:: fastNLP.core.drivers.torch_driver
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.drivers.torch_driver.ddp
+ fastNLP.core.drivers.torch_driver.deepspeed
+ fastNLP.core.drivers.torch_driver.dist_utils
+ fastNLP.core.drivers.torch_driver.fairscale
+ fastNLP.core.drivers.torch_driver.initialize_torch_driver
+ fastNLP.core.drivers.torch_driver.single_device
+ fastNLP.core.drivers.torch_driver.torch_driver
+ fastNLP.core.drivers.torch_driver.torch_fsdp
+ fastNLP.core.drivers.torch_driver.utils
diff --git a/docs/source/fastNLP.core.drivers.torch_driver.single_device.rst b/docs/source/fastNLP.core.drivers.torch_driver.single_device.rst
new file mode 100644
index 00000000..c8f8a2d9
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.torch_driver.single_device.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.torch\_driver.single\_device module
+========================================================
+
+.. automodule:: fastNLP.core.drivers.torch_driver.single_device
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.torch_driver.torch_driver.rst b/docs/source/fastNLP.core.drivers.torch_driver.torch_driver.rst
new file mode 100644
index 00000000..da58a329
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.torch_driver.torch_driver.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.torch\_driver.torch\_driver module
+=======================================================
+
+.. automodule:: fastNLP.core.drivers.torch_driver.torch_driver
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.torch_driver.torch_fsdp.rst b/docs/source/fastNLP.core.drivers.torch_driver.torch_fsdp.rst
new file mode 100644
index 00000000..a799b7fc
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.torch_driver.torch_fsdp.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.torch\_driver.torch\_fsdp module
+=====================================================
+
+.. automodule:: fastNLP.core.drivers.torch_driver.torch_fsdp
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.torch_driver.utils.rst b/docs/source/fastNLP.core.drivers.torch_driver.utils.rst
new file mode 100644
index 00000000..2481377d
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.torch_driver.utils.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.torch\_driver.utils module
+===============================================
+
+.. automodule:: fastNLP.core.drivers.torch_driver.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.drivers.utils.rst b/docs/source/fastNLP.core.drivers.utils.rst
new file mode 100644
index 00000000..7acb2588
--- /dev/null
+++ b/docs/source/fastNLP.core.drivers.utils.rst
@@ -0,0 +1,7 @@
+fastNLP.core.drivers.utils module
+=================================
+
+.. automodule:: fastNLP.core.drivers.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.field.rst b/docs/source/fastNLP.core.field.rst
deleted file mode 100644
index 73dad8af..00000000
--- a/docs/source/fastNLP.core.field.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-fastNLP.core.field
-==================
-
-.. automodule:: fastNLP.core.field
- :members: Padder, AutoPadder, EngChar2DPadder
- :inherited-members:
-
diff --git a/docs/source/fastNLP.core.instance.rst b/docs/source/fastNLP.core.instance.rst
deleted file mode 100644
index 010567b9..00000000
--- a/docs/source/fastNLP.core.instance.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-fastNLP.core.instance
-=====================
-
-.. automodule:: fastNLP.core.instance
- :members: Instance
- :inherited-members:
-
diff --git a/docs/source/fastNLP.core.log.handler.rst b/docs/source/fastNLP.core.log.handler.rst
new file mode 100644
index 00000000..1e124987
--- /dev/null
+++ b/docs/source/fastNLP.core.log.handler.rst
@@ -0,0 +1,7 @@
+fastNLP.core.log.handler module
+===============================
+
+.. automodule:: fastNLP.core.log.handler
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.log.highlighter.rst b/docs/source/fastNLP.core.log.highlighter.rst
new file mode 100644
index 00000000..e62336d4
--- /dev/null
+++ b/docs/source/fastNLP.core.log.highlighter.rst
@@ -0,0 +1,7 @@
+fastNLP.core.log.highlighter module
+===================================
+
+.. automodule:: fastNLP.core.log.highlighter
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.log.logger.rst b/docs/source/fastNLP.core.log.logger.rst
new file mode 100644
index 00000000..caf78c02
--- /dev/null
+++ b/docs/source/fastNLP.core.log.logger.rst
@@ -0,0 +1,7 @@
+fastNLP.core.log.logger module
+==============================
+
+.. automodule:: fastNLP.core.log.logger
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.log.print.rst b/docs/source/fastNLP.core.log.print.rst
new file mode 100644
index 00000000..700ac5b1
--- /dev/null
+++ b/docs/source/fastNLP.core.log.print.rst
@@ -0,0 +1,7 @@
+fastNLP.core.log.print module
+=============================
+
+.. automodule:: fastNLP.core.log.print
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.log.rst b/docs/source/fastNLP.core.log.rst
new file mode 100644
index 00000000..6cd67753
--- /dev/null
+++ b/docs/source/fastNLP.core.log.rst
@@ -0,0 +1,18 @@
+fastNLP.core.log package
+========================
+
+.. automodule:: fastNLP.core.log
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.log.handler
+ fastNLP.core.log.highlighter
+ fastNLP.core.log.logger
+ fastNLP.core.log.print
diff --git a/docs/source/fastNLP.core.losses.rst b/docs/source/fastNLP.core.losses.rst
deleted file mode 100644
index daf246f8..00000000
--- a/docs/source/fastNLP.core.losses.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-fastNLP.core.losses
-===================
-
-.. automodule:: fastNLP.core.losses
- :members: LossBase, LossFunc, LossInForward, CrossEntropyLoss, BCELoss, L1Loss, NLLLoss
- :inherited-members:
-
diff --git a/docs/source/fastNLP.core.metrics.accuracy.rst b/docs/source/fastNLP.core.metrics.accuracy.rst
new file mode 100644
index 00000000..76d7bfa5
--- /dev/null
+++ b/docs/source/fastNLP.core.metrics.accuracy.rst
@@ -0,0 +1,7 @@
+fastNLP.core.metrics.accuracy module
+====================================
+
+.. automodule:: fastNLP.core.metrics.accuracy
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.metrics.backend.auto_backend.rst b/docs/source/fastNLP.core.metrics.backend.auto_backend.rst
new file mode 100644
index 00000000..9275a5e7
--- /dev/null
+++ b/docs/source/fastNLP.core.metrics.backend.auto_backend.rst
@@ -0,0 +1,7 @@
+fastNLP.core.metrics.backend.auto\_backend module
+=================================================
+
+.. automodule:: fastNLP.core.metrics.backend.auto_backend
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.metrics.backend.backend.rst b/docs/source/fastNLP.core.metrics.backend.backend.rst
new file mode 100644
index 00000000..df29d243
--- /dev/null
+++ b/docs/source/fastNLP.core.metrics.backend.backend.rst
@@ -0,0 +1,7 @@
+fastNLP.core.metrics.backend.backend module
+===========================================
+
+.. automodule:: fastNLP.core.metrics.backend.backend
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.metrics.backend.jittor_backend.backend.rst b/docs/source/fastNLP.core.metrics.backend.jittor_backend.backend.rst
new file mode 100644
index 00000000..d5ad28e5
--- /dev/null
+++ b/docs/source/fastNLP.core.metrics.backend.jittor_backend.backend.rst
@@ -0,0 +1,7 @@
+fastNLP.core.metrics.backend.jittor\_backend.backend module
+===========================================================
+
+.. automodule:: fastNLP.core.metrics.backend.jittor_backend.backend
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.metrics.backend.jittor_backend.rst b/docs/source/fastNLP.core.metrics.backend.jittor_backend.rst
new file mode 100644
index 00000000..6ce8b0d4
--- /dev/null
+++ b/docs/source/fastNLP.core.metrics.backend.jittor_backend.rst
@@ -0,0 +1,15 @@
+fastNLP.core.metrics.backend.jittor\_backend package
+====================================================
+
+.. automodule:: fastNLP.core.metrics.backend.jittor_backend
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.metrics.backend.jittor_backend.backend
diff --git a/docs/source/fastNLP.core.metrics.backend.oneflow_backend.backend.rst b/docs/source/fastNLP.core.metrics.backend.oneflow_backend.backend.rst
new file mode 100644
index 00000000..2389250b
--- /dev/null
+++ b/docs/source/fastNLP.core.metrics.backend.oneflow_backend.backend.rst
@@ -0,0 +1,7 @@
+fastNLP.core.metrics.backend.oneflow\_backend.backend module
+============================================================
+
+.. automodule:: fastNLP.core.metrics.backend.oneflow_backend.backend
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.metrics.backend.oneflow_backend.rst b/docs/source/fastNLP.core.metrics.backend.oneflow_backend.rst
new file mode 100644
index 00000000..cb9e9653
--- /dev/null
+++ b/docs/source/fastNLP.core.metrics.backend.oneflow_backend.rst
@@ -0,0 +1,15 @@
+fastNLP.core.metrics.backend.oneflow\_backend package
+=====================================================
+
+.. automodule:: fastNLP.core.metrics.backend.oneflow_backend
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.metrics.backend.oneflow_backend.backend
diff --git a/docs/source/fastNLP.core.metrics.backend.paddle_backend.backend.rst b/docs/source/fastNLP.core.metrics.backend.paddle_backend.backend.rst
new file mode 100644
index 00000000..fa0aef4d
--- /dev/null
+++ b/docs/source/fastNLP.core.metrics.backend.paddle_backend.backend.rst
@@ -0,0 +1,7 @@
+fastNLP.core.metrics.backend.paddle\_backend.backend module
+===========================================================
+
+.. automodule:: fastNLP.core.metrics.backend.paddle_backend.backend
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.metrics.backend.paddle_backend.rst b/docs/source/fastNLP.core.metrics.backend.paddle_backend.rst
new file mode 100644
index 00000000..d932d4e5
--- /dev/null
+++ b/docs/source/fastNLP.core.metrics.backend.paddle_backend.rst
@@ -0,0 +1,15 @@
+fastNLP.core.metrics.backend.paddle\_backend package
+====================================================
+
+.. automodule:: fastNLP.core.metrics.backend.paddle_backend
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.metrics.backend.paddle_backend.backend
diff --git a/docs/source/fastNLP.core.metrics.backend.rst b/docs/source/fastNLP.core.metrics.backend.rst
new file mode 100644
index 00000000..4466a54a
--- /dev/null
+++ b/docs/source/fastNLP.core.metrics.backend.rst
@@ -0,0 +1,27 @@
+fastNLP.core.metrics.backend package
+====================================
+
+.. automodule:: fastNLP.core.metrics.backend
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Subpackages
+-----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.metrics.backend.jittor_backend
+ fastNLP.core.metrics.backend.oneflow_backend
+ fastNLP.core.metrics.backend.paddle_backend
+ fastNLP.core.metrics.backend.torch_backend
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.metrics.backend.auto_backend
+ fastNLP.core.metrics.backend.backend
diff --git a/docs/source/fastNLP.core.metrics.backend.torch_backend.backend.rst b/docs/source/fastNLP.core.metrics.backend.torch_backend.backend.rst
new file mode 100644
index 00000000..b7ecd71f
--- /dev/null
+++ b/docs/source/fastNLP.core.metrics.backend.torch_backend.backend.rst
@@ -0,0 +1,7 @@
+fastNLP.core.metrics.backend.torch\_backend.backend module
+==========================================================
+
+.. automodule:: fastNLP.core.metrics.backend.torch_backend.backend
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.metrics.backend.torch_backend.rst b/docs/source/fastNLP.core.metrics.backend.torch_backend.rst
new file mode 100644
index 00000000..f01efe88
--- /dev/null
+++ b/docs/source/fastNLP.core.metrics.backend.torch_backend.rst
@@ -0,0 +1,15 @@
+fastNLP.core.metrics.backend.torch\_backend package
+===================================================
+
+.. automodule:: fastNLP.core.metrics.backend.torch_backend
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.metrics.backend.torch_backend.backend
diff --git a/docs/source/fastNLP.core.metrics.classify_f1_pre_rec_metric.rst b/docs/source/fastNLP.core.metrics.classify_f1_pre_rec_metric.rst
new file mode 100644
index 00000000..e0af9e1b
--- /dev/null
+++ b/docs/source/fastNLP.core.metrics.classify_f1_pre_rec_metric.rst
@@ -0,0 +1,7 @@
+fastNLP.core.metrics.classify\_f1\_pre\_rec\_metric module
+==========================================================
+
+.. automodule:: fastNLP.core.metrics.classify_f1_pre_rec_metric
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.metrics.element.rst b/docs/source/fastNLP.core.metrics.element.rst
new file mode 100644
index 00000000..880fb405
--- /dev/null
+++ b/docs/source/fastNLP.core.metrics.element.rst
@@ -0,0 +1,7 @@
+fastNLP.core.metrics.element module
+===================================
+
+.. automodule:: fastNLP.core.metrics.element
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.metrics.metric.rst b/docs/source/fastNLP.core.metrics.metric.rst
new file mode 100644
index 00000000..607fb232
--- /dev/null
+++ b/docs/source/fastNLP.core.metrics.metric.rst
@@ -0,0 +1,7 @@
+fastNLP.core.metrics.metric module
+==================================
+
+.. automodule:: fastNLP.core.metrics.metric
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.metrics.rst b/docs/source/fastNLP.core.metrics.rst
index fe304e78..8ad6f729 100644
--- a/docs/source/fastNLP.core.metrics.rst
+++ b/docs/source/fastNLP.core.metrics.rst
@@ -1,7 +1,28 @@
-fastNLP.core.metrics
-====================
+fastNLP.core.metrics package
+============================
.. automodule:: fastNLP.core.metrics
- :members: MetricBase, AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric
- :inherited-members:
+ :members:
+ :undoc-members:
+ :show-inheritance:
+Subpackages
+-----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.metrics.backend
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.metrics.accuracy
+ fastNLP.core.metrics.classify_f1_pre_rec_metric
+ fastNLP.core.metrics.element
+ fastNLP.core.metrics.metric
+ fastNLP.core.metrics.span_f1_pre_rec_metric
+ fastNLP.core.metrics.utils
diff --git a/docs/source/fastNLP.core.metrics.span_f1_pre_rec_metric.rst b/docs/source/fastNLP.core.metrics.span_f1_pre_rec_metric.rst
new file mode 100644
index 00000000..915507ea
--- /dev/null
+++ b/docs/source/fastNLP.core.metrics.span_f1_pre_rec_metric.rst
@@ -0,0 +1,7 @@
+fastNLP.core.metrics.span\_f1\_pre\_rec\_metric module
+======================================================
+
+.. automodule:: fastNLP.core.metrics.span_f1_pre_rec_metric
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.metrics.utils.rst b/docs/source/fastNLP.core.metrics.utils.rst
new file mode 100644
index 00000000..8d78e07b
--- /dev/null
+++ b/docs/source/fastNLP.core.metrics.utils.rst
@@ -0,0 +1,7 @@
+fastNLP.core.metrics.utils module
+=================================
+
+.. automodule:: fastNLP.core.metrics.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.optimizer.rst b/docs/source/fastNLP.core.optimizer.rst
deleted file mode 100644
index 44e45c4f..00000000
--- a/docs/source/fastNLP.core.optimizer.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-fastNLP.core.optimizer
-======================
-
-.. automodule:: fastNLP.core.optimizer
- :members: Optimizer, SGD, Adam, AdamW
- :inherited-members:
-
diff --git a/docs/source/fastNLP.core.rst b/docs/source/fastNLP.core.rst
index 15fe29d5..57dac16a 100644
--- a/docs/source/fastNLP.core.rst
+++ b/docs/source/fastNLP.core.rst
@@ -1,25 +1,32 @@
-fastNLP.core
-============
+fastNLP.core package
+====================
.. automodule:: fastNLP.core
+ :members:
+ :undoc-members:
+ :show-inheritance:
-子模块
-------
+Subpackages
+-----------
.. toctree::
- :maxdepth: 1
+ :maxdepth: 4
- fastNLP.core.batch
- fastNLP.core.callback
- fastNLP.core.const
+ fastNLP.core.callbacks
+ fastNLP.core.collators
+ fastNLP.core.controllers
+ fastNLP.core.dataloaders
fastNLP.core.dataset
- fastNLP.core.field
- fastNLP.core.instance
- fastNLP.core.losses
+ fastNLP.core.drivers
+ fastNLP.core.log
fastNLP.core.metrics
- fastNLP.core.optimizer
- fastNLP.core.sampler
- fastNLP.core.tester
- fastNLP.core.trainer
+ fastNLP.core.samplers
fastNLP.core.utils
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
fastNLP.core.vocabulary
diff --git a/docs/source/fastNLP.core.sampler.rst b/docs/source/fastNLP.core.sampler.rst
deleted file mode 100644
index 56291894..00000000
--- a/docs/source/fastNLP.core.sampler.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-fastNLP.core.sampler
-====================
-
-.. automodule:: fastNLP.core.sampler
- :members: Sampler, BucketSampler, SequentialSampler, RandomSampler
- :inherited-members:
-
diff --git a/docs/source/fastNLP.core.samplers.conversion_utils.rst b/docs/source/fastNLP.core.samplers.conversion_utils.rst
new file mode 100644
index 00000000..855207cc
--- /dev/null
+++ b/docs/source/fastNLP.core.samplers.conversion_utils.rst
@@ -0,0 +1,7 @@
+fastNLP.core.samplers.conversion\_utils module
+==============================================
+
+.. automodule:: fastNLP.core.samplers.conversion_utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.samplers.mix_sampler.rst b/docs/source/fastNLP.core.samplers.mix_sampler.rst
new file mode 100644
index 00000000..9a33cfdb
--- /dev/null
+++ b/docs/source/fastNLP.core.samplers.mix_sampler.rst
@@ -0,0 +1,7 @@
+fastNLP.core.samplers.mix\_sampler module
+=========================================
+
+.. automodule:: fastNLP.core.samplers.mix_sampler
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.samplers.reproducible_batch_sampler.rst b/docs/source/fastNLP.core.samplers.reproducible_batch_sampler.rst
new file mode 100644
index 00000000..00411ec8
--- /dev/null
+++ b/docs/source/fastNLP.core.samplers.reproducible_batch_sampler.rst
@@ -0,0 +1,7 @@
+fastNLP.core.samplers.reproducible\_batch\_sampler module
+=========================================================
+
+.. automodule:: fastNLP.core.samplers.reproducible_batch_sampler
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.samplers.reproducible_sampler.rst b/docs/source/fastNLP.core.samplers.reproducible_sampler.rst
new file mode 100644
index 00000000..e244e08e
--- /dev/null
+++ b/docs/source/fastNLP.core.samplers.reproducible_sampler.rst
@@ -0,0 +1,7 @@
+fastNLP.core.samplers.reproducible\_sampler module
+==================================================
+
+.. automodule:: fastNLP.core.samplers.reproducible_sampler
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.samplers.rst b/docs/source/fastNLP.core.samplers.rst
new file mode 100644
index 00000000..9ccd9b59
--- /dev/null
+++ b/docs/source/fastNLP.core.samplers.rst
@@ -0,0 +1,20 @@
+fastNLP.core.samplers package
+=============================
+
+.. automodule:: fastNLP.core.samplers
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.samplers.conversion_utils
+ fastNLP.core.samplers.mix_sampler
+ fastNLP.core.samplers.reproducible_batch_sampler
+ fastNLP.core.samplers.reproducible_sampler
+ fastNLP.core.samplers.unrepeated_sampler
+ fastNLP.core.samplers.utils
diff --git a/docs/source/fastNLP.core.samplers.unrepeated_sampler.rst b/docs/source/fastNLP.core.samplers.unrepeated_sampler.rst
new file mode 100644
index 00000000..c76ac5bc
--- /dev/null
+++ b/docs/source/fastNLP.core.samplers.unrepeated_sampler.rst
@@ -0,0 +1,7 @@
+fastNLP.core.samplers.unrepeated\_sampler module
+================================================
+
+.. automodule:: fastNLP.core.samplers.unrepeated_sampler
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.samplers.utils.rst b/docs/source/fastNLP.core.samplers.utils.rst
new file mode 100644
index 00000000..4f36cf04
--- /dev/null
+++ b/docs/source/fastNLP.core.samplers.utils.rst
@@ -0,0 +1,7 @@
+fastNLP.core.samplers.utils module
+==================================
+
+.. automodule:: fastNLP.core.samplers.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.tester.rst b/docs/source/fastNLP.core.tester.rst
deleted file mode 100644
index 90ec2a88..00000000
--- a/docs/source/fastNLP.core.tester.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-fastNLP.core.tester
-===================
-
-.. automodule:: fastNLP.core.tester
- :members: Tester
- :inherited-members:
-
diff --git a/docs/source/fastNLP.core.trainer.rst b/docs/source/fastNLP.core.trainer.rst
deleted file mode 100644
index 92c08718..00000000
--- a/docs/source/fastNLP.core.trainer.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-fastNLP.core.trainer
-====================
-
-.. automodule:: fastNLP.core.trainer
- :members: Trainer
- :inherited-members:
-
diff --git a/docs/source/fastNLP.core.utils.cache_results.rst b/docs/source/fastNLP.core.utils.cache_results.rst
new file mode 100644
index 00000000..a23a56ee
--- /dev/null
+++ b/docs/source/fastNLP.core.utils.cache_results.rst
@@ -0,0 +1,7 @@
+fastNLP.core.utils.cache\_results module
+========================================
+
+.. automodule:: fastNLP.core.utils.cache_results
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.utils.dummy_class.rst b/docs/source/fastNLP.core.utils.dummy_class.rst
new file mode 100644
index 00000000..b4ba01cf
--- /dev/null
+++ b/docs/source/fastNLP.core.utils.dummy_class.rst
@@ -0,0 +1,7 @@
+fastNLP.core.utils.dummy\_class module
+======================================
+
+.. automodule:: fastNLP.core.utils.dummy_class
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.utils.exceptions.rst b/docs/source/fastNLP.core.utils.exceptions.rst
new file mode 100644
index 00000000..a99dc4eb
--- /dev/null
+++ b/docs/source/fastNLP.core.utils.exceptions.rst
@@ -0,0 +1,7 @@
+fastNLP.core.utils.exceptions module
+====================================
+
+.. automodule:: fastNLP.core.utils.exceptions
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.utils.jittor_utils.rst b/docs/source/fastNLP.core.utils.jittor_utils.rst
new file mode 100644
index 00000000..85241422
--- /dev/null
+++ b/docs/source/fastNLP.core.utils.jittor_utils.rst
@@ -0,0 +1,7 @@
+fastNLP.core.utils.jittor\_utils module
+=======================================
+
+.. automodule:: fastNLP.core.utils.jittor_utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.utils.oneflow_utils.rst b/docs/source/fastNLP.core.utils.oneflow_utils.rst
new file mode 100644
index 00000000..f9d11510
--- /dev/null
+++ b/docs/source/fastNLP.core.utils.oneflow_utils.rst
@@ -0,0 +1,7 @@
+fastNLP.core.utils.oneflow\_utils module
+========================================
+
+.. automodule:: fastNLP.core.utils.oneflow_utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.utils.paddle_utils.rst b/docs/source/fastNLP.core.utils.paddle_utils.rst
new file mode 100644
index 00000000..819dc3ca
--- /dev/null
+++ b/docs/source/fastNLP.core.utils.paddle_utils.rst
@@ -0,0 +1,7 @@
+fastNLP.core.utils.paddle\_utils module
+=======================================
+
+.. automodule:: fastNLP.core.utils.paddle_utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.utils.rich_progress.rst b/docs/source/fastNLP.core.utils.rich_progress.rst
new file mode 100644
index 00000000..f4660381
--- /dev/null
+++ b/docs/source/fastNLP.core.utils.rich_progress.rst
@@ -0,0 +1,7 @@
+fastNLP.core.utils.rich\_progress module
+========================================
+
+.. automodule:: fastNLP.core.utils.rich_progress
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.utils.rst b/docs/source/fastNLP.core.utils.rst
index 027a43e9..9bf76a23 100644
--- a/docs/source/fastNLP.core.utils.rst
+++ b/docs/source/fastNLP.core.utils.rst
@@ -1,7 +1,25 @@
-fastNLP.core.utils
-==================
+fastNLP.core.utils package
+==========================
.. automodule:: fastNLP.core.utils
- :members: cache_results, seq_len_to_mask, get_seq_len
- :inherited-members:
+ :members:
+ :undoc-members:
+ :show-inheritance:
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.core.utils.cache_results
+ fastNLP.core.utils.dummy_class
+ fastNLP.core.utils.exceptions
+ fastNLP.core.utils.jittor_utils
+ fastNLP.core.utils.oneflow_utils
+ fastNLP.core.utils.paddle_utils
+ fastNLP.core.utils.rich_progress
+ fastNLP.core.utils.seq_len_to_mask
+ fastNLP.core.utils.torch_utils
+ fastNLP.core.utils.tqdm_progress
+ fastNLP.core.utils.utils
diff --git a/docs/source/fastNLP.core.utils.seq_len_to_mask.rst b/docs/source/fastNLP.core.utils.seq_len_to_mask.rst
new file mode 100644
index 00000000..55188a65
--- /dev/null
+++ b/docs/source/fastNLP.core.utils.seq_len_to_mask.rst
@@ -0,0 +1,7 @@
+fastNLP.core.utils.seq\_len\_to\_mask module
+============================================
+
+.. automodule:: fastNLP.core.utils.seq_len_to_mask
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.utils.torch_utils.rst b/docs/source/fastNLP.core.utils.torch_utils.rst
new file mode 100644
index 00000000..f09e3882
--- /dev/null
+++ b/docs/source/fastNLP.core.utils.torch_utils.rst
@@ -0,0 +1,7 @@
+fastNLP.core.utils.torch\_utils module
+======================================
+
+.. automodule:: fastNLP.core.utils.torch_utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.utils.tqdm_progress.rst b/docs/source/fastNLP.core.utils.tqdm_progress.rst
new file mode 100644
index 00000000..cfcdc655
--- /dev/null
+++ b/docs/source/fastNLP.core.utils.tqdm_progress.rst
@@ -0,0 +1,7 @@
+fastNLP.core.utils.tqdm\_progress module
+========================================
+
+.. automodule:: fastNLP.core.utils.tqdm_progress
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.utils.utils.rst b/docs/source/fastNLP.core.utils.utils.rst
new file mode 100644
index 00000000..6ca577d6
--- /dev/null
+++ b/docs/source/fastNLP.core.utils.utils.rst
@@ -0,0 +1,7 @@
+fastNLP.core.utils.utils module
+===============================
+
+.. automodule:: fastNLP.core.utils.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.core.vocabulary.rst b/docs/source/fastNLP.core.vocabulary.rst
index ac07a8c6..a4339b25 100644
--- a/docs/source/fastNLP.core.vocabulary.rst
+++ b/docs/source/fastNLP.core.vocabulary.rst
@@ -1,7 +1,7 @@
-fastNLP.core.vocabulary
-=======================
+fastNLP.core.vocabulary module
+==============================
.. automodule:: fastNLP.core.vocabulary
- :members: Vocabulary, VocabularyOption
- :inherited-members:
-
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.embeddings.bert_embedding.rst b/docs/source/fastNLP.embeddings.bert_embedding.rst
deleted file mode 100644
index 1b59dc35..00000000
--- a/docs/source/fastNLP.embeddings.bert_embedding.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-fastNLP.embeddings.bert_embedding
-=================================
-
-.. automodule:: fastNLP.embeddings.bert_embedding
- :members: BertEmbedding, BertWordPieceEncoder
-
diff --git a/docs/source/fastNLP.embeddings.char_embedding.rst b/docs/source/fastNLP.embeddings.char_embedding.rst
deleted file mode 100644
index bc8d64f9..00000000
--- a/docs/source/fastNLP.embeddings.char_embedding.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-fastNLP.embeddings.char_embedding
-=================================
-
-.. automodule:: fastNLP.embeddings.char_embedding
- :members: CNNCharEmbedding, LSTMCharEmbedding
-
diff --git a/docs/source/fastNLP.embeddings.contextual_embedding.rst b/docs/source/fastNLP.embeddings.contextual_embedding.rst
deleted file mode 100644
index 74e5f5be..00000000
--- a/docs/source/fastNLP.embeddings.contextual_embedding.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-fastNLP.embeddings.contextual_embedding
-=======================================
-
-.. automodule:: fastNLP.embeddings.contextual_embedding
- :members: ContextualEmbedding
-
diff --git a/docs/source/fastNLP.embeddings.elmo_embedding.rst b/docs/source/fastNLP.embeddings.elmo_embedding.rst
deleted file mode 100644
index b8c6d41c..00000000
--- a/docs/source/fastNLP.embeddings.elmo_embedding.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-fastNLP.embeddings.elmo_embedding
-=================================
-
-.. automodule:: fastNLP.embeddings.elmo_embedding
- :members: ElmoEmbedding
-
diff --git a/docs/source/fastNLP.embeddings.embedding.rst b/docs/source/fastNLP.embeddings.embedding.rst
deleted file mode 100644
index 6793446b..00000000
--- a/docs/source/fastNLP.embeddings.embedding.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-fastNLP.embeddings.embedding
-============================
-
-.. automodule:: fastNLP.embeddings.embedding
- :members: Embedding, TokenEmbedding
-
diff --git a/docs/source/fastNLP.embeddings.rst b/docs/source/fastNLP.embeddings.rst
index f4f4a3e0..1b220f59 100644
--- a/docs/source/fastNLP.embeddings.rst
+++ b/docs/source/fastNLP.embeddings.rst
@@ -1,20 +1,15 @@
-fastNLP.embeddings
-==================
+fastNLP.embeddings package
+==========================
.. automodule:: fastNLP.embeddings
- :members: Embedding, TokenEmbedding, StaticEmbedding, ElmoEmbedding, BertEmbedding, BertWordPieceEncoder, StackEmbedding, LSTMCharEmbedding, CNNCharEmbedding, get_embeddings
+ :members:
+ :undoc-members:
+ :show-inheritance:
-子模块
-------
+Subpackages
+-----------
.. toctree::
- :maxdepth: 1
+ :maxdepth: 4
- fastNLP.embeddings.bert_embedding
- fastNLP.embeddings.char_embedding
- fastNLP.embeddings.contextual_embedding
- fastNLP.embeddings.elmo_embedding
- fastNLP.embeddings.embedding
- fastNLP.embeddings.stack_embedding
- fastNLP.embeddings.static_embedding
- fastNLP.embeddings.utils
+ fastNLP.embeddings.torch
diff --git a/docs/source/fastNLP.embeddings.stack_embedding.rst b/docs/source/fastNLP.embeddings.stack_embedding.rst
deleted file mode 100644
index a07d1ef5..00000000
--- a/docs/source/fastNLP.embeddings.stack_embedding.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-fastNLP.embeddings.stack_embedding
-==================================
-
-.. automodule:: fastNLP.embeddings.stack_embedding
- :members: StackEmbedding
-
diff --git a/docs/source/fastNLP.embeddings.static_embedding.rst b/docs/source/fastNLP.embeddings.static_embedding.rst
deleted file mode 100644
index 219ce0e5..00000000
--- a/docs/source/fastNLP.embeddings.static_embedding.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-fastNLP.embeddings.static_embedding
-===================================
-
-.. automodule:: fastNLP.embeddings.static_embedding
- :members: StaticEmbedding
-
diff --git a/docs/source/fastNLP.embeddings.torch.char_embedding.rst b/docs/source/fastNLP.embeddings.torch.char_embedding.rst
new file mode 100644
index 00000000..f0d1dad7
--- /dev/null
+++ b/docs/source/fastNLP.embeddings.torch.char_embedding.rst
@@ -0,0 +1,7 @@
+fastNLP.embeddings.torch.char\_embedding module
+===============================================
+
+.. automodule:: fastNLP.embeddings.torch.char_embedding
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.embeddings.torch.embedding.rst b/docs/source/fastNLP.embeddings.torch.embedding.rst
new file mode 100644
index 00000000..1804a70e
--- /dev/null
+++ b/docs/source/fastNLP.embeddings.torch.embedding.rst
@@ -0,0 +1,7 @@
+fastNLP.embeddings.torch.embedding module
+=========================================
+
+.. automodule:: fastNLP.embeddings.torch.embedding
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.embeddings.torch.rst b/docs/source/fastNLP.embeddings.torch.rst
new file mode 100644
index 00000000..6294e8a2
--- /dev/null
+++ b/docs/source/fastNLP.embeddings.torch.rst
@@ -0,0 +1,19 @@
+fastNLP.embeddings.torch package
+================================
+
+.. automodule:: fastNLP.embeddings.torch
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.embeddings.torch.char_embedding
+ fastNLP.embeddings.torch.embedding
+ fastNLP.embeddings.torch.stack_embedding
+ fastNLP.embeddings.torch.static_embedding
+ fastNLP.embeddings.torch.utils
diff --git a/docs/source/fastNLP.embeddings.torch.stack_embedding.rst b/docs/source/fastNLP.embeddings.torch.stack_embedding.rst
new file mode 100644
index 00000000..dab50088
--- /dev/null
+++ b/docs/source/fastNLP.embeddings.torch.stack_embedding.rst
@@ -0,0 +1,7 @@
+fastNLP.embeddings.torch.stack\_embedding module
+================================================
+
+.. automodule:: fastNLP.embeddings.torch.stack_embedding
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.embeddings.torch.static_embedding.rst b/docs/source/fastNLP.embeddings.torch.static_embedding.rst
new file mode 100644
index 00000000..fc1a2bb9
--- /dev/null
+++ b/docs/source/fastNLP.embeddings.torch.static_embedding.rst
@@ -0,0 +1,7 @@
+fastNLP.embeddings.torch.static\_embedding module
+=================================================
+
+.. automodule:: fastNLP.embeddings.torch.static_embedding
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.embeddings.torch.utils.rst b/docs/source/fastNLP.embeddings.torch.utils.rst
new file mode 100644
index 00000000..9d1fc5b5
--- /dev/null
+++ b/docs/source/fastNLP.embeddings.torch.utils.rst
@@ -0,0 +1,7 @@
+fastNLP.embeddings.torch.utils module
+=====================================
+
+.. automodule:: fastNLP.embeddings.torch.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.embeddings.utils.rst b/docs/source/fastNLP.embeddings.utils.rst
deleted file mode 100644
index 077487c1..00000000
--- a/docs/source/fastNLP.embeddings.utils.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-fastNLP.embeddings.utils
-========================
-
-.. automodule:: fastNLP.embeddings.utils
- :members: get_embeddings
-
diff --git a/docs/source/fastNLP.envs.distributed.rst b/docs/source/fastNLP.envs.distributed.rst
new file mode 100644
index 00000000..bb14f01d
--- /dev/null
+++ b/docs/source/fastNLP.envs.distributed.rst
@@ -0,0 +1,7 @@
+fastNLP.envs.distributed module
+===============================
+
+.. automodule:: fastNLP.envs.distributed
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.envs.env.rst b/docs/source/fastNLP.envs.env.rst
new file mode 100644
index 00000000..8df50d92
--- /dev/null
+++ b/docs/source/fastNLP.envs.env.rst
@@ -0,0 +1,7 @@
+fastNLP.envs.env module
+=======================
+
+.. automodule:: fastNLP.envs.env
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.envs.imports.rst b/docs/source/fastNLP.envs.imports.rst
new file mode 100644
index 00000000..eaf8465d
--- /dev/null
+++ b/docs/source/fastNLP.envs.imports.rst
@@ -0,0 +1,7 @@
+fastNLP.envs.imports module
+===========================
+
+.. automodule:: fastNLP.envs.imports
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.envs.rst b/docs/source/fastNLP.envs.rst
new file mode 100644
index 00000000..2e642ff7
--- /dev/null
+++ b/docs/source/fastNLP.envs.rst
@@ -0,0 +1,20 @@
+fastNLP.envs package
+====================
+
+.. automodule:: fastNLP.envs
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.envs.distributed
+ fastNLP.envs.env
+ fastNLP.envs.imports
+ fastNLP.envs.set_backend
+ fastNLP.envs.set_env_on_import
+ fastNLP.envs.utils
diff --git a/docs/source/fastNLP.envs.set_backend.rst b/docs/source/fastNLP.envs.set_backend.rst
new file mode 100644
index 00000000..b0d7e3b3
--- /dev/null
+++ b/docs/source/fastNLP.envs.set_backend.rst
@@ -0,0 +1,7 @@
+fastNLP.envs.set\_backend module
+================================
+
+.. automodule:: fastNLP.envs.set_backend
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.envs.set_env_on_import.rst b/docs/source/fastNLP.envs.set_env_on_import.rst
new file mode 100644
index 00000000..60f5cbac
--- /dev/null
+++ b/docs/source/fastNLP.envs.set_env_on_import.rst
@@ -0,0 +1,7 @@
+fastNLP.envs.set\_env\_on\_import module
+========================================
+
+.. automodule:: fastNLP.envs.set_env_on_import
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.envs.utils.rst b/docs/source/fastNLP.envs.utils.rst
new file mode 100644
index 00000000..fd0256a6
--- /dev/null
+++ b/docs/source/fastNLP.envs.utils.rst
@@ -0,0 +1,7 @@
+fastNLP.envs.utils module
+=========================
+
+.. automodule:: fastNLP.envs.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.data_bundle.rst b/docs/source/fastNLP.io.data_bundle.rst
index 71a921f1..05fcd9c7 100644
--- a/docs/source/fastNLP.io.data_bundle.rst
+++ b/docs/source/fastNLP.io.data_bundle.rst
@@ -1,7 +1,7 @@
-fastNLP.io.data_bundle
-======================
+fastNLP.io.data\_bundle module
+==============================
.. automodule:: fastNLP.io.data_bundle
- :members: DataBundle
- :inherited-members:
-
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.embed_loader.rst b/docs/source/fastNLP.io.embed_loader.rst
index 581f5c1b..acd67dfc 100644
--- a/docs/source/fastNLP.io.embed_loader.rst
+++ b/docs/source/fastNLP.io.embed_loader.rst
@@ -1,7 +1,7 @@
-fastNLP.io.embed_loader
-=======================
+fastNLP.io.embed\_loader module
+===============================
.. automodule:: fastNLP.io.embed_loader
- :members: EmbedLoader, EmbeddingOption
- :inherited-members:
-
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.file_reader.rst b/docs/source/fastNLP.io.file_reader.rst
new file mode 100644
index 00000000..4c2f8928
--- /dev/null
+++ b/docs/source/fastNLP.io.file_reader.rst
@@ -0,0 +1,7 @@
+fastNLP.io.file\_reader module
+==============================
+
+.. automodule:: fastNLP.io.file_reader
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.file_utils.rst b/docs/source/fastNLP.io.file_utils.rst
index 0815e068..f19d6e2e 100644
--- a/docs/source/fastNLP.io.file_utils.rst
+++ b/docs/source/fastNLP.io.file_utils.rst
@@ -1,7 +1,7 @@
-fastNLP.io.file_utils
-=====================
+fastNLP.io.file\_utils module
+=============================
.. automodule:: fastNLP.io.file_utils
- :members: cached_path, get_filepath, get_cache_path, split_filename_suffix, get_from_cache
- :inherited-members:
-
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.loader.classification.rst b/docs/source/fastNLP.io.loader.classification.rst
new file mode 100644
index 00000000..6bfd46c7
--- /dev/null
+++ b/docs/source/fastNLP.io.loader.classification.rst
@@ -0,0 +1,7 @@
+fastNLP.io.loader.classification module
+=======================================
+
+.. automodule:: fastNLP.io.loader.classification
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.loader.conll.rst b/docs/source/fastNLP.io.loader.conll.rst
new file mode 100644
index 00000000..96123649
--- /dev/null
+++ b/docs/source/fastNLP.io.loader.conll.rst
@@ -0,0 +1,7 @@
+fastNLP.io.loader.conll module
+==============================
+
+.. automodule:: fastNLP.io.loader.conll
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.loader.csv.rst b/docs/source/fastNLP.io.loader.csv.rst
new file mode 100644
index 00000000..f84f5b18
--- /dev/null
+++ b/docs/source/fastNLP.io.loader.csv.rst
@@ -0,0 +1,7 @@
+fastNLP.io.loader.csv module
+============================
+
+.. automodule:: fastNLP.io.loader.csv
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.loader.cws.rst b/docs/source/fastNLP.io.loader.cws.rst
new file mode 100644
index 00000000..a42ff1f5
--- /dev/null
+++ b/docs/source/fastNLP.io.loader.cws.rst
@@ -0,0 +1,7 @@
+fastNLP.io.loader.cws module
+============================
+
+.. automodule:: fastNLP.io.loader.cws
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.loader.json.rst b/docs/source/fastNLP.io.loader.json.rst
new file mode 100644
index 00000000..53f28586
--- /dev/null
+++ b/docs/source/fastNLP.io.loader.json.rst
@@ -0,0 +1,7 @@
+fastNLP.io.loader.json module
+=============================
+
+.. automodule:: fastNLP.io.loader.json
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.loader.loader.rst b/docs/source/fastNLP.io.loader.loader.rst
new file mode 100644
index 00000000..4c437624
--- /dev/null
+++ b/docs/source/fastNLP.io.loader.loader.rst
@@ -0,0 +1,7 @@
+fastNLP.io.loader.loader module
+===============================
+
+.. automodule:: fastNLP.io.loader.loader
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.loader.matching.rst b/docs/source/fastNLP.io.loader.matching.rst
new file mode 100644
index 00000000..5faa91cc
--- /dev/null
+++ b/docs/source/fastNLP.io.loader.matching.rst
@@ -0,0 +1,7 @@
+fastNLP.io.loader.matching module
+=================================
+
+.. automodule:: fastNLP.io.loader.matching
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.loader.qa.rst b/docs/source/fastNLP.io.loader.qa.rst
new file mode 100644
index 00000000..e3684853
--- /dev/null
+++ b/docs/source/fastNLP.io.loader.qa.rst
@@ -0,0 +1,7 @@
+fastNLP.io.loader.qa module
+===========================
+
+.. automodule:: fastNLP.io.loader.qa
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.loader.rst b/docs/source/fastNLP.io.loader.rst
index c6d0dc55..20be532a 100644
--- a/docs/source/fastNLP.io.loader.rst
+++ b/docs/source/fastNLP.io.loader.rst
@@ -1,7 +1,23 @@
-fastNLP.io.loader
-=================
+fastNLP.io.loader package
+=========================
.. automodule:: fastNLP.io.loader
- :members: Loader, YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader, ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader, MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader, CSVLoader, JsonLoader, CWSLoader, MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, CNXNLILoader, BQCorpusLoader, LCQMCLoader, CoReferenceLoader
- :inherited-members:
+ :members:
+ :undoc-members:
+ :show-inheritance:
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.io.loader.classification
+ fastNLP.io.loader.conll
+ fastNLP.io.loader.csv
+ fastNLP.io.loader.cws
+ fastNLP.io.loader.json
+ fastNLP.io.loader.loader
+ fastNLP.io.loader.matching
+ fastNLP.io.loader.qa
+ fastNLP.io.loader.summarization
diff --git a/docs/source/fastNLP.io.loader.summarization.rst b/docs/source/fastNLP.io.loader.summarization.rst
new file mode 100644
index 00000000..10338f94
--- /dev/null
+++ b/docs/source/fastNLP.io.loader.summarization.rst
@@ -0,0 +1,7 @@
+fastNLP.io.loader.summarization module
+======================================
+
+.. automodule:: fastNLP.io.loader.summarization
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.model_io.rst b/docs/source/fastNLP.io.model_io.rst
deleted file mode 100644
index 183122b1..00000000
--- a/docs/source/fastNLP.io.model_io.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-fastNLP.io.model_io
-===================
-
-.. automodule:: fastNLP.io.model_io
- :members: ModelLoader, ModelSaver
- :inherited-members:
-
diff --git a/docs/source/fastNLP.io.pipe.classification.rst b/docs/source/fastNLP.io.pipe.classification.rst
new file mode 100644
index 00000000..4428e698
--- /dev/null
+++ b/docs/source/fastNLP.io.pipe.classification.rst
@@ -0,0 +1,7 @@
+fastNLP.io.pipe.classification module
+=====================================
+
+.. automodule:: fastNLP.io.pipe.classification
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.pipe.conll.rst b/docs/source/fastNLP.io.pipe.conll.rst
new file mode 100644
index 00000000..0d481811
--- /dev/null
+++ b/docs/source/fastNLP.io.pipe.conll.rst
@@ -0,0 +1,7 @@
+fastNLP.io.pipe.conll module
+============================
+
+.. automodule:: fastNLP.io.pipe.conll
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.pipe.construct_graph.rst b/docs/source/fastNLP.io.pipe.construct_graph.rst
new file mode 100644
index 00000000..4a0c0726
--- /dev/null
+++ b/docs/source/fastNLP.io.pipe.construct_graph.rst
@@ -0,0 +1,7 @@
+fastNLP.io.pipe.construct\_graph module
+=======================================
+
+.. automodule:: fastNLP.io.pipe.construct_graph
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.pipe.cws.rst b/docs/source/fastNLP.io.pipe.cws.rst
new file mode 100644
index 00000000..46990499
--- /dev/null
+++ b/docs/source/fastNLP.io.pipe.cws.rst
@@ -0,0 +1,7 @@
+fastNLP.io.pipe.cws module
+==========================
+
+.. automodule:: fastNLP.io.pipe.cws
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.pipe.matching.rst b/docs/source/fastNLP.io.pipe.matching.rst
new file mode 100644
index 00000000..08f98a78
--- /dev/null
+++ b/docs/source/fastNLP.io.pipe.matching.rst
@@ -0,0 +1,7 @@
+fastNLP.io.pipe.matching module
+===============================
+
+.. automodule:: fastNLP.io.pipe.matching
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.pipe.pipe.rst b/docs/source/fastNLP.io.pipe.pipe.rst
new file mode 100644
index 00000000..b39f56a4
--- /dev/null
+++ b/docs/source/fastNLP.io.pipe.pipe.rst
@@ -0,0 +1,7 @@
+fastNLP.io.pipe.pipe module
+===========================
+
+.. automodule:: fastNLP.io.pipe.pipe
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.pipe.qa.rst b/docs/source/fastNLP.io.pipe.qa.rst
new file mode 100644
index 00000000..38cd6476
--- /dev/null
+++ b/docs/source/fastNLP.io.pipe.qa.rst
@@ -0,0 +1,7 @@
+fastNLP.io.pipe.qa module
+=========================
+
+.. automodule:: fastNLP.io.pipe.qa
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.pipe.rst b/docs/source/fastNLP.io.pipe.rst
index 178d35a9..53a62918 100644
--- a/docs/source/fastNLP.io.pipe.rst
+++ b/docs/source/fastNLP.io.pipe.rst
@@ -1,7 +1,23 @@
-fastNLP.io.pipe
-===============
+fastNLP.io.pipe package
+=======================
.. automodule:: fastNLP.io.pipe
- :members: Pipe, CWSPipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, WeiboSenti100kPipe, Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe, Conll2003Pipe, MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, CNXNLIBertPipe, BQCorpusBertPipe, LCQMCBertPipe, MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, LCQMCPipe, CNXNLIPipe, BQCorpusPipe, RenamePipe, GranularizePipe, MachingTruncatePipe, CoReferencePipe
- :inherited-members:
+ :members:
+ :undoc-members:
+ :show-inheritance:
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.io.pipe.classification
+ fastNLP.io.pipe.conll
+ fastNLP.io.pipe.construct_graph
+ fastNLP.io.pipe.cws
+ fastNLP.io.pipe.matching
+ fastNLP.io.pipe.pipe
+ fastNLP.io.pipe.qa
+ fastNLP.io.pipe.summarization
+ fastNLP.io.pipe.utils
diff --git a/docs/source/fastNLP.io.pipe.summarization.rst b/docs/source/fastNLP.io.pipe.summarization.rst
new file mode 100644
index 00000000..5163da8d
--- /dev/null
+++ b/docs/source/fastNLP.io.pipe.summarization.rst
@@ -0,0 +1,7 @@
+fastNLP.io.pipe.summarization module
+====================================
+
+.. automodule:: fastNLP.io.pipe.summarization
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.pipe.utils.rst b/docs/source/fastNLP.io.pipe.utils.rst
new file mode 100644
index 00000000..be5db9ab
--- /dev/null
+++ b/docs/source/fastNLP.io.pipe.utils.rst
@@ -0,0 +1,7 @@
+fastNLP.io.pipe.utils module
+============================
+
+.. automodule:: fastNLP.io.pipe.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.rst b/docs/source/fastNLP.io.rst
index 54373df4..7e1a5a67 100644
--- a/docs/source/fastNLP.io.rst
+++ b/docs/source/fastNLP.io.rst
@@ -1,20 +1,28 @@
-fastNLP.io
-==========
+fastNLP.io package
+==================
.. automodule:: fastNLP.io
- :members: DataBundle, EmbedLoader, Loader, YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader, ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader, MsraNERLoader, WeiboNERLoader, PeopleDailyNERLoader, CSVLoader, JsonLoader, CWSLoader, MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, CNXNLILoader, BQCorpusLoader, LCQMCLoader, Pipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, WeiboSenti100kPipe, Conll2003Pipe, Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, PeopleDailyPipe, WeiboNERPipe, CWSPipe, MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, ModelLoader, ModelSaver
- :inherited-members:
+ :members:
+ :undoc-members:
+ :show-inheritance:
-子模块
-------
+Subpackages
+-----------
.. toctree::
- :maxdepth: 1
+ :maxdepth: 4
+
+ fastNLP.io.loader
+ fastNLP.io.pipe
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
fastNLP.io.data_bundle
fastNLP.io.embed_loader
+ fastNLP.io.file_reader
fastNLP.io.file_utils
- fastNLP.io.loader
- fastNLP.io.model_io
- fastNLP.io.pipe
fastNLP.io.utils
diff --git a/docs/source/fastNLP.io.utils.rst b/docs/source/fastNLP.io.utils.rst
index 3bff3c45..b19d8427 100644
--- a/docs/source/fastNLP.io.utils.rst
+++ b/docs/source/fastNLP.io.utils.rst
@@ -1,7 +1,7 @@
-fastNLP.io.utils
-================
+fastNLP.io.utils module
+=======================
.. automodule:: fastNLP.io.utils
- :members: check_loader_paths
- :inherited-members:
-
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.models.bert.rst b/docs/source/fastNLP.models.bert.rst
deleted file mode 100644
index b0c813f9..00000000
--- a/docs/source/fastNLP.models.bert.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-fastNLP.models.bert
-===================
-
-.. automodule:: fastNLP.models.bert
- :members: BertForSequenceClassification, BertForSentenceMatching, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering
-
diff --git a/docs/source/fastNLP.models.biaffine_parser.rst b/docs/source/fastNLP.models.biaffine_parser.rst
deleted file mode 100644
index 395638fe..00000000
--- a/docs/source/fastNLP.models.biaffine_parser.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-fastNLP.models.biaffine_parser
-==============================
-
-.. automodule:: fastNLP.models.biaffine_parser
- :members: BiaffineParser, GraphParser
-
diff --git a/docs/source/fastNLP.models.cnn_text_classification.rst b/docs/source/fastNLP.models.cnn_text_classification.rst
deleted file mode 100644
index e9ed7ee1..00000000
--- a/docs/source/fastNLP.models.cnn_text_classification.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-fastNLP.models.cnn_text_classification
-======================================
-
-.. automodule:: fastNLP.models.cnn_text_classification
- :members: CNNText
-
diff --git a/docs/source/fastNLP.models.rst b/docs/source/fastNLP.models.rst
index 21cf41a7..eaef5a5e 100644
--- a/docs/source/fastNLP.models.rst
+++ b/docs/source/fastNLP.models.rst
@@ -1,18 +1,15 @@
-fastNLP.models
-==============
+fastNLP.models package
+======================
.. automodule:: fastNLP.models
- :members: CNNText, SeqLabeling, AdvSeqLabel, ESIM, StarTransEnc, STSeqLabel, STNLICls, STSeqCls, BiaffineParser, GraphParser, BertForSequenceClassification, BertForSentenceMatching, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering
+ :members:
+ :undoc-members:
+ :show-inheritance:
-子模块
-------
+Subpackages
+-----------
.. toctree::
- :maxdepth: 1
+ :maxdepth: 4
- fastNLP.models.bert
- fastNLP.models.biaffine_parser
- fastNLP.models.cnn_text_classification
- fastNLP.models.sequence_labeling
- fastNLP.models.snli
- fastNLP.models.star_transformer
+ fastNLP.models.torch
diff --git a/docs/source/fastNLP.models.sequence_labeling.rst b/docs/source/fastNLP.models.sequence_labeling.rst
deleted file mode 100644
index dcd1300e..00000000
--- a/docs/source/fastNLP.models.sequence_labeling.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-fastNLP.models.sequence_labeling
-================================
-
-.. automodule:: fastNLP.models.sequence_labeling
- :members: SeqLabeling, AdvSeqLabel, BiLSTMCRF
-
diff --git a/docs/source/fastNLP.models.snli.rst b/docs/source/fastNLP.models.snli.rst
deleted file mode 100644
index eed02139..00000000
--- a/docs/source/fastNLP.models.snli.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-fastNLP.models.snli
-===================
-
-.. automodule:: fastNLP.models.snli
- :members: ESIM
-
diff --git a/docs/source/fastNLP.models.star_transformer.rst b/docs/source/fastNLP.models.star_transformer.rst
deleted file mode 100644
index 80ab5b33..00000000
--- a/docs/source/fastNLP.models.star_transformer.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-fastNLP.models.star_transformer
-===============================
-
-.. automodule:: fastNLP.models.star_transformer
- :members: StarTransEnc, STNLICls, STSeqCls, STSeqLabel
-
diff --git a/docs/source/fastNLP.models.torch.biaffine_parser.rst b/docs/source/fastNLP.models.torch.biaffine_parser.rst
new file mode 100644
index 00000000..c75d7079
--- /dev/null
+++ b/docs/source/fastNLP.models.torch.biaffine_parser.rst
@@ -0,0 +1,7 @@
+fastNLP.models.torch.biaffine\_parser module
+============================================
+
+.. automodule:: fastNLP.models.torch.biaffine_parser
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.models.torch.cnn_text_classification.rst b/docs/source/fastNLP.models.torch.cnn_text_classification.rst
new file mode 100644
index 00000000..a0b4e1bd
--- /dev/null
+++ b/docs/source/fastNLP.models.torch.cnn_text_classification.rst
@@ -0,0 +1,7 @@
+fastNLP.models.torch.cnn\_text\_classification module
+=====================================================
+
+.. automodule:: fastNLP.models.torch.cnn_text_classification
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.models.torch.rst b/docs/source/fastNLP.models.torch.rst
new file mode 100644
index 00000000..7196f3f7
--- /dev/null
+++ b/docs/source/fastNLP.models.torch.rst
@@ -0,0 +1,19 @@
+fastNLP.models.torch package
+============================
+
+.. automodule:: fastNLP.models.torch
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.models.torch.biaffine_parser
+ fastNLP.models.torch.cnn_text_classification
+ fastNLP.models.torch.seq2seq_generator
+ fastNLP.models.torch.seq2seq_model
+ fastNLP.models.torch.sequence_labeling
diff --git a/docs/source/fastNLP.models.torch.seq2seq_generator.rst b/docs/source/fastNLP.models.torch.seq2seq_generator.rst
new file mode 100644
index 00000000..bc1e4ca0
--- /dev/null
+++ b/docs/source/fastNLP.models.torch.seq2seq_generator.rst
@@ -0,0 +1,7 @@
+fastNLP.models.torch.seq2seq\_generator module
+==============================================
+
+.. automodule:: fastNLP.models.torch.seq2seq_generator
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.models.torch.seq2seq_model.rst b/docs/source/fastNLP.models.torch.seq2seq_model.rst
new file mode 100644
index 00000000..802b8793
--- /dev/null
+++ b/docs/source/fastNLP.models.torch.seq2seq_model.rst
@@ -0,0 +1,7 @@
+fastNLP.models.torch.seq2seq\_model module
+==========================================
+
+.. automodule:: fastNLP.models.torch.seq2seq_model
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.models.torch.sequence_labeling.rst b/docs/source/fastNLP.models.torch.sequence_labeling.rst
new file mode 100644
index 00000000..af834f53
--- /dev/null
+++ b/docs/source/fastNLP.models.torch.sequence_labeling.rst
@@ -0,0 +1,7 @@
+fastNLP.models.torch.sequence\_labeling module
+==============================================
+
+.. automodule:: fastNLP.models.torch.sequence_labeling
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.modules.decoder.rst b/docs/source/fastNLP.modules.decoder.rst
deleted file mode 100644
index de6e0d9d..00000000
--- a/docs/source/fastNLP.modules.decoder.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-fastNLP.modules.decoder
-=======================
-
-.. automodule:: fastNLP.modules.decoder
- :members: MLP, ConditionalRandomField, viterbi_decode, allowed_transitions
-
diff --git a/docs/source/fastNLP.modules.encoder.rst b/docs/source/fastNLP.modules.encoder.rst
deleted file mode 100644
index a402cb67..00000000
--- a/docs/source/fastNLP.modules.encoder.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-fastNLP.modules.encoder
-=======================
-
-.. automodule:: fastNLP.modules.encoder
- :members: ConvolutionCharEncoder, LSTMCharEncoder, ConvMaxpool, LSTM, StarTransformer, TransformerEncoder, VarRNN, VarLSTM, VarGRU, MaxPool, MaxPoolWithMask, KMaxPool, AvgPool, AvgPoolWithMask, MultiHeadAttention, BiAttention, SelfAttention
-
diff --git a/docs/source/fastNLP.modules.mix_modules.rst b/docs/source/fastNLP.modules.mix_modules.rst
new file mode 100644
index 00000000..5351c55a
--- /dev/null
+++ b/docs/source/fastNLP.modules.mix_modules.rst
@@ -0,0 +1,15 @@
+fastNLP.modules.mix\_modules package
+====================================
+
+.. automodule:: fastNLP.modules.mix_modules
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.modules.mix_modules.utils
diff --git a/docs/source/fastNLP.modules.mix_modules.utils.rst b/docs/source/fastNLP.modules.mix_modules.utils.rst
new file mode 100644
index 00000000..9dab336d
--- /dev/null
+++ b/docs/source/fastNLP.modules.mix_modules.utils.rst
@@ -0,0 +1,7 @@
+fastNLP.modules.mix\_modules.utils module
+=========================================
+
+.. automodule:: fastNLP.modules.mix_modules.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.modules.rst b/docs/source/fastNLP.modules.rst
index 9c44e461..b686105d 100644
--- a/docs/source/fastNLP.modules.rst
+++ b/docs/source/fastNLP.modules.rst
@@ -1,15 +1,16 @@
-fastNLP.modules
-===============
+fastNLP.modules package
+=======================
.. automodule:: fastNLP.modules
- :members: ConvolutionCharEncoder, LSTMCharEncoder, ConvMaxpool, LSTM, StarTransformer, TransformerEncoder, VarRNN, VarLSTM, VarGRU, MaxPool, MaxPoolWithMask, KMaxPool, AvgPool, AvgPoolWithMask, MultiHeadAttention, MLP, ConditionalRandomField, viterbi_decode, allowed_transitions, TimestepDropout
+ :members:
+ :undoc-members:
+ :show-inheritance:
-子模块
-------
+Subpackages
+-----------
.. toctree::
- :maxdepth: 1
+ :maxdepth: 4
- fastNLP.modules.decoder
- fastNLP.modules.encoder
- fastNLP.modules.utils
+ fastNLP.modules.mix_modules
+ fastNLP.modules.torch
diff --git a/docs/source/fastNLP.modules.torch.attention.rst b/docs/source/fastNLP.modules.torch.attention.rst
new file mode 100644
index 00000000..52b7bf8c
--- /dev/null
+++ b/docs/source/fastNLP.modules.torch.attention.rst
@@ -0,0 +1,7 @@
+fastNLP.modules.torch.attention module
+======================================
+
+.. automodule:: fastNLP.modules.torch.attention
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.modules.torch.decoder.crf.rst b/docs/source/fastNLP.modules.torch.decoder.crf.rst
new file mode 100644
index 00000000..2d9e3460
--- /dev/null
+++ b/docs/source/fastNLP.modules.torch.decoder.crf.rst
@@ -0,0 +1,7 @@
+fastNLP.modules.torch.decoder.crf module
+========================================
+
+.. automodule:: fastNLP.modules.torch.decoder.crf
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.modules.torch.decoder.mlp.rst b/docs/source/fastNLP.modules.torch.decoder.mlp.rst
new file mode 100644
index 00000000..6bb9cc5c
--- /dev/null
+++ b/docs/source/fastNLP.modules.torch.decoder.mlp.rst
@@ -0,0 +1,7 @@
+fastNLP.modules.torch.decoder.mlp module
+========================================
+
+.. automodule:: fastNLP.modules.torch.decoder.mlp
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.modules.torch.decoder.rst b/docs/source/fastNLP.modules.torch.decoder.rst
new file mode 100644
index 00000000..999ab01d
--- /dev/null
+++ b/docs/source/fastNLP.modules.torch.decoder.rst
@@ -0,0 +1,18 @@
+fastNLP.modules.torch.decoder package
+=====================================
+
+.. automodule:: fastNLP.modules.torch.decoder
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.modules.torch.decoder.crf
+ fastNLP.modules.torch.decoder.mlp
+ fastNLP.modules.torch.decoder.seq2seq_decoder
+ fastNLP.modules.torch.decoder.seq2seq_state
diff --git a/docs/source/fastNLP.modules.torch.decoder.seq2seq_decoder.rst b/docs/source/fastNLP.modules.torch.decoder.seq2seq_decoder.rst
new file mode 100644
index 00000000..43c77fea
--- /dev/null
+++ b/docs/source/fastNLP.modules.torch.decoder.seq2seq_decoder.rst
@@ -0,0 +1,7 @@
+fastNLP.modules.torch.decoder.seq2seq\_decoder module
+=====================================================
+
+.. automodule:: fastNLP.modules.torch.decoder.seq2seq_decoder
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.modules.torch.decoder.seq2seq_state.rst b/docs/source/fastNLP.modules.torch.decoder.seq2seq_state.rst
new file mode 100644
index 00000000..05f730e4
--- /dev/null
+++ b/docs/source/fastNLP.modules.torch.decoder.seq2seq_state.rst
@@ -0,0 +1,7 @@
+fastNLP.modules.torch.decoder.seq2seq\_state module
+===================================================
+
+.. automodule:: fastNLP.modules.torch.decoder.seq2seq_state
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.modules.torch.dropout.rst b/docs/source/fastNLP.modules.torch.dropout.rst
new file mode 100644
index 00000000..8e4b591b
--- /dev/null
+++ b/docs/source/fastNLP.modules.torch.dropout.rst
@@ -0,0 +1,7 @@
+fastNLP.modules.torch.dropout module
+====================================
+
+.. automodule:: fastNLP.modules.torch.dropout
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.modules.torch.encoder.conv_maxpool.rst b/docs/source/fastNLP.modules.torch.encoder.conv_maxpool.rst
new file mode 100644
index 00000000..438ec076
--- /dev/null
+++ b/docs/source/fastNLP.modules.torch.encoder.conv_maxpool.rst
@@ -0,0 +1,7 @@
+fastNLP.modules.torch.encoder.conv\_maxpool module
+==================================================
+
+.. automodule:: fastNLP.modules.torch.encoder.conv_maxpool
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.modules.torch.encoder.lstm.rst b/docs/source/fastNLP.modules.torch.encoder.lstm.rst
new file mode 100644
index 00000000..918e13cb
--- /dev/null
+++ b/docs/source/fastNLP.modules.torch.encoder.lstm.rst
@@ -0,0 +1,7 @@
+fastNLP.modules.torch.encoder.lstm module
+=========================================
+
+.. automodule:: fastNLP.modules.torch.encoder.lstm
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.modules.torch.encoder.rst b/docs/source/fastNLP.modules.torch.encoder.rst
new file mode 100644
index 00000000..14120ed1
--- /dev/null
+++ b/docs/source/fastNLP.modules.torch.encoder.rst
@@ -0,0 +1,20 @@
+fastNLP.modules.torch.encoder package
+=====================================
+
+.. automodule:: fastNLP.modules.torch.encoder
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.modules.torch.encoder.conv_maxpool
+ fastNLP.modules.torch.encoder.lstm
+ fastNLP.modules.torch.encoder.seq2seq_encoder
+ fastNLP.modules.torch.encoder.star_transformer
+ fastNLP.modules.torch.encoder.transformer
+ fastNLP.modules.torch.encoder.variational_rnn
diff --git a/docs/source/fastNLP.modules.torch.encoder.seq2seq_encoder.rst b/docs/source/fastNLP.modules.torch.encoder.seq2seq_encoder.rst
new file mode 100644
index 00000000..152fc091
--- /dev/null
+++ b/docs/source/fastNLP.modules.torch.encoder.seq2seq_encoder.rst
@@ -0,0 +1,7 @@
+fastNLP.modules.torch.encoder.seq2seq\_encoder module
+=====================================================
+
+.. automodule:: fastNLP.modules.torch.encoder.seq2seq_encoder
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.modules.torch.encoder.star_transformer.rst b/docs/source/fastNLP.modules.torch.encoder.star_transformer.rst
new file mode 100644
index 00000000..3257cf13
--- /dev/null
+++ b/docs/source/fastNLP.modules.torch.encoder.star_transformer.rst
@@ -0,0 +1,7 @@
+fastNLP.modules.torch.encoder.star\_transformer module
+======================================================
+
+.. automodule:: fastNLP.modules.torch.encoder.star_transformer
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.modules.torch.encoder.transformer.rst b/docs/source/fastNLP.modules.torch.encoder.transformer.rst
new file mode 100644
index 00000000..0a3c893f
--- /dev/null
+++ b/docs/source/fastNLP.modules.torch.encoder.transformer.rst
@@ -0,0 +1,7 @@
+fastNLP.modules.torch.encoder.transformer module
+================================================
+
+.. automodule:: fastNLP.modules.torch.encoder.transformer
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.modules.torch.encoder.variational_rnn.rst b/docs/source/fastNLP.modules.torch.encoder.variational_rnn.rst
new file mode 100644
index 00000000..71a70c3a
--- /dev/null
+++ b/docs/source/fastNLP.modules.torch.encoder.variational_rnn.rst
@@ -0,0 +1,7 @@
+fastNLP.modules.torch.encoder.variational\_rnn module
+=====================================================
+
+.. automodule:: fastNLP.modules.torch.encoder.variational_rnn
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.modules.torch.generator.rst b/docs/source/fastNLP.modules.torch.generator.rst
new file mode 100644
index 00000000..783db61d
--- /dev/null
+++ b/docs/source/fastNLP.modules.torch.generator.rst
@@ -0,0 +1,15 @@
+fastNLP.modules.torch.generator package
+=======================================
+
+.. automodule:: fastNLP.modules.torch.generator
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.modules.torch.generator.seq2seq_generator
diff --git a/docs/source/fastNLP.modules.torch.generator.seq2seq_generator.rst b/docs/source/fastNLP.modules.torch.generator.seq2seq_generator.rst
new file mode 100644
index 00000000..4abc102f
--- /dev/null
+++ b/docs/source/fastNLP.modules.torch.generator.seq2seq_generator.rst
@@ -0,0 +1,7 @@
+fastNLP.modules.torch.generator.seq2seq\_generator module
+=========================================================
+
+.. automodule:: fastNLP.modules.torch.generator.seq2seq_generator
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.modules.torch.rst b/docs/source/fastNLP.modules.torch.rst
new file mode 100644
index 00000000..8e1fb0f5
--- /dev/null
+++ b/docs/source/fastNLP.modules.torch.rst
@@ -0,0 +1,26 @@
+fastNLP.modules.torch package
+=============================
+
+.. automodule:: fastNLP.modules.torch
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Subpackages
+-----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.modules.torch.decoder
+ fastNLP.modules.torch.encoder
+ fastNLP.modules.torch.generator
+
+Submodules
+----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.modules.torch.attention
+ fastNLP.modules.torch.dropout
diff --git a/docs/source/fastNLP.modules.utils.rst b/docs/source/fastNLP.modules.utils.rst
deleted file mode 100644
index 101a0f45..00000000
--- a/docs/source/fastNLP.modules.utils.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-fastNLP.modules.utils
-=====================
-
-.. automodule:: fastNLP.modules.utils
- :members: initial_parameter, summary
-
diff --git a/docs/source/fastNLP.rst b/docs/source/fastNLP.rst
index 097ad0b2..f3e245fe 100644
--- a/docs/source/fastNLP.rst
+++ b/docs/source/fastNLP.rst
@@ -1,18 +1,21 @@
-fastNLP
-=======
+fastNLP package
+===============
.. automodule:: fastNLP
- :members: Instance, FieldArray, DataSetIter, BatchIter, TorchLoaderIter, Vocabulary, DataSet, Const, Trainer, Tester, Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, CallbackException, EarlyStopError, Padder, AutoPadder, EngChar2DPadder, AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric, Optimizer, SGD, Adam, AdamW, Sampler, SequentialSampler, BucketSampler, RandomSampler, LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, cache_results, logger
- :inherited-members:
+ :members:
+ :undoc-members:
+ :show-inheritance:
-子模块
-------
+Subpackages
+-----------
.. toctree::
- :maxdepth: 1
+ :maxdepth: 4
fastNLP.core
fastNLP.embeddings
+ fastNLP.envs
fastNLP.io
fastNLP.models
fastNLP.modules
+ fastNLP.transformers
\ No newline at end of file
diff --git a/docs/source/fastNLP.transformers.rst b/docs/source/fastNLP.transformers.rst
new file mode 100644
index 00000000..023da63d
--- /dev/null
+++ b/docs/source/fastNLP.transformers.rst
@@ -0,0 +1,14 @@
+fastNLP.transformers package
+============================
+.. automodule:: fastNLP.transformers
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Subpackages
+-----------
+
+.. toctree::
+ :maxdepth: 4
+
+ fastNLP.transformers.torch
diff --git a/docs/source/fastNLP.transformers.torch.rst b/docs/source/fastNLP.transformers.torch.rst
new file mode 100644
index 00000000..9d5f0d65
--- /dev/null
+++ b/docs/source/fastNLP.transformers.torch.rst
@@ -0,0 +1,7 @@
+fastNLP.transformers.torch package
+==================================
+
+.. automodule:: fastNLP.transformers.torch
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/figures/fitlogChart.png b/docs/source/figures/fitlogChart.png
deleted file mode 100644
index 57ae1683..00000000
Binary files a/docs/source/figures/fitlogChart.png and /dev/null differ
diff --git a/docs/source/figures/fitlogTable.png b/docs/source/figures/fitlogTable.png
deleted file mode 100644
index 37551634..00000000
Binary files a/docs/source/figures/fitlogTable.png and /dev/null differ
diff --git a/docs/source/figures/sequence_labeling.PNG b/docs/source/figures/sequence_labeling.PNG
deleted file mode 100644
index 397f0a24..00000000
Binary files a/docs/source/figures/sequence_labeling.PNG and /dev/null differ
diff --git a/docs/source/figures/text_classification.png b/docs/source/figures/text_classification.png
deleted file mode 100644
index 21502708..00000000
Binary files a/docs/source/figures/text_classification.png and /dev/null differ
diff --git a/docs/source/figures/workflow.png b/docs/source/figures/workflow.png
deleted file mode 100644
index 3cf4e70e..00000000
Binary files a/docs/source/figures/workflow.png and /dev/null differ
diff --git a/docs/source/index.rst b/docs/source/index.rst
index ff77a6fc..abe73d8a 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -1,52 +1,26 @@
fastNLP 中文文档
=====================
-`fastNLP `_ 是一款轻量级的自然语言处理(NLP)工具包。你既可以用它来快速地完成一个NLP任务,
-也可以用它在研究中快速构建更复杂的模型。
-.. hint::
-
- 如果你是从 readthedocs 访问的该文档,请跳转到我们的 `最新网站 `_
-
-fastNLP具有如下的特性:
-
-- 统一的Tabular式数据容器,简化数据预处理过程;
-- 内置多种数据集的 :class:`~fastNLP.io.Loader` 和 :class:`~fastNLP.io.Pipe` ,省去预处理代码;
-- 各种方便的NLP工具,例如Embedding加载(包括 :class:`~fastNLP.embeddings.ElmoEmbedding` 和 :class:`~fastNLP.embeddings.BertEmbedding` )、中间数据cache等;
-- 部分 `数据集与预训练模型 `_ 的自动下载;
-- 提供多种神经网络组件以及复现模型(涵盖中文分词、命名实体识别、句法分析、文本分类、文本匹配、指代消解、摘要等任务);
-- :class:`~fastNLP.Trainer` 提供多种内置 :mod:`~fastNLP.core.callback` 函数,方便实验记录、异常捕获等.
-
-
-用户手册
+快速上手
----------------
.. toctree::
:maxdepth: 2
- 安装指南
- 快速入门
- 详细教程
+ tutorials
API 文档
-------------
-除了用户手册之外,你还可以通过查阅 API 文档来找到你所需要的工具。
+您可以通过查阅 API 文档来找到你所需要的工具。
.. toctree::
:titlesonly:
:maxdepth: 2
-
- fastNLP
-
-:doc:`API变动列表 `
-
-fitlog文档
-----------
+ fastNLP
-您可以 `点此 `_ 查看fitlog的文档。
-fitlog 是由我们团队开发的日志记录+代码管理的工具。
索引与搜索
==================
diff --git a/docs/source/tutorials.rst b/docs/source/tutorials.rst
new file mode 100644
index 00000000..bc32d10f
--- /dev/null
+++ b/docs/source/tutorials.rst
@@ -0,0 +1,8 @@
+fastNLP 教程系列
+================
+
+.. toctree::
+ :maxdepth: 1
+ :glob:
+
+ tutorials/*
diff --git a/docs/source/tutorials/cn_cls_example.png b/docs/source/tutorials/cn_cls_example.png
deleted file mode 100644
index 5055bb02..00000000
Binary files a/docs/source/tutorials/cn_cls_example.png and /dev/null differ
diff --git a/docs/source/tutorials/extend_1_bert_embedding.rst b/docs/source/tutorials/extend_1_bert_embedding.rst
deleted file mode 100644
index b902b8ec..00000000
--- a/docs/source/tutorials/extend_1_bert_embedding.rst
+++ /dev/null
@@ -1,231 +0,0 @@
-==============================
-BertEmbedding的各种用法
-==============================
-
-Bert自从在 `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding `_
-中被提出后,因其性能卓越受到了极大的关注,在这里我们展示一下在fastNLP中如何使用Bert进行各类任务。其中中文Bert我们使用的模型的权重来自于
-`中文Bert预训练 `_ 。
-
-为了方便大家的使用,fastNLP提供了预训练的Embedding权重及数据集的自动下载,支持自动下载的Embedding和数据集见
-`数据集 `_ 。或您可从 :doc:`/tutorials/tutorial_3_embedding` 与
-:doc:`/tutorials/tutorial_4_load_dataset` 了解更多相关信息。
-
-----------------------------------
-中文任务
-----------------------------------
-下面我们将介绍通过使用Bert来进行文本分类, 中文命名实体识别, 文本匹配, 中文问答。
-
-.. note::
-
- 本教程必须使用 GPU 进行实验,并且会花费大量的时间
-
-1. 使用Bert进行文本分类
-----------------------------------
-文本分类是指给定一段文字,判定其所属的类别。例如下面的文本情感分类
-
-.. code-block:: text
-
- 1, 商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错!
-
-这里我们使用fastNLP提供自动下载的微博分类进行测试
-
-.. code-block:: python
-
- from fastNLP.io import WeiboSenti100kPipe
- from fastNLP.embeddings import BertEmbedding
- from fastNLP.models import BertForSequenceClassification
- from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam
- import torch
-
- data_bundle =WeiboSenti100kPipe().process_from_file()
- data_bundle.rename_field('chars', 'words')
-
- # 载入BertEmbedding
- embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn-wwm', include_cls_sep=True)
-
- # 载入模型
- model = BertForSequenceClassification(embed, len(data_bundle.get_vocab('target')))
-
- # 训练模型
- device = 0 if torch.cuda.is_available() else 'cpu'
- trainer = Trainer(data_bundle.get_dataset('train'), model,
- optimizer=Adam(model_params=model.parameters(), lr=2e-5),
- loss=CrossEntropyLoss(), device=device,
- batch_size=8, dev_data=data_bundle.get_dataset('dev'),
- metrics=AccuracyMetric(), n_epochs=2, print_every=1)
- trainer.train()
-
- # 测试结果
- from fastNLP import Tester
-
- tester = Tester(data_bundle.get_dataset('test'), model, batch_size=128, metrics=AccuracyMetric())
- tester.test()
-
-输出结果::
-
- In Epoch:1/Step:12499, got best dev performance:
- AccuracyMetric: acc=0.9838
- Reloaded the best model.
- Evaluate data in 63.84 seconds!
- [tester]
- AccuracyMetric: acc=0.9815
-
-
-2. 使用Bert进行命名实体识别
-----------------------------------
-命名实体识别是给定一句话,标记出其中的实体。一般序列标注的任务都使用conll格式,conll格式是至一行中通过制表符分隔不同的内容,使用空行分隔
-两句话,例如下面的例子
-
-.. code-block:: text
-
- 中 B-ORG
- 共 I-ORG
- 中 I-ORG
- 央 I-ORG
- 致 O
- 中 B-ORG
- 国 I-ORG
- 致 I-ORG
- 公 I-ORG
- 党 I-ORG
- 十 I-ORG
- 一 I-ORG
- 大 I-ORG
- 的 O
- 贺 O
- 词 O
-
-这部分内容请参考 :doc:`/tutorials/序列标注`
-
-
-3. 使用Bert进行文本匹配
-----------------------------------
-文本匹配任务是指给定两句话判断他们的关系。比如,给定两句话判断前一句是否和后一句具有因果关系或是否是矛盾关系;或者给定两句话判断两句话是否
-具有相同的意思。这里我们使用
-
-.. code-block:: python
-
- from fastNLP.io import CNXNLIBertPipe
- from fastNLP.embeddings import BertEmbedding
- from fastNLP.models import BertForSentenceMatching
- from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam
- from fastNLP.core.optimizer import AdamW
- from fastNLP.core.callback import WarmupCallback
- from fastNLP import Tester
- import torch
-
- data_bundle = CNXNLIBertPipe().process_from_file()
- data_bundle.rename_field('chars', 'words')
- print(data_bundle)
-
- # 载入BertEmbedding
- embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn-wwm', include_cls_sep=True)
-
- # 载入模型
- model = BertForSentenceMatching(embed, len(data_bundle.get_vocab('target')))
-
- # 训练模型
- callbacks = [WarmupCallback(warmup=0.1, schedule='linear'), ]
- device = 0 if torch.cuda.is_available() else 'cpu'
- trainer = Trainer(data_bundle.get_dataset('train'), model,
- optimizer=AdamW(params=model.parameters(), lr=4e-5),
- loss=CrossEntropyLoss(), device=device,
- batch_size=8, dev_data=data_bundle.get_dataset('dev'),
- metrics=AccuracyMetric(), n_epochs=5, print_every=1,
- update_every=8, callbacks=callbacks)
- trainer.train()
-
- tester = Tester(data_bundle.get_dataset('test'), model, batch_size=8, metrics=AccuracyMetric())
- tester.test()
-
-运行结果::
-
- In Epoch:3/Step:73632, got best dev performance:
- AccuracyMetric: acc=0.781928
- Reloaded the best model.
- Evaluate data in 18.54 seconds!
- [tester]
- AccuracyMetric: acc=0.783633
-
-
-4. 使用Bert进行中文问答
-----------------------------------
-问答任务是给定一段内容,以及一个问题,需要从这段内容中找到答案。
-例如::
-
- "context": "锣鼓经是大陆传统器乐及戏曲里面常用的打击乐记谱方法,以中文字的声音模拟敲击乐的声音,纪录打击乐的各种不同的演奏方法。常
- 用的节奏型称为「锣鼓点」。而锣鼓是戏曲节奏的支柱,除了加强演员身段动作的节奏感,也作为音乐的引子和尾声,提示音乐的板式和速度,以及
- 作为唱腔和念白的伴奏,令诗句的韵律更加抑扬顿锉,段落分明。锣鼓的运用有约定俗成的程式,依照角色行当的身份、性格、情绪以及环境,配合
- 相应的锣鼓点。锣鼓亦可以模仿大自然的音响效果,如雷电、波浪等等。戏曲锣鼓所运用的敲击乐器主要分为鼓、锣、钹和板四类型:鼓类包括有单
- 皮鼓(板鼓)、大鼓、大堂鼓(唐鼓)、小堂鼓、怀鼓、花盆鼓等;锣类有大锣、小锣(手锣)、钲锣、筛锣、马锣、镗锣、云锣;钹类有铙钹、大
- 钹、小钹、水钹、齐钹、镲钹、铰子、碰钟等;打拍子用的檀板、木鱼、梆子等。因为京剧的锣鼓通常由四位乐师负责,又称为四大件,领奏的师
- 傅称为:「鼓佬」,其职责有如西方乐队的指挥,负责控制速度以及利用各种手势提示乐师演奏不同的锣鼓点。粤剧吸收了部份京剧的锣鼓,但以木鱼
- 和沙的代替了京剧的板和鼓,作为打拍子的主要乐器。以下是京剧、昆剧和粤剧锣鼓中乐器对应的口诀用字:",
- "question": "锣鼓经是什么?",
- "answers": [
- {
- "text": "大陆传统器乐及戏曲里面常用的打击乐记谱方法",
- "answer_start": 4
- },
- {
- "text": "大陆传统器乐及戏曲里面常用的打击乐记谱方法",
- "answer_start": 4
- },
- {
- "text": "大陆传统器乐及戏曲里面常用的打击乐记谱方法",
- "answer_start": 4
- }
- ]
-
-您可以通过以下的代码训练 (原文代码:`CMRC2018 `_)
-
-.. code-block:: python
-
- from fastNLP.embeddings import BertEmbedding
- from fastNLP.models import BertForQuestionAnswering
- from fastNLP.core.losses import CMRC2018Loss
- from fastNLP.core.metrics import CMRC2018Metric
- from fastNLP.io.pipe.qa import CMRC2018BertPipe
- from fastNLP import Trainer, BucketSampler
- from fastNLP import WarmupCallback, GradientClipCallback
- from fastNLP.core.optimizer import AdamW
- import torch
-
- data_bundle = CMRC2018BertPipe().process_from_file()
- data_bundle.rename_field('chars', 'words')
-
- print(data_bundle)
-
- embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn', requires_grad=True, include_cls_sep=False, auto_truncate=True,
- dropout=0.5, word_dropout=0.01)
- model = BertForQuestionAnswering(embed)
- loss = CMRC2018Loss()
- metric = CMRC2018Metric()
-
- wm_callback = WarmupCallback(schedule='linear')
- gc_callback = GradientClipCallback(clip_value=1, clip_type='norm')
- callbacks = [wm_callback, gc_callback]
-
- optimizer = AdamW(model.parameters(), lr=5e-5)
-
- device = 0 if torch.cuda.is_available() else 'cpu'
- trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer,
- sampler=BucketSampler(seq_len_field_name='context_len'),
- dev_data=data_bundle.get_dataset('dev'), metrics=metric,
- callbacks=callbacks, device=device, batch_size=6, num_workers=2, n_epochs=2, print_every=1,
- test_use_tqdm=False, update_every=10)
- trainer.train(load_best_model=False)
-
-训练结果(和原论文中报道的基本一致)::
-
- In Epoch:2/Step:1692, got best dev performance:
- CMRC2018Metric: f1=85.61, em=66.08
-
-
-----------------------------------
-代码下载
-----------------------------------
-
-.. raw:: html
-
- 点击下载 IPython Notebook 文件
diff --git a/docs/source/tutorials/extend_2_dist.rst b/docs/source/tutorials/extend_2_dist.rst
deleted file mode 100644
index b8175306..00000000
--- a/docs/source/tutorials/extend_2_dist.rst
+++ /dev/null
@@ -1,223 +0,0 @@
-Distributed Parallel Training
-=============================
-
-原理
-----
-
-随着深度学习模型越来越复杂,单个GPU可能已经无法满足正常的训练。比如BERT等预训练模型,更是在多个GPU上训练得到的。为了使用多GPU训练,Pytorch框架已经提供了
-`nn.DataParallel `_ 以及
-`nn.DistributedDataParallel `_ 两种方式的支持。
-`nn.DataParallel `_
-很容易使用,但是却有着GPU负载不均衡,单进程速度慢等缺点,无法发挥出多GPU的全部性能。因此,分布式的多GPU训练方式
-`nn.DistributedDataParallel `_
-是更好的选择。然而,因为分布式训练的特点,
-`nn.DistributedDataParallel `_
-常常难以理解和使用,也很难debug。所以,在使用分布式训练之前,需要理解它的原理。
-
-在使用
-`nn.DistributedDataParallel `_
-时,模型会被复制到所有使用的GPU,通常每个GPU上存有一个模型,并被一个单独的进程控制。这样有N块GPU,就会产生N个进程。当训练一个batch时,这一batch会被分为N份,每个进程会使用batch的一部分进行训练,然后在必要时进行同步,并通过网络传输需要同步的数据。这时,只有模型的梯度会被同步,而模型的参数不会,所以能缓解大部分的网络传输压力,网络传输不再是训练速度的瓶颈之一。你可能会好奇,不同步模型的参数,怎么保证不同进程所训练的模型相同?只要每个进程初始的模型是同一个,具有相同的参数,而之后每次更新,都使用相同的梯度,就能保证梯度更新后的模型也具有相同的参数了。
-
-为了让每个进程的模型初始化完全相同,通常这N个进程都是由单个进程复制而来的,这时需要对分布式的进程进行初始化,建立相互通信的机制。在
-Pytorch 中,我们用
-`distributed.init_process_group `_
-函数来完成,需要在程序开头就加入这一步骤。初始化完成后,每一个进程用唯一的编号
-``rank`` 进行区分,从 0 到 N-1递增,一般地,我们将 ``rank`` 为 0
-的进程当作主进程,而其他 ``rank`` 的进程为子进程。每个进程还要知道
-``world_size`` ,即分布式训练的总进程数
-N。训练时,每个进程使用batch的一部分,互相不能重复,这里通过
-`nn.utils.data.DistributedSampler `_
-来实现。
-
-使用方式
---------
-
-Pytorch的分布式训练使用起来非常麻烦,难以理解,可以从给出的\ `官方教程 `_ \ 中看到。而\ ``fastNLP``
-提供了
-``DistTrainer``\ ,将大部分的分布式训练的细节进行了封装,只需简单的改动训练代码,就能直接用上分布式训练。那么,具体怎么将普通的训练代码改成支持分布式训练的代码呢。下面我们来讲一讲分布式训练的完整流程。通常,分布式程序的多个进程是单个进程的复制。假设我们用N个GPU进行分布式训练,我们需要启动N个进程,这时,在命令行使用:
-
-.. code:: shell
-
- python -m torch.distributed.launch --nproc_per_node=N train_script.py --args
-
-其中\ ``N``\ 是需要启动的进程数,\ ``train_script.py``\ 为训练代码,\ ``--args``\ 是自定义的命令行参数。在启动了N个进程之后,如果我们在\ ``train_script.py``\ 的训练代码中正常配置,分布式训练就能正常进行。
-
-此外,还可以使用环境变量\ ``CUDA_VISIBLE_DEVICES``\ 设置指定的GPU,比如在8卡机器上使用编号为4,5,6,7的4块GPU:
-
-.. code:: shell
-
- CUDA_VISIBLE_DEVICES=4,5,6,7 python -m torch.distributed.launch --nproc_per_node=N train_script.py --args
-
-在 ``train_script.py``
-训练代码中,有一些必须的配置。为了清晰的叙述,这里放一个简单的分布式训练代码,省去多余细节:
-
-.. code:: python
-
- import torch.distributed as dist
- from fastNLP import DistTrainer, get_local_rank
- import fastNLP as fnlp
-
- def main(options):
- # options为训练所需的参数,batch_size等
-
- set_seed(options.seed)
-
- # 初始化分布式进程
- dist.init_process_group('nccl')
-
- ######## 读取数据
- if get_local_rank() != 0:
- dist.barrier() # 先让主进程(rank==0)先执行,进行数据处理,预训模型参数下载等操作,然后保存cache
- data = get_processed_data()
- model = get_model(data.get_vocab("words"), data.get_vocab("target"))
- if get_local_rank() == 0:
- dist.barrier() # 主进程执行完后,其余进程开始读取cache
- ########
-
- # 初始化Trainer,训练等,与普通训练差别不大
- def get_trainer(model, data):
- # 注意设置的callback有两种,一种只在主进程执行,一种在所有进程都执行
- callbacks_master = [fnlp.FitlogCallback()]
- callbacks_all = [fnlp.WarmupCallback(warmup=options.warmup)]
- trainer = DistTrainer(
- save_path='save',
- train_data=data.get_dataset("train"),
- dev_data=data.get_dataset("dev"),
- model=model,
- loss=fnlp.CrossEntropyLoss(),
- metrics=fnlp.AccuracyMetric(),
- metric_key="acc",
- optimizer=fnlp.AdamW(model.parameters(), lr=options.lr),
- callbacks_master=callbacks_master, # 仅在主进程执行(如模型保存,日志记录)
- callbacks_all=callbacks_all, # 在所有进程都执行(如梯度裁剪,学习率衰减)
- batch_size_per_gpu=options.batch_size, # 指定每个GPU的batch大小
- update_every=options.update,
- n_epochs=options.epochs,
- use_tqdm=True,
- )
- return trainer
-
- trainer = get_trainer(model, data)
- trainer.train()
-
-指定进程编号
-^^^^^^^^^^^^
-
-首先,为了区分不同的进程,初始时需要对每个进程传入\ ``rank``\ 。这里一般分为\ ``node_rank``\ 和\ ``local_rank``\ ,分别表示进程处于哪一机器以及同机器上处于第几进程。如果在单一机器上,\ ``node_rank``\ 可以省略。\ ``local_rank``\ 一般通过命令行参数\ ``--local_rank``\ 传入,为\ ``int``\ 类型。也可以通过环境变量传入\ ``local_rank``\ ,只需在\ ``torch.distributed.launch``\ 时,使用\ ``--use_env``\ 参数。无论哪种方式,在训练脚本中,都要获取到\ ``local_rank``\ ,用于初始化分布式通信,以及区分进程。如果你使用\ ``fastNLP``\ ,可以通过\ ``fastNLP.get_local_rank``\ 来得到\ ``local_rank``\ 。
-
-初始化进程
-^^^^^^^^^^
-
-在获取了\ ``local_rank``\ 等重要参数后,在开始训练前,我们需要建立不同进程的通信和同步机制。这时我们使用\ `torch.distributed.init_process_group `_
-来完成。通常,我们只需要 ``torch.distributed.init_process_group('nccl')``
-来指定使用\ ``nccl``\ 后端来进行同步即可。其他参数程序将读取环境变量自动设置。如果想手动设置这些参数,比如,使用TCP进行通信,可以设置:
-
-.. code:: python
-
- init_process_group('nccl', init_method='tcp://localhost:55678',
- rank=args.rank, world_size=N)
-
-或者使用文件进行通信:
-
-.. code:: python
-
- init_process_group('nccl', init_method='file:///mnt/nfs/sharedfile',
- world_size=N, rank=args.rank)
-
-注意,此时必须显式指定\ ``world_size``\ 和\ ``rank``\ ,具体可以参考
-`torch.distributed.init_process_group `_
-的使用文档。
-
-在初始化分布式通信后,再初始化\ ``DistTrainer``\ ,传入数据和模型,就完成了分布式训练的代码。代码修改完成后,使用上面给出的命令行启动脚本,就能成功运行分布式训练。但是,如果数据处理,训练中的自定义操作比较复杂,则可能需要额外的代码修改。下面列出一些需要特别注意的地方,在使用分布式训练前,请仔细检查这些事项。
-
-注意事项
---------
-
-在执行完
-`torch.distributed.init_process_group `_
-后,我们就可以在不同进程间完成传输数据,进行同步等操作。这些操作都可以在\ `torch.distributed `_
-中找到。其中,最重要的是
-`barrier `_
-以及
-`get_rank `_
-操作。对于训练而言,我们关心的是读入数据,记录日志,模型初始化,模型参数更新,模型保存等操作。这些操作大多是读写操作,在多进程状态下,这些操作都必须小心进行,否则可能出现难以预料的bug。而在\ ``fastNLP``\ 中,大部分操作都封装在
-``DistTrainer`` 中,只需保证数据读入和模型初始化正确即可完成训练。
-
-写操作
-^^^^^^
-
-一般而言,读入操作需要在每一个进程都执行,因为每个进程都要使用读入的数据和模型参数进行训练。而写出操作只需在其中一个进程(通常为主进程)执行,因为每一个进程保存的模型都相同,都处于同一训练状态。所以,通常单进程的训练脚本中,只需要修改写出操作的部分,通过加入对进程\ ``rank``\ 的判断,仅让其中一个进程执行写操作:
-
-.. code:: python
-
- import torch.distributed as dist
-
- # 仅在主进程才执行
- if dist.get_rank() == 0:
- do_wirte_op() # 一些写操作
- dist.barrier() # 确保写完成后,所有进程再执行(若进程无需读入写出的数据,可以省去)
-
-若使用\ ``fastNLP``\ 中的\ ``DistTrainer``\ ,也可以这样写:
-
-.. code:: python
-
- # 判断是否是主进程的trainer
- if trainer.is_master:
- do_wirte_op()
- dist.barrier()
-
-读操作
-^^^^^^
-
-然而有些时候,我们需要其中一个进程先执行某些操作,等这一进程执行完后,其它进程再执行这一操作。比如,在读入数据时,我们有时需要从网上下载,再处理,将处理好的数据保存,供反复使用。这时,我们不需要所有进程都去下载和处理数据,只需要主进程进行这些操作,其它进程等待。直到处理好的数据被保存后,其他进程再从保存位置直接读入数据。这里可以参考范例代码中的读取数据:
-
-.. code:: python
-
- if dist.get_rank() != 0:
- dist.barrier() # 先让主进程(rank==0)先执行,进行数据处理,预训模型参数下载等操作,然后保存cache
-
- # 这里会自动处理数据,或直接读取保存的cache
- data = get_processed_data()
- model = get_model(data.get_vocab("words"), data.get_vocab("target"))
-
- if dist.get_rank() == 0:
- dist.barrier() # 主进程执行完后,其余进程开始读取cache
-
-也可以显式的将主进程和其它进程的操作分开:
-
-.. code:: python
-
- if dist.get_rank() == 0:
- data = do_data_processing() # 数据处理
- dist.barrier()
- else:
- dist.barrier()
- data = load_processed_data() # 读取cache
-
-日志操作
-^^^^^^^^
-
-通常,我们需要知道训练的状态,如当前在第几个epoch,模型当前的loss等等。单进程训练时,我们可以直接使用\ ``print``\ 将这些信息输出到命令行或日志文件。然而,在多进程时,\ ``print``\ 会导致同样的信息在每一进程都输出,造成问题。这一问题和写操作类似,也可以通过判断进程的编号之后再输出。问题是,日志通常在训练的很多地方都有输出,逐一加上判断代码是非常繁琐的。这里,建议统一修改为:
-
-.. code:: python
-
- from fastNLP import logger
- logger.info('....') # 替换print
-
-在\ ``DistTrainer``\ 中,主进程的\ ``logger``\ 级别为\ ``INFO``\ ,而其它进程为\ ``WARNING``\ 。这样级别为\ ``INFO``\ 的信息只会在主进程输出,不会造成日志重复问题。若需要其它进程中的信息,可以使用\ ``logger.warning``\ 。
-
-注意,\ ``logger``\ 的级别设置只有初始化了\ ``DistTrainer``\ 后才能生效。如果想要在初始化进程后就生效,需要在分布式通信初始化后,执行\ ``init_logger_dist``\ 。
-
-Callback
-^^^^^^^^
-
-``fastNLP``\ 的一个特色是可以使用\ ``Callback``\ 在训练时完成各种自定义操作。而这一特色在\ ``DistTrainer``\ 中得以保留。但是,这时需要特别注意\ ``Callback``\ 是否只需要在主进程执行。一些\ ``Callback``\ ,比如调整学习率,梯度裁剪等,会改变模型的状态,因此需要在所有进程上都执行,将它们通过\ ``callback_all``\ 参数传入\ ``DistTrainer``\ 。而另一些\ ``Callback``\ ,比如\ ``fitlog``\ ,保存模型,不会改变模型的状态,而是进行数据写操作,因此仅在主进程上执行,将它们通过\ ``callback_master``\ 传入。
-
-在自定义\ ``Callback``\ 时,请遵循一个原则,改变训练或模型状态的操作在所有进程中执行,而数据写到硬盘请在主进程单独进行。这样就能避免进程间失去同步,或者磁盘写操作的冲突。
-
-Debug
-^^^^^
-
-多进程的程序很难进行debug,如果出现问题,可以先参考报错信息进行处理。也可以在程序中多输出日志,定位问题。具体情况,具体分析。在debug时,要多考虑进程同步和异步的操作,判断问题是程序本身导致的,还是由进程间没有同步而产生。
-
-其中,有一个常见问题是程序卡住不动。具体表现为训练暂停,程序没有输出,但是GPU利用率保持100%。这一问题是由进程失去同步导致的。这时只能手动\ ``kill``\ GPU上残留的进程,再检查代码。需要检查进程同步的位置,比如模型\ ``backward()``\ 时,\ ``barrier()``\ 时等。同时,也要检查主进程与其它进程操作不同的位置,比如存储模型,evaluate模型时等。注意,失去同步的位置可能并不是程序卡住的位置,所以需要细致的检查。
diff --git a/docs/source/tutorials/extend_3_fitlog.rst b/docs/source/tutorials/extend_3_fitlog.rst
deleted file mode 100644
index 152e18fe..00000000
--- a/docs/source/tutorials/extend_3_fitlog.rst
+++ /dev/null
@@ -1,122 +0,0 @@
-============================================
-使用fitlog 辅助 fastNLP 进行科研
-============================================
-
-本文介绍结合使用 fastNLP 和 fitlog 进行科研的方法。
-
-首先,我们需要安装 `fitlog `_ 。你需要确认你的电脑中没有其它名为 `fitlog` 的命令。
-
-我们从命令行中进入到一个文件夹,现在我们要在文件夹中创建我们的 fastNLP 项目。你可以在命令行输入 `fitlog init test1` ,
-然后你会看到如下提示::
-
- Initialized empty Git repository in /Users/fdujyn/workspaces/test1/.git/
- Auto commit by fitlog
- Initialized empty Git repository in /Users/fdujyn/workspaces/test1/.git/
- Fitlog project test1 is initialized.
-
-这表明你已经创建成功了项目文件夹,并且在项目文件夹中已经初始化了 Git。如果你不想初始化 Git,
-可以参考文档 `命令行工具 `_
-
-现在我们进入你创建的项目文件夹 test1 中,可以看到有一个名为 logs 的文件夹,后面我们将会在里面存放你的实验记录。
-同时也有一个名为 main.py 的文件,这是我们推荐你使用的训练入口文件。文件的内容如下::
-
- import fitlog
-
- fitlog.commit(__file__) # auto commit your codes
- fitlog.add_hyper_in_file (__file__) # record your hyperparameters
-
- """
- Your training code here, you may use these functions to log your result:
- fitlog.add_hyper()
- fitlog.add_loss()
- fitlog.add_metric()
- fitlog.add_best_metric()
- ......
- """
-
- fitlog.finish() # finish the logging
-
-我们推荐你保留除注释外的四行代码,它们有助于你的实验,
-他们的具体用处参见文档 `用户 API `_
-
-我们假定你要进行前两个教程中的实验,并已经把数据复制到了项目根目录下的 tutorial_sample_dataset.csv 文件中。
-现在我们编写如下的训练代码,使用 :class:`~fastNLP.core.callback.FitlogCallback` 进行实验记录保存::
-
- import fitlog
- from fastNLP import Vocabulary, Trainer, CrossEntropyLoss, AccuracyMetric
- from fastNLP.io import CSVLoader
- from fastNLP.models import CNNText
- from fastNLP.core.callback import FitlogCallback
-
- fitlog.commit(__file__) # auto commit your codes
- fitlog.add_hyper_in_file (__file__) # record your hyperparameters
-
- ############hyper
- word_embed = 50
- dropout = 0.1
- ############hyper
-
- loader = CSVLoader(headers=('raw_sentence', 'label'), sep='\t')
- dataset = loader.load("tutorial_sample_dataset.csv")
-
- dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='sentence')
- dataset.apply(lambda x: x['sentence'].split(), new_field_name='words', is_input=True)
- dataset.apply(lambda x: int(x['label']), new_field_name='target', is_target=True)
- vocab = Vocabulary(min_freq=2).from_dataset(dataset, field_name='words')
- vocab.index_dataset(dataset, field_name='words',new_field_name='words')
-
- model = CNNText((len(vocab),word_embed), num_classes=5, padding=2, dropout=dropout)
-
- train_dev_data, test_data = dataset.split(0.1)
- train_data, dev_data = train_dev_data.split(0.1)
-
- trainer = Trainer(model=model, train_data=train_data, dev_data=dev_data,
- loss=CrossEntropyLoss(), metrics=AccuracyMetric(),
- callbacks=[FitlogCallback(test_data)])
- trainer.train()
-
- fitlog.finish() # finish the logging
-
-用命令行在项目目录下执行 `python main.py` 之后,输出结果如下::
-
- Auto commit by fitlog
- input fields after batch(if batch size is 2):
- words: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 11])
- target fields after batch(if batch size is 2):
- target: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2])
-
- training epochs started 2019-05-23-21-11-51
- Evaluation at Epoch 1/10. Step:2/20. AccuracyMetric: acc=0.285714
-
- Evaluation at Epoch 2/10. Step:4/20. AccuracyMetric: acc=0.285714
-
- Evaluation at Epoch 3/10. Step:6/20. AccuracyMetric: acc=0.285714
-
- Evaluation at Epoch 4/10. Step:8/20. AccuracyMetric: acc=0.428571
-
- Evaluation at Epoch 5/10. Step:10/20. AccuracyMetric: acc=0.571429
-
- Evaluation at Epoch 6/10. Step:12/20. AccuracyMetric: acc=0.571429
-
- Evaluation at Epoch 7/10. Step:14/20. AccuracyMetric: acc=0.285714
-
- Evaluation at Epoch 8/10. Step:16/20. AccuracyMetric: acc=0.142857
-
- Evaluation at Epoch 9/10. Step:18/20. AccuracyMetric: acc=0.285714
-
- Evaluation at Epoch 10/10. Step:20/20. AccuracyMetric: acc=0.571429
-
-
- In Epoch:5/Step:10, got best dev performance:AccuracyMetric: acc=0.571429
- Reloaded the best model.
-
-现在,我们在项目目录下输入 `fitlog log logs` ,命令行会启动一个网页,默认 url 为 ``0.0.0.0:5000`` 。
-我们在浏览器中打开网页,可以看到如下的统计表格:
-
-.. image:: ../figures/fitlogTable.png
-
-如果我们点击action中的最后一个键钮,可以看到详细的 loss 图:
-
-.. image:: ../figures/fitlogChart.png
-
-更多的教程还在编写中,敬请期待~
\ No newline at end of file
diff --git a/docs/source/tutorials/fastnlp_torch_tutorial.ipynb b/docs/source/tutorials/fastnlp_torch_tutorial.ipynb
new file mode 100644
index 00000000..9633ac7f
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_torch_tutorial.ipynb
@@ -0,0 +1,869 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "6011adf8",
+ "metadata": {},
+ "source": [
+ "# 10 分钟快速上手 fastNLP torch\n",
+ "\n",
+ "在这个例子中,我们将使用BERT来解决conll2003数据集中的命名实体识别任务。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "e166c051",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "--2022-07-07 10:12:29-- https://data.deepai.org/conll2003.zip\n",
+ "Resolving data.deepai.org (data.deepai.org)... 138.201.36.183\n",
+ "Connecting to data.deepai.org (data.deepai.org)|138.201.36.183|:443... connected.\n",
+ "WARNING: cannot verify data.deepai.org's certificate, issued by ‘CN=R3,O=Let's Encrypt,C=US’:\n",
+ " Issued certificate has expired.\n",
+ "HTTP request sent, awaiting response... 200 OK\n",
+ "Length: 982975 (960K) [application/x-zip-compressed]\n",
+ "Saving to: ‘conll2003.zip’\n",
+ "\n",
+ "conll2003.zip 100%[===================>] 959.94K 653KB/s in 1.5s \n",
+ "\n",
+ "2022-07-07 10:12:32 (653 KB/s) - ‘conll2003.zip’ saved [982975/982975]\n",
+ "\n",
+ "Archive: conll2003.zip\n",
+ " inflating: conll2003/metadata \n",
+ " inflating: conll2003/test.txt \n",
+ " inflating: conll2003/train.txt \n",
+ " inflating: conll2003/valid.txt \n"
+ ]
+ }
+ ],
+ "source": [
+ "# Linux/Mac 下载数据,并解压\n",
+ "import platform\n",
+ "if platform.system() != \"Windows\":\n",
+ " !wget https://data.deepai.org/conll2003.zip --no-check-certificate -O conll2003.zip\n",
+ " !unzip conll2003.zip -d conll2003\n",
+ "# Windows用户请通过复制该url到浏览器下载该数据并解压"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f7acbf1f",
+ "metadata": {},
+ "source": [
+ "## 目录\n",
+ "接下来我们将按照以下的内容介绍在如何通过fastNLP减少工程性代码的撰写 \n",
+ "- 1. 数据加载\n",
+ "- 2. 数据预处理、数据缓存\n",
+ "- 3. DataLoader\n",
+ "- 4. 模型准备\n",
+ "- 5. Trainer的使用\n",
+ "- 6. Evaluator的使用\n",
+ "- 7. 其它【待补充】\n",
+ " - 7.1 使用多卡进行训练、评测\n",
+ " - 7.2 使用ZeRO优化\n",
+ " - 7.3 通过overfit测试快速验证模型\n",
+ " - 7.4 复杂Monitor的使用\n",
+ " - 7.5 训练过程中,使用不同的测试函数\n",
+ " - 7.6 更有效率的Sampler\n",
+ " - 7.7 保存模型\n",
+ " - 7.8 断点重训\n",
+ " - 7.9 使用huggingface datasets\n",
+ " - 7.10 使用torchmetrics来作为metric\n",
+ " - 7.11 将预测结果写出到文件\n",
+ " - 7.12 混合 dataset 训练\n",
+ " - 7.13 logger的使用\n",
+ " - 7.14 自定义分布式 Metric 。\n",
+ " - 7.15 通过batch_step_fn实现R-Drop"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0657dfba",
+ "metadata": {},
+ "source": [
+ "#### 1. 数据加载\n",
+ "目前在``conll2003``目录下有``train.txt``, ``test.txt``与``valid.txt``三个文件,文件的格式为[conll格式](https://universaldependencies.org/format.html),其编码格式为 [BIO](https://blog.csdn.net/HappyRocking/article/details/79716212) 类型。可以通过继承 fastNLP.io.Loader 来简化加载过程,继承了 Loader 函数后,只需要在实现读取单个文件 _load() 函数即可。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "c557f0ba",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import sys\n",
+ "sys.path.append('../..')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "6f59e438",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "In total 3 datasets:\n",
+ "\ttrain has 14987 instances.\n",
+ "\ttest has 3684 instances.\n",
+ "\tdev has 3466 instances.\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "from fastNLP import DataSet, Instance\n",
+ "from fastNLP.io import Loader\n",
+ "\n",
+ "\n",
+ "# 继承Loader之后,我们只需要实现其中_load()方法,_load()方法传入一个文件路径,返回一个fastNLP DataSet对象,其目的是读取一个文件。\n",
+ "class ConllLoader(Loader):\n",
+ " def _load(self, path):\n",
+ " ds = DataSet()\n",
+ " with open(path, 'r') as f:\n",
+ " segments = []\n",
+ " for line in f:\n",
+ " line = line.strip()\n",
+ " if line == '': # 如果为空行,说明需要切换到下一句了。\n",
+ " if segments:\n",
+ " raw_words = [s[0] for s in segments]\n",
+ " raw_target = [s[1] for s in segments]\n",
+ " # 将一个 sample 插入到 DataSet中\n",
+ " ds.append(Instance(raw_words=raw_words, raw_target=raw_target)) \n",
+ " segments = []\n",
+ " else:\n",
+ " parts = line.split()\n",
+ " assert len(parts)==4\n",
+ " segments.append([parts[0], parts[-1]])\n",
+ " return ds\n",
+ " \n",
+ "\n",
+ "# 直接使用 load() 方法加载数据集, 返回的 data_bundle 是一个 fastNLP.io.DataBundle 对象,该对象相当于将多个 dataset 放置在一起,\n",
+ "# 可以方便之后的预处理,DataBundle 支持的接口可以在 !!! 查看。\n",
+ "data_bundle = ConllLoader().load({\n",
+ " 'train': 'conll2003/train.txt',\n",
+ " 'test': 'conll2003/test.txt',\n",
+ " 'dev': 'conll2003/valid.txt'\n",
+ "})\n",
+ "\"\"\"\n",
+ "也可以通过 ConllLoader().load('conll2003/') 来读取,其原理是load()函数将尝试从'conll2003/'文件夹下寻找文件名称中包含了\n",
+ "'train'、'test'和'dev'的文件,并分别读取将其命名为'train'、'test'和'dev'(如文件夹中同一个关键字出现在了多个文件名中将导致报错,\n",
+ "此时请通过dict的方式传入路径信息)。但在我们这里的数据里,没有文件包含dev,所以无法直接使用文件夹读取,转而通过dict的方式传入读取的路径,\n",
+ "该dict的key也将作为读取的数据集的名称,value即对应的文件路径。\n",
+ "\"\"\"\n",
+ "\n",
+ "print(data_bundle) # 打印 data_bundle 可以查看包含的 DataSet \n",
+ "# data_bundle.get_dataset('train') # 可以获取单个 dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "57ae314d",
+ "metadata": {},
+ "source": [
+ "#### 2. 数据预处理\n",
+ "接下来,我们将演示如何通过fastNLP提供的apply函数方便快捷地进行预处理。我们需要进行的预处理操作有: \n",
+ "(1)使用BertTokenizer将文本转换为index;同时记录每个word被bpe之后第一个bpe的index,用于得到word的hidden state; \n",
+ "(2)使用[Vocabulary](../fastNLP)来将raw_target转换为序号。 "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "96389988",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "c3bd41a323c94a41b409d29a5d4079b6",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "IOPub message rate exceeded.\n",
+ "The notebook server will temporarily stop sending output\n",
+ "to the client in order to avoid crashing it.\n",
+ "To change this limit, set the config variable\n",
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
+ "\n",
+ "Current values:\n",
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "[10:48:13] INFO Save cache to /remote-home/hyan01/exps/fastNLP/fastN cache_results.py:332\n",
+ " LP/demo/torch_tutorial/caches/c7f74559_cache.pkl. \n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m[10:48:13]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Save cache to \u001b[35m/remote-home/hyan01/exps/fastNLP/fastN\u001b[0m \u001b]8;id=831330;file://../../fastNLP/core/utils/cache_results.py\u001b\\\u001b[2mcache_results.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=609545;file://../../fastNLP/core/utils/cache_results.py#332\u001b\\\u001b[2m332\u001b[0m\u001b]8;;\u001b\\\n",
+ "\u001b[2;36m \u001b[0m \u001b[35mLP/demo/torch_tutorial/caches/\u001b[0m\u001b[95mc7f74559_cache.pkl.\u001b[0m \u001b[2m \u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# fastNLP 中提供了BERT, RoBERTa, GPT, BART 模型,更多的预训练模型请直接使用transformers\n",
+ "from fastNLP.transformers.torch import BertTokenizer\n",
+ "from fastNLP import cache_results, Vocabulary\n",
+ "\n",
+ "# 使用cache_results来装饰函数,会将函数的返回结果缓存到'caches/{param_hash_id}_cache.pkl'路径中(其中{param_hash_id}是根据\n",
+ "# 传递给 process_data 函数参数决定的,因此当函数的参数变化时,会再生成新的缓存文件。如果需要重新生成新的缓存,(a) 可以在调用process_data\n",
+ "# 函数时,额外传入一个_refresh=True的参数; 或者(b)删除相应的缓存文件。此外,保存结果时,cache_results默认还会\n",
+ "# 记录 process_data 函数源码的hash值,当其源码发生了变动,直接读取缓存会发出警告,以防止在修改预处理代码之后,忘记刷新缓存。)\n",
+ "@cache_results('caches/cache.pkl')\n",
+ "def process_data(data_bundle, model_name):\n",
+ " tokenizer = BertTokenizer.from_pretrained(model_name)\n",
+ " def bpe(raw_words):\n",
+ " bpes = [tokenizer.cls_token_id]\n",
+ " first = [0]\n",
+ " first_index = 1 # 记录第一个bpe的位置\n",
+ " for word in raw_words:\n",
+ " bpe = tokenizer.encode(word, add_special_tokens=False)\n",
+ " bpes.extend(bpe)\n",
+ " first.append(first_index)\n",
+ " first_index += len(bpe)\n",
+ " bpes.append(tokenizer.sep_token_id)\n",
+ " first.append(first_index)\n",
+ " return {'input_ids': bpes, 'input_len': len(bpes), 'first': first, 'first_len': len(raw_words)}\n",
+ " # 对data_bundle中每个dataset的每一条数据中的raw_words使用bpe函数,并且将返回的结果加入到每条数据中。\n",
+ " data_bundle.apply_field_more(bpe, field_name='raw_words', num_proc=4)\n",
+ " # 对应我们还有 apply_field() 函数,该函数和 apply_field_more() 的区别在于传入到 apply_field() 中的函数应该返回一个 field 的\n",
+ " # 内容(即不需要用dict包裹了)。此外,我们还提供了 data_bundle.apply() ,传入 apply() 的函数需要支持传入一个Instance对象,\n",
+ " # 更多信息可以参考对应的文档。\n",
+ " \n",
+ " # tag的词表,由于这是词表,所以不需要有padding和unk\n",
+ " tag_vocab = Vocabulary(padding=None, unknown=None)\n",
+ " # 从 train 数据的 raw_target 中获取建立词表\n",
+ " tag_vocab.from_dataset(data_bundle.get_dataset('train'), field_name='raw_target')\n",
+ " # 使用词表将每个 dataset 中的raw_target转为数字,并且将写入到target这个field中\n",
+ " tag_vocab.index_dataset(data_bundle.datasets.values(), field_name='raw_target', new_field_name='target')\n",
+ " \n",
+ " # 可以将 vocabulary 绑定到 data_bundle 上,方便之后使用。\n",
+ " data_bundle.set_vocab(tag_vocab, field_name='target')\n",
+ " \n",
+ " return data_bundle, tokenizer\n",
+ "\n",
+ "data_bundle, tokenizer = process_data(data_bundle, 'bert-base-cased', _refresh=True) # 第一次调用耗时较长,第二次调用则会直接读取缓存的文件\n",
+ "# data_bundle = process_data(data_bundle, 'bert-base-uncased') # 由于参数变化,fastNLP 会再次生成新的缓存文件。 "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "80036fcd",
+ "metadata": {},
+ "source": [
+ "### 3. DataLoader \n",
+ "由于现在的深度学习算法大都基于 mini-batch 进行优化,因此需要将多个 sample 组合成一个 batch 再输入到模型之中。在自然语言处理中,不同的 sample 往往长度不一致,需要进行 padding 操作。在fastNLP中,我们使用 fastNLP.TorchDataLoader 帮助用户快速进行 padding ,我们使用了 !!!fastNLP.Collator!!! 对象来进行 pad ,Collator 会在迭代过程中根据第一个 batch 的数据自动判定每个 field 是否可以进行 pad ,可以通过 Collator.set_pad() 函数修改某个 field 的 pad 行为。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "09494695",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from fastNLP import prepare_dataloader\n",
+ "\n",
+ "# 将 data_bundle 中每个 dataset 取出并构造出相应的 DataLoader 对象。返回的 dls 是一个 dict ,包含了 'train', 'test', 'dev' 三个\n",
+ "# fastNLP.TorchDataLoader 对象。\n",
+ "dls = prepare_dataloader(data_bundle, batch_size=24) \n",
+ "\n",
+ "\n",
+ "# fastNLP 将默认尝试对所有 field 都进行 pad ,如果当前 field 是不可 pad 的类型,则不进行pad;如果是可以 pad 的类型\n",
+ "# 默认使用 0 进行 pad 。\n",
+ "for dl in dls.values():\n",
+ " # 可以通过 set_pad 修改 padding 的行为。\n",
+ " dl.set_pad('input_ids', pad_val=tokenizer.pad_token_id)\n",
+ " # 如果希望忽略某个 field ,可以通过 set_ignore 方法。\n",
+ " dl.set_ignore('raw_target')\n",
+ " dl.set_pad('target', pad_val=-100)\n",
+ "# 另一种设置的方法是,可以在 dls = prepare_dataloader(data_bundle, batch_size=32) 之前直接调用 \n",
+ "# data_bundle.set_pad('input_ids', pad_val=tokenizer.pad_token_id); data_bundle.set_ignore('raw_target')来进行设置。\n",
+ "# DataSet 也支持这两个方法。\n",
+ "# 若此时调用 batch = next(dls['train']),则 batch 是一个 dict ,其中包含了\n",
+ "# 'input_ids': torch.LongTensor([batch_size, max_len])\n",
+ "# 'input_len': torch.LongTensor([batch_size])\n",
+ "# 'first': torch.LongTensor([batch_size, max_len'])\n",
+ "# 'first_len': torch.LongTensor([batch_size])\n",
+ "# 'target': torch.LongTensor([batch_size, max_len'-2])\n",
+ "# 'raw_words': List[List[str]] # 因为无法判断,所以 Collator 不会做任何处理"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3583df6d",
+ "metadata": {},
+ "source": [
+ "### 4. 模型准备\n",
+ "传入给fastNLP的模型,需要有两个特殊的方法``train_step``、``evaluate_step``,前者默认在 fastNLP.Trainer 中进行调用,后者默认在 fastNLP.Evaluator 中调用。如果模型中没有``train_step``方法,则Trainer会直接使用模型的``forward``函数;如果模型没有``evaluate_step``方法,则Evaluator会直接使用模型的``forward``函数。``train_step``方法(或当其不存在时,``forward``方法)的返回值必须为 dict 类型,并且必须包含``loss``这个 key 。\n",
+ "\n",
+ "此外fastNLP会使用形参名匹配的方式进行参数传递,例如以下模型\n",
+ "```python\n",
+ "class Model(nn.Module):\n",
+ " def train_step(self, x, y):\n",
+ " return {'loss': (x-y).abs().mean()}\n",
+ "```\n",
+ "fastNLP将尝试从 DataLoader 返回的 batch(假设包含的 key 为 input_ids, target) 中寻找 'x' 和 'y' 这两个 key ,如果没有找到则会报错。有以下的方法可以解决报错\n",
+ "- 修改 train_step 的参数为(input_ids, target),以保证和 DataLoader 返回的 batch 中的 key 匹配\n",
+ "- 修改 DataLoader 中返回 batch 的 key 的名字为 (x, y)\n",
+ "- 在 Trainer 中传入参数 train_input_mapping={'input_ids': 'x', 'target': 'y'} 将输入进行映射,train_input_mapping 也可以是一个函数,更多 train_input_mapping 的介绍可以参考文档。\n",
+ "\n",
+ "``evaluate_step``也是使用同样的匹配方式,前两条解决方法是一致的,第三种解决方案中,需要在 Evaluator 中传入 evaluate_input_mapping={'input_ids': 'x', 'target': 'y'}。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "f131c1a3",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "[10:48:21] WARNING Some weights of the model checkpoint at modeling_utils.py:1490\n",
+ " bert-base-uncased were not used when initializing \n",
+ " BertModel: ['cls.predictions.bias', \n",
+ " 'cls.predictions.transform.LayerNorm.weight', \n",
+ " 'cls.seq_relationship.weight', \n",
+ " 'cls.predictions.decoder.weight', \n",
+ " 'cls.predictions.transform.dense.weight', \n",
+ " 'cls.predictions.transform.LayerNorm.bias', \n",
+ " 'cls.predictions.transform.dense.bias', \n",
+ " 'cls.seq_relationship.bias'] \n",
+ " - This IS expected if you are initializing \n",
+ " BertModel from the checkpoint of a model trained \n",
+ " on another task or with another architecture (e.g. \n",
+ " initializing a BertForSequenceClassification model \n",
+ " from a BertForPreTraining model). \n",
+ " - This IS NOT expected if you are initializing \n",
+ " BertModel from the checkpoint of a model that you \n",
+ " expect to be exactly identical (initializing a \n",
+ " BertForSequenceClassification model from a \n",
+ " BertForSequenceClassification model). \n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m[10:48:21]\u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m Some weights of the model checkpoint at \u001b]8;id=387614;file://../../fastNLP/transformers/torch/modeling_utils.py\u001b\\\u001b[2mmodeling_utils.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=648168;file://../../fastNLP/transformers/torch/modeling_utils.py#1490\u001b\\\u001b[2m1490\u001b[0m\u001b]8;;\u001b\\\n",
+ "\u001b[2;36m \u001b[0m bert-base-uncased were not used when initializing \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m BertModel: \u001b[1m[\u001b[0m\u001b[32m'cls.predictions.bias'\u001b[0m, \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.LayerNorm.weight'\u001b[0m, \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m \u001b[32m'cls.seq_relationship.weight'\u001b[0m, \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.decoder.weight'\u001b[0m, \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.dense.weight'\u001b[0m, \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.LayerNorm.bias'\u001b[0m, \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.dense.bias'\u001b[0m, \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m \u001b[32m'cls.seq_relationship.bias'\u001b[0m\u001b[1m]\u001b[0m \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m - This IS expected if you are initializing \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m BertModel from the checkpoint of a model trained \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m on another task or with another architecture \u001b[1m(\u001b[0me.g. \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m initializing a BertForSequenceClassification model \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m from a BertForPreTraining model\u001b[1m)\u001b[0m. \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m - This IS NOT expected if you are initializing \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m BertModel from the checkpoint of a model that you \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m expect to be exactly identical \u001b[1m(\u001b[0minitializing a \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m BertForSequenceClassification model from a \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m BertForSequenceClassification model\u001b[1m)\u001b[0m. \u001b[2m \u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " INFO All the weights of BertModel were initialized from modeling_utils.py:1507\n",
+ " the model checkpoint at bert-base-uncased. \n",
+ " If your task is similar to the task the model of \n",
+ " the checkpoint was trained on, you can already use \n",
+ " BertModel for predictions without further \n",
+ " training. \n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m All the weights of BertModel were initialized from \u001b]8;id=544687;file://../../fastNLP/transformers/torch/modeling_utils.py\u001b\\\u001b[2mmodeling_utils.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=934505;file://../../fastNLP/transformers/torch/modeling_utils.py#1507\u001b\\\u001b[2m1507\u001b[0m\u001b]8;;\u001b\\\n",
+ "\u001b[2;36m \u001b[0m the model checkpoint at bert-base-uncased. \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m If your task is similar to the task the model of \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m the checkpoint was trained on, you can already use \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m BertModel for predictions without further \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m training. \u001b[2m \u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "from torch import nn\n",
+ "from torch.nn.utils.rnn import pad_sequence\n",
+ "from fastNLP.transformers.torch import BertModel\n",
+ "from fastNLP import seq_len_to_mask\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "\n",
+ "class BertNER(nn.Module):\n",
+ " def __init__(self, model_name, num_class, tag_vocab=None):\n",
+ " super().__init__()\n",
+ " self.bert = BertModel.from_pretrained(model_name)\n",
+ " self.mlp = nn.Sequential(nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size),\n",
+ " nn.Dropout(0.3),\n",
+ " nn.Linear(self.bert.config.hidden_size, num_class))\n",
+ " self.tag_vocab = tag_vocab # 这里传入 tag_vocab 的目的是为了演示 constrined_decode \n",
+ " if tag_vocab is not None:\n",
+ " self._init_constrained_transition()\n",
+ " \n",
+ " def forward(self, input_ids, input_len, first):\n",
+ " attention_mask = seq_len_to_mask(input_len)\n",
+ " outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n",
+ " last_hidden_state = outputs.last_hidden_state\n",
+ " first = first.unsqueeze(-1).repeat(1, 1, last_hidden_state.size(-1))\n",
+ " first_bpe_state = last_hidden_state.gather(dim=1, index=first)\n",
+ " first_bpe_state = first_bpe_state[:, 1:-1] # 删除 cls 和 sep\n",
+ " \n",
+ " pred = self.mlp(first_bpe_state)\n",
+ " return {'pred': pred}\n",
+ " \n",
+ " def train_step(self, input_ids, input_len, first, target):\n",
+ " pred = self(input_ids, input_len, first)['pred']\n",
+ " loss = F.cross_entropy(pred.transpose(1, 2), target)\n",
+ " return {'loss': loss}\n",
+ " \n",
+ " def evaluate_step(self, input_ids, input_len, first):\n",
+ " pred = self(input_ids, input_len, first)['pred'].argmax(dim=-1)\n",
+ " return {'pred': pred}\n",
+ " \n",
+ " def constrained_decode(self, input_ids, input_len, first, first_len):\n",
+ " # 这个函数在推理时,将保证解码出来的 tag 一定不与前一个 tag 矛盾【例如一定不会出现 B-person 后面接着 I-Location 的情况】\n",
+ " # 本身这个需求可以在 Metric 中实现,这里在模型中实现的目的是为了方便演示:如何在fastNLP中使用不同的评测函数\n",
+ " pred = self(input_ids, input_len, first)['pred']\n",
+ " cons_pred = []\n",
+ " for _pred, _len in zip(pred, first_len):\n",
+ " _pred = _pred[:_len]\n",
+ " tags = [_pred[0].argmax(dim=-1).item()] # 这里就不考虑第一个位置非法的情况了\n",
+ " for i in range(1, _len):\n",
+ " tags.append((_pred[i] + self.transition[tags[-1]]).argmax().item())\n",
+ " cons_pred.append(torch.LongTensor(tags))\n",
+ " cons_pred = pad_sequence(cons_pred, batch_first=True)\n",
+ " return {'pred': cons_pred}\n",
+ " \n",
+ " def _init_constrained_transition(self):\n",
+ " from fastNLP.modules.torch import allowed_transitions\n",
+ " allowed_trans = allowed_transitions(self.tag_vocab)\n",
+ " transition = torch.ones((len(self.tag_vocab), len(self.tag_vocab)))*-100000.0\n",
+ " for s, e in allowed_trans:\n",
+ " transition[s, e] = 0\n",
+ " self.register_buffer('transition', transition)\n",
+ "\n",
+ "model = BertNER('bert-base-uncased', len(data_bundle.get_vocab('target')), data_bundle.get_vocab('target'))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5aeee1e9",
+ "metadata": {},
+ "source": [
+ "### Trainer 的使用\n",
+ "fastNLP 的 Trainer 是用于对模型进行训练的部件。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "f4250f0b",
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "[10:49:22] INFO Running evaluator sanity check for 2 batches. trainer.py:661\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m[10:49:22]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=246773;file://../../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=639347;file://../../fastNLP/core/controllers/trainer.py#661\u001b\\\u001b[2m661\u001b[0m\u001b]8;;\u001b\\\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "+++++++++++++++++++++++++++++ Eval. results on Epoch:1, Batch:0 +++++++++++++++++++++++++++++\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[38;5;41m+++++++++++++++++++++++++++++ \u001b[0m\u001b[1mEval. results on Epoch:\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1m, Batch:\u001b[0m\u001b[1;36m0\u001b[0m\u001b[38;5;41m +++++++++++++++++++++++++++++\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"f#f\": 0.402447,\n",
+ " \"pre#f\": 0.447906,\n",
+ " \"rec#f\": 0.365365\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"f#f\"\u001b[0m: \u001b[1;36m0.402447\u001b[0m,\n",
+ " \u001b[1;34m\"pre#f\"\u001b[0m: \u001b[1;36m0.447906\u001b[0m,\n",
+ " \u001b[1;34m\"rec#f\"\u001b[0m: \u001b[1;36m0.365365\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "[10:51:15] INFO The best performance for monitor f#f:0.402447 was progress_callback.py:37\n",
+ " achieved in Epoch:1, Global Batch:625. The \n",
+ " evaluation result: \n",
+ " {'f#f': 0.402447, 'pre#f': 0.447906, 'rec#f': \n",
+ " 0.365365} \n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m[10:51:15]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m The best performance for monitor f#\u001b[1;92mf:0\u001b[0m.\u001b[1;36m402447\u001b[0m was \u001b]8;id=192029;file://../../fastNLP/core/callbacks/progress_callback.py\u001b\\\u001b[2mprogress_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=994998;file://../../fastNLP/core/callbacks/progress_callback.py#37\u001b\\\u001b[2m37\u001b[0m\u001b]8;;\u001b\\\n",
+ "\u001b[2;36m \u001b[0m achieved in Epoch:\u001b[1;36m1\u001b[0m, Global Batch:\u001b[1;36m625\u001b[0m. The \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m evaluation result: \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m \u001b[1m{\u001b[0m\u001b[32m'f#f'\u001b[0m: \u001b[1;36m0.402447\u001b[0m, \u001b[32m'pre#f'\u001b[0m: \u001b[1;36m0.447906\u001b[0m, \u001b[32m'rec#f'\u001b[0m: \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m \u001b[1;36m0.365365\u001b[0m\u001b[1m}\u001b[0m \u001b[2m \u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " INFO Loading best model from buffer with f#f: load_best_model_callback.py:115\n",
+ " 0.402447... \n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Loading best model from buffer with f#f: \u001b]8;id=654516;file://../../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=96586;file://../../fastNLP/core/callbacks/load_best_model_callback.py#115\u001b\\\u001b[2m115\u001b[0m\u001b]8;;\u001b\\\n",
+ "\u001b[2;36m \u001b[0m \u001b[1;36m0.402447\u001b[0m\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from torch import optim\n",
+ "from fastNLP import Trainer, LoadBestModelCallback, TorchWarmupCallback\n",
+ "from fastNLP import SpanFPreRecMetric\n",
+ "\n",
+ "optimizer = optim.AdamW(model.parameters(), lr=2e-5)\n",
+ "callbacks = [\n",
+ " LoadBestModelCallback(), # 用于在训练结束之后加载性能最好的model的权重\n",
+ " TorchWarmupCallback()\n",
+ "] \n",
+ "\n",
+ "trainer = Trainer(model=model, train_dataloader=dls['train'], optimizers=optimizer, \n",
+ " evaluate_dataloaders=dls['dev'], \n",
+ " metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))}, \n",
+ " n_epochs=1, callbacks=callbacks, \n",
+ " # 在评测时将 dataloader 中的 first_len 映射 seq_len, 因为 Accuracy.update 接口需要输入一个名为 seq_len 的参数\n",
+ " evaluate_input_mapping={'first_len': 'seq_len'}, overfit_batches=0,\n",
+ " device=0, monitor='f#f', fp16=False) # fp16 为 True 的话,将使用 float16 进行训练。\n",
+ "trainer.run()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c600a450",
+ "metadata": {},
+ "source": [
+ "### Evaluator的使用\n",
+ "fastNLP中用于评测数据的对象。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "1b19f0ba",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{'f#f': 0.390326, 'pre#f': 0.414741, 'rec#f': 0.368626}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\u001b[32m'f#f'\u001b[0m: \u001b[1;36m0.390326\u001b[0m, \u001b[32m'pre#f'\u001b[0m: \u001b[1;36m0.414741\u001b[0m, \u001b[32m'rec#f'\u001b[0m: \u001b[1;36m0.368626\u001b[0m\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "{'f#f': 0.390326, 'pre#f': 0.414741, 'rec#f': 0.368626}"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from fastNLP import Evaluator\n",
+ "from fastNLP import SpanFPreRecMetric\n",
+ "\n",
+ "evaluator = Evaluator(model=model, dataloaders=dls['test'], \n",
+ " metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))}, \n",
+ " evaluate_input_mapping={'first_len': 'seq_len'}, \n",
+ " device=0)\n",
+ "evaluator.run()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "52f87770",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "f723fe399df34917875ad74c2542508c",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# 如果想评测一下使用 constrained decoding的性能,则可以通过传入 evaluate_fn 指定使用的函数\n",
+ "def input_mapping(x):\n",
+ " x['seq_len'] = x['first_len']\n",
+ " return x\n",
+ "evaluator = Evaluator(model=model, dataloaders=dls['test'], device=0,\n",
+ " metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))},\n",
+ " evaluate_fn='constrained_decode',\n",
+ " # 如果将 first_len 重新命名为了 seq_len, 将导致 constrained_decode 的输入缺少 first_len 参数,因此\n",
+ " # 额外重复一下 'first_len': 'first_len',使得这个参数不会消失。\n",
+ " evaluate_input_mapping=input_mapping)\n",
+ "evaluator.run()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "419e718b",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/source/tutorials/fastnlp_tutorial_0.ipynb b/docs/source/tutorials/fastnlp_tutorial_0.ipynb
new file mode 100644
index 00000000..09667794
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_tutorial_0.ipynb
@@ -0,0 +1,1352 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "aec0fde7",
+ "metadata": {},
+ "source": [
+ "# T0. trainer 和 evaluator 的基本使用\n",
+ "\n",
+ " 1 trainer 和 evaluator 的基本关系\n",
+ " \n",
+ " 1.1 trainer 和 evaluater 的初始化\n",
+ "\n",
+ " 1.2 driver 的含义与使用要求\n",
+ "\n",
+ " 1.3 trainer 内部初始化 evaluater\n",
+ "\n",
+ " 2 使用 fastNLP 搭建 argmax 模型\n",
+ "\n",
+ " 2.1 trainer_step 和 evaluator_step\n",
+ "\n",
+ " 2.2 trainer 和 evaluator 的参数匹配\n",
+ "\n",
+ " 2.3 示例:argmax 模型的搭建\n",
+ "\n",
+ " 3 使用 fastNLP 训练 argmax 模型\n",
+ " \n",
+ " 3.1 trainer 外部初始化的 evaluator\n",
+ "\n",
+ " 3.2 trainer 内部初始化的 evaluator "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "09ea669a",
+ "metadata": {},
+ "source": [
+ "## 1. trainer 和 evaluator 的基本关系\n",
+ "\n",
+ "### 1.1 trainer 和 evaluator 的初始化\n",
+ "\n",
+ "在`fastNLP 1.0`中,`Trainer`模块和`Evaluator`模块分别表示 **“训练器”和“评测器”**\n",
+ "\n",
+ " 对应于之前的`fastNLP`版本中的`Trainer`模块和`Tester`模块,其定义方法如下所示\n",
+ "\n",
+ "在`fastNLP 1.0`中,需要注意,在同个`python`脚本中先使用`Trainer`训练,然后使用`Evaluator`评测\n",
+ "\n",
+ " 非常关键的问题在于**如何正确设置二者的 driver**。这就引入了另一个问题:什么是 `driver`?\n",
+ "\n",
+ "\n",
+ "```python\n",
+ "trainer = Trainer(\n",
+ " model=model, # 模型基于 torch.nn.Module\n",
+ " train_dataloader=train_dataloader, # 加载模块基于 torch.utils.data.DataLoader \n",
+ " optimizers=optimizer, # 优化模块基于 torch.optim.*\n",
+ " ...\n",
+ " driver=\"torch\", # 使用 pytorch 模块进行训练 \n",
+ " device='cuda', # 使用 GPU:0 显卡执行训练\n",
+ " ...\n",
+ " )\n",
+ "...\n",
+ "evaluator = Evaluator(\n",
+ " model=model, # 模型基于 torch.nn.Module\n",
+ " dataloaders=evaluate_dataloader, # 加载模块基于 torch.utils.data.DataLoader\n",
+ " metrics={'acc': Accuracy()}, # 测评方法使用 fastNLP.core.metrics.Accuracy \n",
+ " ...\n",
+ " driver=trainer.driver, # 保持同 trainer 的 driver 一致\n",
+ " device=None,\n",
+ " ...\n",
+ " )\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3c11fe1a",
+ "metadata": {},
+ "source": [
+ "### 1.2 driver 的含义与使用要求\n",
+ "\n",
+ "在`fastNLP 1.0`中,**driver**这一概念被用来表示**控制具体训练的各个步骤的最终执行部分**\n",
+ "\n",
+ " 例如神经网络前向、后向传播的具体执行、网络参数的优化和数据在设备间的迁移等\n",
+ "\n",
+ "在`fastNLP 1.0`中,**Trainer 和 Evaluator 都依赖于具体的 driver 来完成整体的工作流程**\n",
+ "\n",
+ " 具体`driver`与`Trainer`以及`Evaluator`之间的关系之后`tutorial 4`中的详细介绍\n",
+ "\n",
+ "注:这里给出一条建议:**在同一脚本中**,**所有的** Trainer **和** Evaluator **使用的** driver **应当保持一致**\n",
+ "\n",
+ " 尽量不出现,之前使用单卡的`driver`,后面又使用多卡的`driver`,这是因为,当脚本执行至\n",
+ "\n",
+ " 多卡`driver`处时,会重启一个进程执行之前所有内容,如此一来可能会造成一些意想不到的麻烦"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2cac4a1a",
+ "metadata": {},
+ "source": [
+ "### 1.3 Trainer 内部初始化 Evaluator\n",
+ "\n",
+ "在`fastNLP 1.0`中,如果在**初始化 Trainer 时**,**传入参数 evaluator_dataloaders 和 metrics **\n",
+ "\n",
+ " 则在`Trainer`内部,也会初始化单独的`Evaluator`来帮助训练过程中对验证集的评测\n",
+ "\n",
+ "```python\n",
+ "trainer = Trainer(\n",
+ " model=model,\n",
+ " train_dataloader=train_dataloader,\n",
+ " optimizers=optimizer,\n",
+ " ...\n",
+ " driver=\"torch\",\n",
+ " device='cuda',\n",
+ " ...\n",
+ " evaluate_dataloaders=evaluate_dataloader, # 传入参数 evaluator_dataloaders\n",
+ " metrics={'acc': Accuracy()}, # 传入参数 metrics\n",
+ " ...\n",
+ " )\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0c9c7dda",
+ "metadata": {},
+ "source": [
+ "## 2. argmax 模型的搭建实例"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "524ac200",
+ "metadata": {},
+ "source": [
+ "### 2.1 trainer_step 和 evaluator_step\n",
+ "\n",
+ "在`fastNLP 1.0`中,使用`pytorch.nn.Module`搭建需要训练的模型,在搭建模型过程中,除了\n",
+ "\n",
+ " 添加`pytorch`要求的`forward`方法外,还需要添加 `train_step` 和 `evaluate_step` 这两个方法\n",
+ "\n",
+ "```python\n",
+ "class Model(torch.nn.Module):\n",
+ " def __init__(self):\n",
+ " super(Model, self).__init__()\n",
+ " self.loss_fn = torch.nn.CrossEntropyLoss()\n",
+ " pass\n",
+ "\n",
+ " def forward(self, x):\n",
+ " pass\n",
+ "\n",
+ " def train_step(self, x, y):\n",
+ " pred = self(x)\n",
+ " return {\"loss\": self.loss_fn(pred, y)}\n",
+ "\n",
+ " def evaluate_step(self, x, y):\n",
+ " pred = self(x)\n",
+ " pred = torch.max(pred, dim=-1)[1]\n",
+ " return {\"pred\": pred, \"target\": y}\n",
+ "```\n",
+ "***\n",
+ "在`fastNLP 1.0`中,**函数 train_step 是 Trainer 中参数 train_fn 的默认值**\n",
+ "\n",
+ " 由于,在`Trainer`训练时,**Trainer 通过参数 train_fn 对应的模型方法获得当前数据批次的损失值**\n",
+ "\n",
+ " 因此,在`Trainer`训练时,`Trainer`首先会寻找模型是否定义了`train_step`这一方法\n",
+ "\n",
+ " 如果没有找到,那么`Trainer`会默认使用模型的`forward`函数来进行训练的前向传播过程\n",
+ "\n",
+ "注:在`fastNLP 1.0`中,**Trainer 要求模型通过 train_step 来返回一个字典**,**满足如 {\"loss\": loss} 的形式**\n",
+ "\n",
+ " 此外,这里也可以通过传入`Trainer`的参数`output_mapping`来实现输出的转换,详见(trainer的详细讲解,待补充)\n",
+ "\n",
+ "同样,在`fastNLP 1.0`中,**函数 evaluate_step 是 Evaluator 中参数 evaluate_fn 的默认值**\n",
+ "\n",
+ " 在`Evaluator`测试时,**Evaluator 通过参数 evaluate_fn 对应的模型方法获得当前数据批次的评测结果**\n",
+ "\n",
+ " 从用户角度,模型通过`evaluate_step`方法来返回一个字典,内容与传入`Evaluator`的`metrics`一致\n",
+ "\n",
+ " 从模块角度,该字典的键值和`metric`中的`update`函数的签名一致,这样的机制在传参时被称为“**参数匹配**”\n",
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "fb3272eb",
+ "metadata": {},
+ "source": [
+ "### 2.2 trainer 和 evaluator 的参数匹配\n",
+ "\n",
+ "在`fastNLP 1.0`中,参数匹配涉及到两个方面,分别是在\n",
+ "\n",
+ " 一方面,**在模型的前向传播中**,**dataloader 向 train_step 或 evaluate_step 函数传递 batch**\n",
+ "\n",
+ " 另方面,**在模型的评测过程中**,**evaluate_dataloader 向 metric 的 update 函数传递 batch**\n",
+ "\n",
+ "对于前者,在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`False`时\n",
+ "\n",
+ " **fastNLP 1.0 要求 dataloader 生成的每个 batch **,**满足如 {\"x\": x, \"y\": y} 的形式**\n",
+ "\n",
+ " 同时,`fastNLP 1.0`会查看模型的`train_step`和`evaluate_step`方法的参数签名,并为对应参数传入对应数值\n",
+ "\n",
+ " **字典形式的定义**,**对应在 Dataset 定义的 \\_\\_getitem\\_\\_ 方法中**,例如下方的`ArgMaxDatset`\n",
+ "\n",
+ " 而在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`True`时\n",
+ "\n",
+ " `fastNLP 1.0`会将`batch`直接传给模型的`train_step`、`evaluate_step`或`forward`函数\n",
+ "\n",
+ "```python\n",
+ "class Dataset(torch.utils.data.Dataset):\n",
+ " def __init__(self, x, y):\n",
+ " self.x = x\n",
+ " self.y = y\n",
+ "\n",
+ " def __len__(self):\n",
+ " return len(self.x)\n",
+ "\n",
+ " def __getitem__(self, item):\n",
+ " return {\"x\": self.x[item], \"y\": self.y[item]}\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f5f1a6aa",
+ "metadata": {},
+ "source": [
+ "对于后者,首先要明确,在`Trainer`和`Evaluator`中,`metrics`的计算分为`update`和`get_metric`两步\n",
+ "\n",
+ " **update 函数**,**针对一个 batch 的预测结果**,计算其累计的评价指标\n",
+ "\n",
+ " **get_metric 函数**,**统计 update 函数累计的评价指标**,来计算最终的评价结果\n",
+ "\n",
+ " 例如对于`Accuracy`来说,`update`函数会更新一个`batch`的正例数量`right_num`和负例数量`total_num`\n",
+ "\n",
+ " 而`get_metric`函数则会返回所有`batch`的评测值`right_num / total_num`\n",
+ "\n",
+ " 在此基础上,**fastNLP 1.0 要求 evaluate_dataloader 生成的每个 batch 传递给对应的 metric**\n",
+ "\n",
+ " **以 {\"pred\": y_pred, \"target\": y_true} 的形式**,对应其`update`函数的函数签名\n",
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f62b7bb1",
+ "metadata": {},
+ "source": [
+ "### 2.3 示例:argmax 模型的搭建\n",
+ "\n",
+ "下文将通过训练`argmax`模型,简单介绍如何`Trainer`模块的使用方式\n",
+ "\n",
+ " 首先,使用`pytorch.nn.Module`定义`argmax`模型,目标是输入一组固定维度的向量,输出其中数值最大的数的索引"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "5314482b",
+ "metadata": {
+ "pycharm": {
+ "is_executing": true
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "\n",
+ "class ArgMaxModel(nn.Module):\n",
+ " def __init__(self, num_labels, feature_dimension):\n",
+ " nn.Module.__init__(self)\n",
+ " self.num_labels = num_labels\n",
+ "\n",
+ " self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)\n",
+ " self.ac1 = nn.ReLU()\n",
+ " self.linear2 = nn.Linear(in_features=10, out_features=10)\n",
+ " self.ac2 = nn.ReLU()\n",
+ " self.output = nn.Linear(in_features=10, out_features=num_labels)\n",
+ " self.loss_fn = nn.CrossEntropyLoss()\n",
+ "\n",
+ " def forward(self, x):\n",
+ " pred = self.ac1(self.linear1(x))\n",
+ " pred = self.ac2(self.linear2(pred))\n",
+ " pred = self.output(pred)\n",
+ " return pred\n",
+ "\n",
+ " def train_step(self, x, y):\n",
+ " pred = self(x)\n",
+ " return {\"loss\": self.loss_fn(pred, y)}\n",
+ "\n",
+ " def evaluate_step(self, x, y):\n",
+ " pred = self(x)\n",
+ " pred = torch.max(pred, dim=-1)[1]\n",
+ " return {\"pred\": pred, \"target\": y}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "71f3fa6b",
+ "metadata": {},
+ "source": [
+ " 接着,使用`torch.utils.data.Dataset`定义`ArgMaxDataset`数据集\n",
+ "\n",
+ " 数据集包含三个参数:维度`feature_dimension`、数据量`data_num`和随机种子`seed`\n",
+ "\n",
+ " 数据及初始化是,自动生成指定维度的向量,并为每个向量标注出其中最大值的索引作为预测标签"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "fe612e61",
+ "metadata": {
+ "pycharm": {
+ "is_executing": false
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from torch.utils.data import Dataset\n",
+ "\n",
+ "class ArgMaxDataset(Dataset):\n",
+ " def __init__(self, feature_dimension, data_num=1000, seed=0):\n",
+ " self.num_labels = feature_dimension\n",
+ " self.feature_dimension = feature_dimension\n",
+ " self.data_num = data_num\n",
+ " self.seed = seed\n",
+ "\n",
+ " g = torch.Generator()\n",
+ " g.manual_seed(1000)\n",
+ " self.x = torch.randint(low=-100, high=100, size=[data_num, feature_dimension], generator=g).float()\n",
+ " self.y = torch.max(self.x, dim=-1)[1]\n",
+ "\n",
+ " def __len__(self):\n",
+ " return self.data_num\n",
+ "\n",
+ " def __getitem__(self, item):\n",
+ " return {\"x\": self.x[item], \"y\": self.y[item]}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2cb96332",
+ "metadata": {},
+ "source": [
+ " 然后,根据`ArgMaxModel`类初始化模型实例,保持输入维度`feature_dimension`和输出标签数量`num_labels`一致\n",
+ "\n",
+ " 再根据`ArgMaxDataset`类初始化两个数据集实例,分别用来模型测试和模型评测,数据量各1000笔"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "76172ef8",
+ "metadata": {
+ "pycharm": {
+ "is_executing": false
+ }
+ },
+ "outputs": [],
+ "source": [
+ "model = ArgMaxModel(num_labels=10, feature_dimension=10)\n",
+ "\n",
+ "train_dataset = ArgMaxDataset(feature_dimension=10, data_num=1000)\n",
+ "evaluate_dataset = ArgMaxDataset(feature_dimension=10, data_num=100)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4e7d25ee",
+ "metadata": {},
+ "source": [
+ " 此外,使用`torch.utils.data.DataLoader`初始化两个数据加载模块,批量大小同为8,分别用于训练和测评"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "363b5b09",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from torch.utils.data import DataLoader\n",
+ "\n",
+ "train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)\n",
+ "evaluate_dataloader = DataLoader(evaluate_dataset, batch_size=8)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c8d4443f",
+ "metadata": {},
+ "source": [
+ " 最后,使用`torch.optim.SGD`初始化一个优化模块,基于随机梯度下降法"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "dc28a2d9",
+ "metadata": {
+ "pycharm": {
+ "is_executing": false
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from torch.optim import SGD\n",
+ "\n",
+ "optimizer = SGD(model.parameters(), lr=0.001)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "eb8ca6cf",
+ "metadata": {},
+ "source": [
+ "## 3. 使用 fastNLP 1.0 训练 argmax 模型\n",
+ "\n",
+ "### 3.1 trainer 外部初始化的 evaluator"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "55145553",
+ "metadata": {},
+ "source": [
+ "通过从`fastNLP`库中导入`Trainer`类,初始化`trainer`实例,对模型进行训练\n",
+ "\n",
+ " 需要导入预先定义好的模型`model`、对应的数据加载模块`train_dataloader`、优化模块`optimizer`\n",
+ "\n",
+ " 通过`progress_bar`设定进度条格式,默认为`\"auto\"`,此外还有`\"rich\"`、`\"raw\"`和`None`\n",
+ "\n",
+ " 但对于`\"auto\"`和`\"rich\"`格式,在`jupyter`中,进度条会在训练结束后会被丢弃\n",
+ "\n",
+ " 通过`n_epochs`设定优化迭代轮数,默认为20;全部`Trainer`的全部变量与函数可以通过`dir(trainer)`查询"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "b51b7a2d",
+ "metadata": {
+ "pycharm": {
+ "is_executing": false
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import sys\n",
+ "sys.path.append('..')\n",
+ "\n",
+ "from fastNLP import Trainer\n",
+ "\n",
+ "trainer = Trainer(\n",
+ " model=model,\n",
+ " driver=\"torch\",\n",
+ " device='cuda',\n",
+ " train_dataloader=train_dataloader,\n",
+ " optimizers=optimizer,\n",
+ " n_epochs=10, # 设定迭代轮数 \n",
+ " progress_bar=\"auto\" # 设定进度条格式\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6e202d6e",
+ "metadata": {},
+ "source": [
+ "通过使用`Trainer`类的`run`函数,进行训练\n",
+ "\n",
+ " 其中,可以通过参数`num_train_batch_per_epoch`决定每个`epoch`运行多少个`batch`后停止,默认全部\n",
+ "\n",
+ " `run`函数完成后在`jupyter`中没有输出保留,此外,通过`help(trainer.run)`可以查询`run`函数的详细内容"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "ba047ead",
+ "metadata": {
+ "pycharm": {
+ "is_executing": true
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "trainer.run()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c16c5fa4",
+ "metadata": {},
+ "source": [
+ "通过从`fastNLP`库中导入`Evaluator`类,初始化`evaluator`实例,对模型进行评测\n",
+ "\n",
+ " 需要导入预先定义好的模型`model`、对应的数据加载模块`evaluate_dataloader`\n",
+ "\n",
+ " 需要注意的是评测方法`metrics`,设定为形如`{'acc': fastNLP.core.metrics.Accuracy()}`的字典\n",
+ "\n",
+ " 类似地,也可以通过`progress_bar`限定进度条格式,默认为`\"auto\"`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "1c6b6b36",
+ "metadata": {
+ "pycharm": {
+ "is_executing": true
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from fastNLP import Evaluator\n",
+ "from fastNLP import Accuracy\n",
+ "\n",
+ "evaluator = Evaluator(\n",
+ " model=model,\n",
+ " driver=trainer.driver, # 需要使用 trainer 已经启动的 driver\n",
+ " device=None,\n",
+ " dataloaders=evaluate_dataloader,\n",
+ " metrics={'acc': Accuracy()} # 需要严格使用此种形式的字典\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8157bb9b",
+ "metadata": {},
+ "source": [
+ "通过使用`Evaluator`类的`run`函数,进行训练\n",
+ "\n",
+ " 其中,可以通过参数`num_eval_batch_per_dl`决定每个`evaluate_dataloader`运行多少个`batch`停止,默认全部\n",
+ "\n",
+ " 最终,输出形如`{'acc#acc': acc}`的字典,在`jupyter`中,进度条会在评测结束后会被丢弃"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "f7cb0165",
+ "metadata": {
+ "pycharm": {
+ "is_executing": true
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{'acc#acc': 0.31, 'total#acc': 100.0, 'correct#acc': 31.0}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.31\u001b[0m, \u001b[32m'total#acc'\u001b[0m: \u001b[1;36m100.0\u001b[0m, \u001b[32m'correct#acc'\u001b[0m: \u001b[1;36m31.0\u001b[0m\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "{'acc#acc': 0.31, 'total#acc': 100.0, 'correct#acc': 31.0}"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "evaluator.run()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "dd9f68fa",
+ "metadata": {},
+ "source": [
+ "### 3.2 trainer 内部初始化的 evaluator \n",
+ "\n",
+ "通过在初始化`trainer`实例时加入`evaluate_dataloaders`和`metrics`,可以实现在训练过程中进行评测\n",
+ "\n",
+ " 通过`progress_bar`同时设定训练和评估进度条格式,在`jupyter`中,在进度条训练结束后会被丢弃\n",
+ "\n",
+ " 但是中间的评估结果仍会保留;**通过 evaluate_every 设定评估频率**,可以为负数、正数或者函数:\n",
+ "\n",
+ " **为负数时**,**表示每隔几个 epoch 评估一次**;**为正数时**,**则表示每隔几个 batch 评估一次**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "183c7d19",
+ "metadata": {
+ "pycharm": {
+ "is_executing": true
+ }
+ },
+ "outputs": [],
+ "source": [
+ "trainer = Trainer(\n",
+ " model=model,\n",
+ " driver=trainer.driver, # 因为是在同个脚本中,这里的 driver 同样需要重用\n",
+ " train_dataloader=train_dataloader,\n",
+ " evaluate_dataloaders=evaluate_dataloader,\n",
+ " metrics={'acc': Accuracy()},\n",
+ " optimizers=optimizer,\n",
+ " n_epochs=10, \n",
+ " evaluate_every=-1, # 表示每个 epoch 的结束进行评估\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "714cc404",
+ "metadata": {},
+ "source": [
+ "通过使用`Trainer`类的`run`函数,进行训练\n",
+ "\n",
+ " 还可以通过**参数 num_eval_sanity_batch 决定每次训练前运行多少个 evaluate_batch 进行评测**,**默认为 2 **\n",
+ "\n",
+ " 之所以“先评测后训练”,是为了保证训练很长时间的数据,不会在评测阶段出问题,故作此**试探性评测**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "2e4daa2c",
+ "metadata": {
+ "pycharm": {
+ "is_executing": true
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "[18:28:25] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m[18:28:25]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=549287;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=645362;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.31,\n",
+ " \"total#acc\": 100.0,\n",
+ " \"correct#acc\": 31.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.31\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m31.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.33,\n",
+ " \"total#acc\": 100.0,\n",
+ " \"correct#acc\": 33.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.33\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m33.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.34,\n",
+ " \"total#acc\": 100.0,\n",
+ " \"correct#acc\": 34.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.34\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m34.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.36,\n",
+ " \"total#acc\": 100.0,\n",
+ " \"correct#acc\": 36.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.36,\n",
+ " \"total#acc\": 100.0,\n",
+ " \"correct#acc\": 36.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.36,\n",
+ " \"total#acc\": 100.0,\n",
+ " \"correct#acc\": 36.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.36,\n",
+ " \"total#acc\": 100.0,\n",
+ " \"correct#acc\": 36.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.36,\n",
+ " \"total#acc\": 100.0,\n",
+ " \"correct#acc\": 36.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.37,\n",
+ " \"total#acc\": 100.0,\n",
+ " \"correct#acc\": 37.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.37\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m37.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.4,\n",
+ " \"total#acc\": 100.0,\n",
+ " \"correct#acc\": 40.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.4\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m40.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "trainer.run()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "c4e9c619",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "{'acc#acc': 0.4, 'total#acc': 100.0, 'correct#acc': 40.0}"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "trainer.evaluator.run()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1bc7cb4a",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.13"
+ },
+ "pycharm": {
+ "stem_cell": {
+ "cell_type": "raw",
+ "metadata": {
+ "collapsed": false
+ },
+ "source": []
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/source/tutorials/fastnlp_tutorial_1.ipynb b/docs/source/tutorials/fastnlp_tutorial_1.ipynb
new file mode 100644
index 00000000..cff81a21
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_tutorial_1.ipynb
@@ -0,0 +1,1333 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "cdc25fcd",
+ "metadata": {},
+ "source": [
+ "# T1. dataset 和 vocabulary 的基本使用\n",
+ "\n",
+ " 1 dataset 的使用与结构\n",
+ " \n",
+ " 1.1 dataset 的结构与创建\n",
+ "\n",
+ " 1.2 dataset 的数据预处理\n",
+ "\n",
+ " 1.3 延伸:instance 和 field\n",
+ "\n",
+ " 2 vocabulary 的结构与使用\n",
+ "\n",
+ " 2.1 vocabulary 的创建与修改\n",
+ "\n",
+ " 2.2 vocabulary 与 OOV 问题\n",
+ "\n",
+ " 3 dataset 和 vocabulary 的组合使用\n",
+ " \n",
+ " 3.1 从 dataframe 中加载 dataset\n",
+ "\n",
+ " 3.2 从 dataset 中获取 vocabulary"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0eb18a22",
+ "metadata": {},
+ "source": [
+ "## 1. dataset 的基本使用\n",
+ "\n",
+ "### 1.1 dataset 的结构与创建\n",
+ "\n",
+ "在`fastNLP 1.0`中,使用`DataSet`模块表示数据集,**dataset 类似于关系型数据库中的数据表**(下文统一为小写 `dataset`)\n",
+ "\n",
+ " **主要包含 field 字段和 instance 实例两个元素**,对应 table 中的 field 字段和`record`记录\n",
+ "\n",
+ "在`fastNLP 1.0`中,`DataSet`模块被定义在`fastNLP.core.dataset`路径下,导入该模块后,最简单的\n",
+ "\n",
+ " 初始化方法,即将字典形式的表格 **{'field1': column1, 'field2': column2, ...}** 传入构造函数"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "a1d69ad2",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+-----+------------------------+------------------------+-----+\n",
+ "| idx | sentence | words | num |\n",
+ "+-----+------------------------+------------------------+-----+\n",
+ "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n",
+ "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n",
+ "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n",
+ "+-----+------------------------+------------------------+-----+\n"
+ ]
+ }
+ ],
+ "source": [
+ "from fastNLP import DataSet\n",
+ "\n",
+ "data = {'idx': [0, 1, 2], \n",
+ " 'sentence':[\"This is an apple .\", \"I like apples .\", \"Apples are good for our health .\"],\n",
+ " 'words': [['This', 'is', 'an', 'apple', '.'], \n",
+ " ['I', 'like', 'apples', '.'], \n",
+ " ['Apples', 'are', 'good', 'for', 'our', 'health', '.']],\n",
+ " 'num': [5, 4, 7]}\n",
+ "\n",
+ "dataset = DataSet(data)\n",
+ "print(dataset)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9260fdc6",
+ "metadata": {},
+ "source": [
+ " 在`dataset`的实例中,字段`field`的名称和实例`instance`中的字符串也可以中文"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "3d72ef00",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+------+--------------------+------------------------+------+\n",
+ "| 序号 | 句子 | 字符 | 长度 |\n",
+ "+------+--------------------+------------------------+------+\n",
+ "| 0 | 生活就像海洋, | ['生', '活', '就', ... | 7 |\n",
+ "| 1 | 只有意志坚强的人, | ['只', '有', '意', ... | 9 |\n",
+ "| 2 | 才能到达彼岸。 | ['才', '能', '到', ... | 7 |\n",
+ "+------+--------------------+------------------------+------+\n"
+ ]
+ }
+ ],
+ "source": [
+ "temp = {'序号': [0, 1, 2], \n",
+ " '句子':[\"生活就像海洋,\", \"只有意志坚强的人,\", \"才能到达彼岸。\"],\n",
+ " '字符': [['生', '活', '就', '像', '海', '洋', ','], \n",
+ " ['只', '有', '意', '志', '坚', '强', '的', '人', ','], \n",
+ " ['才', '能', '到', '达', '彼', '岸', '。']],\n",
+ " '长度': [7, 9, 7]}\n",
+ "\n",
+ "chinese = DataSet(temp)\n",
+ "print(chinese)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "202e5490",
+ "metadata": {},
+ "source": [
+ "在`dataset`中,使用`drop`方法可以删除满足条件的实例,这里使用了python中的`lambda`表达式\n",
+ "\n",
+ " 注一:在`drop`方法中,通过设置`inplace`参数将删除对应实例后的`dataset`作为一个新的实例生成"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "09b478f8",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2492313174344 2491986424200\n",
+ "+-----+------------------------+------------------------+-----+\n",
+ "| idx | sentence | words | num |\n",
+ "+-----+------------------------+------------------------+-----+\n",
+ "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n",
+ "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n",
+ "+-----+------------------------+------------------------+-----+\n",
+ "+-----+------------------------+------------------------+-----+\n",
+ "| idx | sentence | words | num |\n",
+ "+-----+------------------------+------------------------+-----+\n",
+ "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n",
+ "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n",
+ "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n",
+ "+-----+------------------------+------------------------+-----+\n"
+ ]
+ }
+ ],
+ "source": [
+ "dropped = dataset\n",
+ "dropped = dropped.drop(lambda ins:ins['num'] < 5, inplace=False)\n",
+ "print(id(dropped), id(dataset))\n",
+ "print(dropped)\n",
+ "print(dataset)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "aa277674",
+ "metadata": {},
+ "source": [
+ " 注二:**对对象使用等号一般表示传引用**,所以对`dataset`使用等号,是传引用而不是赋值\n",
+ "\n",
+ " 如下所示,**dropped 和 dataset 具有相同 id**,**对 dropped 执行删除操作 dataset 同时会被修改**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "77c8583a",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2491986424200 2491986424200\n",
+ "+-----+------------------------+------------------------+-----+\n",
+ "| idx | sentence | words | num |\n",
+ "+-----+------------------------+------------------------+-----+\n",
+ "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n",
+ "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n",
+ "+-----+------------------------+------------------------+-----+\n",
+ "+-----+------------------------+------------------------+-----+\n",
+ "| idx | sentence | words | num |\n",
+ "+-----+------------------------+------------------------+-----+\n",
+ "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n",
+ "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n",
+ "+-----+------------------------+------------------------+-----+\n"
+ ]
+ }
+ ],
+ "source": [
+ "dropped = dataset\n",
+ "dropped.drop(lambda ins:ins['num'] < 5)\n",
+ "print(id(dropped), id(dataset))\n",
+ "print(dropped)\n",
+ "print(dataset)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a76199dc",
+ "metadata": {},
+ "source": [
+ "在`dataset`中,使用`delet_instance`方法可以删除对应序号的`instance`实例,序号从0开始"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "d8824b40",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+-----+--------------------+------------------------+-----+\n",
+ "| idx | sentence | words | num |\n",
+ "+-----+--------------------+------------------------+-----+\n",
+ "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n",
+ "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n",
+ "+-----+--------------------+------------------------+-----+\n"
+ ]
+ }
+ ],
+ "source": [
+ "dataset = DataSet(data)\n",
+ "dataset.delete_instance(2)\n",
+ "print(dataset)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f4fa9f33",
+ "metadata": {},
+ "source": [
+ "在`dataset`中,使用`delet_field`方法可以删除对应名称的`field`字段"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "f68ddb40",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+-----+--------------------+------------------------------+\n",
+ "| idx | sentence | words |\n",
+ "+-----+--------------------+------------------------------+\n",
+ "| 0 | This is an apple . | ['This', 'is', 'an', 'app... |\n",
+ "| 1 | I like apples . | ['I', 'like', 'apples', '... |\n",
+ "+-----+--------------------+------------------------------+\n"
+ ]
+ }
+ ],
+ "source": [
+ "dataset.delete_field('num')\n",
+ "print(dataset)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b1e9d42c",
+ "metadata": {},
+ "source": [
+ "### 1.2 dataset 的数据预处理\n",
+ "\n",
+ "在`dataset`模块中,`apply`、`apply_field`、`apply_more`和`apply_field_more`函数可以进行简单的数据预处理\n",
+ "\n",
+ " **apply 和 apply_more 输入整条实例**,**apply_field 和 apply_field_more 仅输入实例的部分字段**\n",
+ "\n",
+ " **apply 和 apply_field 仅输出单个字段**,**apply_more 和 apply_field_more 则是输出多个字段**\n",
+ "\n",
+ " **apply 和 apply_field 返回的是个列表**,**apply_more 和 apply_field_more 返回的是个字典**\n",
+ "\n",
+ " 预处理过程中,通过`progress_bar`参数设置显示进度条类型,通过`num_proc`设置多进程\n",
+ "***\n",
+ "\n",
+ "`apply`的参数包括一个函数`func`和一个新字段名`new_field_name`,函数`func`的处理对象是`dataset`模块中\n",
+ "\n",
+ " 的每个`instance`实例,函数`func`的处理结果存放在`new_field_name`对应的新建字段内"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "72a0b5f9",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Processing: 0%| | 0/3 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+-----+------------------------------+------------------------------+\n",
+ "| idx | sentence | words |\n",
+ "+-----+------------------------------+------------------------------+\n",
+ "| 0 | This is an apple . | ['This', 'is', 'an', 'app... |\n",
+ "| 1 | I like apples . | ['I', 'like', 'apples', '... |\n",
+ "| 2 | Apples are good for our h... | ['Apples', 'are', 'good',... |\n",
+ "+-----+------------------------------+------------------------------+\n"
+ ]
+ }
+ ],
+ "source": [
+ "from fastNLP import DataSet\n",
+ "\n",
+ "data = {'idx': [0, 1, 2], \n",
+ " 'sentence':[\"This is an apple .\", \"I like apples .\", \"Apples are good for our health .\"], }\n",
+ "dataset = DataSet(data)\n",
+ "dataset.apply(lambda ins: ins['sentence'].split(), new_field_name='words', progress_bar=\"tqdm\") #\n",
+ "print(dataset)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c10275ee",
+ "metadata": {},
+ "source": [
+ " **apply 使用的函数可以是一个基于 lambda 表达式的匿名函数**,**也可以是一个自定义的函数**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "b1a8631f",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Processing: 0%| | 0/3 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+-----+------------------------------+------------------------------+\n",
+ "| idx | sentence | words |\n",
+ "+-----+------------------------------+------------------------------+\n",
+ "| 0 | This is an apple . | ['This', 'is', 'an', 'app... |\n",
+ "| 1 | I like apples . | ['I', 'like', 'apples', '... |\n",
+ "| 2 | Apples are good for our h... | ['Apples', 'are', 'good',... |\n",
+ "+-----+------------------------------+------------------------------+\n"
+ ]
+ }
+ ],
+ "source": [
+ "dataset = DataSet(data)\n",
+ "\n",
+ "def get_words(instance):\n",
+ " sentence = instance['sentence']\n",
+ " words = sentence.split()\n",
+ " return words\n",
+ "\n",
+ "dataset.apply(get_words, new_field_name='words', progress_bar=\"tqdm\")\n",
+ "print(dataset)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "64abf745",
+ "metadata": {},
+ "source": [
+ "`apply_field`的参数,除了函数`func`外还有`field_name`和`new_field_name`,该函数`func`的处理对象仅\n",
+ "\n",
+ " 是`dataset`模块中的每个`field_name`对应的字段内容,处理结果存放在`new_field_name`对应的新建字段内"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "057c1d2c",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Processing: 0%| | 0/3 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+-----+------------------------------+------------------------------+\n",
+ "| idx | sentence | words |\n",
+ "+-----+------------------------------+------------------------------+\n",
+ "| 0 | This is an apple . | ['This', 'is', 'an', 'app... |\n",
+ "| 1 | I like apples . | ['I', 'like', 'apples', '... |\n",
+ "| 2 | Apples are good for our h... | ['Apples', 'are', 'good',... |\n",
+ "+-----+------------------------------+------------------------------+\n"
+ ]
+ }
+ ],
+ "source": [
+ "dataset = DataSet(data)\n",
+ "dataset.apply_field(lambda sent:sent.split(), field_name='sentence', new_field_name='words', \n",
+ " progress_bar=\"tqdm\")\n",
+ "print(dataset)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5a9cc8b2",
+ "metadata": {},
+ "source": [
+ "`apply_more`的参数只有函数`func`,函数`func`的处理对象是`dataset`模块中的每个`instance`实例\n",
+ "\n",
+ " 要求函数`func`返回一个字典,根据字典的`key-value`确定存储在`dataset`中的字段名称与内容"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "51e2f02c",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Processing: 0%| | 0/3 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+-----+------------------------+------------------------+-----+\n",
+ "| idx | sentence | words | num |\n",
+ "+-----+------------------------+------------------------+-----+\n",
+ "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n",
+ "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n",
+ "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n",
+ "+-----+------------------------+------------------------+-----+\n"
+ ]
+ }
+ ],
+ "source": [
+ "dataset = DataSet(data)\n",
+ "dataset.apply_more(lambda ins:{'words': ins['sentence'].split(), 'num': len(ins['sentence'].split())}, \n",
+ " progress_bar=\"tqdm\")\n",
+ "print(dataset)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "02d2b7ef",
+ "metadata": {},
+ "source": [
+ "`apply_more`的参数只有函数`func`,函数`func`的处理对象是`dataset`模块中的每个`instance`实例\n",
+ "\n",
+ " 要求函数`func`返回一个字典,根据字典的`key-value`确定存储在`dataset`中的字段名称与内容"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "db4295d5",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Processing: 0%| | 0/3 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+-----+------------------------+------------------------+-----+\n",
+ "| idx | sentence | words | num |\n",
+ "+-----+------------------------+------------------------+-----+\n",
+ "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n",
+ "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n",
+ "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n",
+ "+-----+------------------------+------------------------+-----+\n"
+ ]
+ }
+ ],
+ "source": [
+ "dataset = DataSet(data)\n",
+ "dataset.apply_field_more(lambda sent:{'words': sent.split(), 'num': len(sent.split())}, \n",
+ " field_name='sentence', progress_bar=\"tqdm\")\n",
+ "print(dataset)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9c09e592",
+ "metadata": {},
+ "source": [
+ "### 1.3 延伸:instance 和 field\n",
+ "\n",
+ "在`fastNLP 1.0`中,使用`Instance`模块表示数据集`dataset`中的每条数据,被称为实例\n",
+ "\n",
+ " 构造方式类似于构造一个字典,通过键值相同的`Instance`列表,也可以初始化一个`dataset`,代码如下"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "012f537c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from fastNLP import DataSet\n",
+ "from fastNLP import Instance\n",
+ "\n",
+ "dataset = DataSet([\n",
+ " Instance(sentence=\"This is an apple .\",\n",
+ " words=['This', 'is', 'an', 'apple', '.'],\n",
+ " num=5),\n",
+ " Instance(sentence=\"I like apples .\",\n",
+ " words=['I', 'like', 'apples', '.'],\n",
+ " num=4),\n",
+ " Instance(sentence=\"Apples are good for our health .\",\n",
+ " words=['Apples', 'are', 'good', 'for', 'our', 'health', '.'],\n",
+ " num=7),\n",
+ " ])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2fafb1ef",
+ "metadata": {},
+ "source": [
+ " 通过`items`、`keys`和`values`方法,可以分别获得`dataset`的`item`列表、`key`列表、`value`列表"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "a4c1c10d",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "dict_items([('sentence', 'This is an apple .'), ('words', ['This', 'is', 'an', 'apple', '.']), ('num', 5)])\n",
+ "dict_keys(['sentence', 'words', 'num'])\n",
+ "dict_values(['This is an apple .', ['This', 'is', 'an', 'apple', '.'], 5])\n"
+ ]
+ }
+ ],
+ "source": [
+ "ins = Instance(sentence=\"This is an apple .\", words=['This', 'is', 'an', 'apple', '.'], num=5)\n",
+ "\n",
+ "print(ins.items())\n",
+ "print(ins.keys())\n",
+ "print(ins.values())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b5459a2d",
+ "metadata": {},
+ "source": [
+ " 通过`add_field`方法,可以在`Instance`实例中,通过参数`field_name`添加字段,通过参数`field`赋值"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "55376402",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+--------------------+------------------------+-----+-----+\n",
+ "| sentence | words | num | idx |\n",
+ "+--------------------+------------------------+-----+-----+\n",
+ "| This is an apple . | ['This', 'is', 'an'... | 5 | 0 |\n",
+ "+--------------------+------------------------+-----+-----+\n"
+ ]
+ }
+ ],
+ "source": [
+ "ins.add_field(field_name='idx', field=0)\n",
+ "print(ins)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "49caaa9c",
+ "metadata": {},
+ "source": [
+ "在`fastNLP 1.0`中,使用`FieldArray`模块表示数据集`dataset`中的每条字段名(注:没有`field`类)\n",
+ "\n",
+ " 通过`get_all_fields`方法可以获取`dataset`的字段列表\n",
+ "\n",
+ " 通过`get_field_names`方法可以获取`dataset`的字段名称列表,代码如下"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "fe15f4c1",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'sentence': ,\n",
+ " 'words': ,\n",
+ " 'num': }"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dataset.get_all_fields()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "5433815c",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['num', 'sentence', 'words']"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dataset.get_field_names()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4964eeed",
+ "metadata": {},
+ "source": [
+ "其他`dataset`的基本使用:通过`in`或者`has_field`方法可以判断`dataset`的是否包含某种字段\n",
+ "\n",
+ " 通过`rename_field`方法可以更改`dataset`中的字段名称;通过`concat`方法可以实现两个`dataset`中的拼接\n",
+ "\n",
+ " 通过`len`可以统计`dataset`中的实例数目;`dataset`的全部变量与函数可以通过`dir(dataset)`查询"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "25ce5488",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "3 False\n",
+ "6 True\n",
+ "+------------------------------+------------------------------+--------+\n",
+ "| sentence | words | length |\n",
+ "+------------------------------+------------------------------+--------+\n",
+ "| This is an apple . | ['This', 'is', 'an', 'app... | 5 |\n",
+ "| I like apples . | ['I', 'like', 'apples', '... | 4 |\n",
+ "| Apples are good for our h... | ['Apples', 'are', 'good',... | 7 |\n",
+ "| This is an apple . | ['This', 'is', 'an', 'app... | 5 |\n",
+ "| I like apples . | ['I', 'like', 'apples', '... | 4 |\n",
+ "| Apples are good for our h... | ['Apples', 'are', 'good',... | 7 |\n",
+ "+------------------------------+------------------------------+--------+\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(len(dataset), dataset.has_field('length')) \n",
+ "if 'num' in dataset:\n",
+ " dataset.rename_field('num', 'length')\n",
+ "elif 'length' in dataset:\n",
+ " dataset.rename_field('length', 'num')\n",
+ "dataset.concat(dataset)\n",
+ "print(len(dataset), dataset.has_field('length')) \n",
+ "print(dataset) "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e30a6cd7",
+ "metadata": {},
+ "source": [
+ "## 2. vocabulary 的结构与使用\n",
+ "\n",
+ "### 2.1 vocabulary 的创建与修改\n",
+ "\n",
+ "在`fastNLP 1.0`中,使用`Vocabulary`模块表示词汇表,**vocabulary 的核心是从单词到序号的映射**\n",
+ "\n",
+ " 可以直接通过构造函数实例化,通过查找`word2idx`属性,可以找到`vocabulary`映射对应的字典实现\n",
+ "\n",
+ " **默认补零 padding 用 \\ 表示**,**对应序号为0**;**未知单词 unknown 用 \\ 表示**,**对应序号1**\n",
+ "\n",
+ " 通过打印`vocabulary`可以看到词汇表中的单词列表,其中,`padding`和`unknown`不会显示"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "3515e096",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Vocabulary([]...)\n",
+ "{'': 0, '': 1}\n",
+ " 0\n",
+ " 1\n"
+ ]
+ }
+ ],
+ "source": [
+ "from fastNLP import Vocabulary\n",
+ "\n",
+ "vocab = Vocabulary()\n",
+ "print(vocab)\n",
+ "print(vocab.word2idx)\n",
+ "print(vocab.padding, vocab.padding_idx)\n",
+ "print(vocab.unknown, vocab.unknown_idx)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "640be126",
+ "metadata": {},
+ "source": [
+ "在`vocabulary`中,通过`add_word`方法或`add_word_lst`方法,可以单独或批量添加单词\n",
+ "\n",
+ " 通过`len`或`word_count`属性,可以显示`vocabulary`的单词量和每个单词添加的次数"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "88c7472a",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "5 Counter({'生活': 1, '就像': 1, '海洋': 1})\n",
+ "6 Counter({'生活': 1, '就像': 1, '海洋': 1, '只有': 1})\n",
+ "6 {'': 0, '': 1, '生活': 2, '就像': 3, '海洋': 4, '只有': 5}\n"
+ ]
+ }
+ ],
+ "source": [
+ "vocab.add_word_lst(['生活', '就像', '海洋'])\n",
+ "print(len(vocab), vocab.word_count)\n",
+ "vocab.add_word('只有')\n",
+ "print(len(vocab), vocab.word_count)\n",
+ "print(len(vocab), vocab.word2idx)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f9ec8b28",
+ "metadata": {},
+ "source": [
+ " **通过 to_word 方法可以找到单词对应的序号**,**通过 to_index 方法可以找到序号对应的单词**\n",
+ "\n",
+ " 由于序号0和序号1已经被占用,所以**新加入的词的序号从2开始计数**,如`'生活'`对应2\n",
+ "\n",
+ " 通过`has_word`方法可以判断单词是否在词汇表中,没有的单词被判做``"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "3447acde",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " 0\n",
+ " 1\n",
+ "生活 2\n",
+ "彼岸 1 False\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(vocab.to_word(0), vocab.to_index(''))\n",
+ "print(vocab.to_word(1), vocab.to_index(''))\n",
+ "print(vocab.to_word(2), vocab.to_index('生活'))\n",
+ "print('彼岸', vocab.to_index('彼岸'), vocab.has_word('彼岸'))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b4e36850",
+ "metadata": {},
+ "source": [
+ "**vocabulary 允许反复添加相同单词**,**可以通过 word_count 方法看到相应单词被添加的次数**\n",
+ "\n",
+ " 但其中没有``和``,`vocabulary`的全部变量与函数可以通过`dir(vocabulary)`查询\n",
+ "\n",
+ " 注:**使用 add_word_lst 添加单词**,**单词对应序号不会动态调整**,**使用 dataset 添加单词的情况不同**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "490b101c",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "生活 2\n",
+ "彼岸 12 True\n",
+ "13 Counter({'人': 4, '生活': 2, '就像': 2, '海洋': 2, '只有': 2, '意志': 1, '坚强的': 1, '才': 1, '能': 1, '到达': 1, '彼岸': 1})\n",
+ "13 {'': 0, '': 1, '生活': 2, '就像': 3, '海洋': 4, '只有': 5, '人': 6, '意志': 7, '坚强的': 8, '才': 9, '能': 10, '到达': 11, '彼岸': 12}\n"
+ ]
+ }
+ ],
+ "source": [
+ "vocab.add_word_lst(['生活', '就像', '海洋', '只有', '意志', '坚强的', '人', '人', '人', '人', '才', '能', '到达', '彼岸'])\n",
+ "print(vocab.to_word(2), vocab.to_index('生活'))\n",
+ "print('彼岸', vocab.to_index('彼岸'), vocab.has_word('彼岸'))\n",
+ "print(len(vocab), vocab.word_count)\n",
+ "print(len(vocab), vocab.word2idx)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "23e32a63",
+ "metadata": {},
+ "source": [
+ "### 2.2 vocabulary 与 OOV 问题\n",
+ "\n",
+ "在`vocabulary`模块初始化的时候,可以通过指定`unknown`和`padding`为`None`,限制其存在\n",
+ "\n",
+ " 此时添加单词直接从0开始标号,如果遇到未知单词会直接报错,即 out of vocabulary"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "a99ff909",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'positive': 0, 'negative': 1}\n",
+ "ValueError: word `neutral` not in vocabulary\n"
+ ]
+ }
+ ],
+ "source": [
+ "vocab = Vocabulary(unknown=None, padding=None)\n",
+ "\n",
+ "vocab.add_word_lst(['positive', 'negative'])\n",
+ "print(vocab.word2idx)\n",
+ "\n",
+ "try:\n",
+ " print(vocab.to_index('neutral'))\n",
+ "except ValueError:\n",
+ " print(\"ValueError: word `neutral` not in vocabulary\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "618da6bd",
+ "metadata": {},
+ "source": [
+ " 相应的,如果只指定其中的`unknown`,则编号会后移一个,同时遇到未知单词全部当做``"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "432f74c1",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'': 0, 'positive': 1, 'negative': 2}\n",
+ "0 \n"
+ ]
+ }
+ ],
+ "source": [
+ "vocab = Vocabulary(unknown='', padding=None)\n",
+ "\n",
+ "vocab.add_word_lst(['positive', 'negative'])\n",
+ "print(vocab.word2idx)\n",
+ "\n",
+ "print(vocab.to_index('neutral'), vocab.to_word(vocab.to_index('neutral')))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b6263f73",
+ "metadata": {},
+ "source": [
+ "## 3 dataset 和 vocabulary 的组合使用\n",
+ " \n",
+ "### 3.1 从 dataframe 中加载 dataset\n",
+ "\n",
+ "以下通过 [NLP-beginner](https://github.com/FudanNLP/nlp-beginner) 实践一中 [Rotten Tomatoes 影评数据集](https://www.kaggle.com/c/sentiment-analysis-on-movie-reviews) 的部分训练数据组成`test4dataset.tsv`文件\n",
+ "\n",
+ " 介绍如何使用`dataset`、`vocabulary`简单加载并处理数据集,首先使用`pandas`模块,读取原始数据的`dataframe`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "3dbd985d",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " SentenceId | \n",
+ " Sentence | \n",
+ " Sentiment | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1 | \n",
+ " A series of escapades demonstrating the adage ... | \n",
+ " negative | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 2 | \n",
+ " This quiet , introspective and entertaining in... | \n",
+ " positive | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 3 | \n",
+ " Even fans of Ismail Merchant 's work , I suspe... | \n",
+ " negative | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 4 | \n",
+ " A positively thrilling combination of ethnogra... | \n",
+ " neutral | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 5 | \n",
+ " A comedy-drama of nearly epic proportions root... | \n",
+ " positive | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 6 | \n",
+ " The Importance of Being Earnest , so thick wit... | \n",
+ " neutral | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " SentenceId Sentence Sentiment\n",
+ "0 1 A series of escapades demonstrating the adage ... negative\n",
+ "1 2 This quiet , introspective and entertaining in... positive\n",
+ "2 3 Even fans of Ismail Merchant 's work , I suspe... negative\n",
+ "3 4 A positively thrilling combination of ethnogra... neutral\n",
+ "4 5 A comedy-drama of nearly epic proportions root... positive\n",
+ "5 6 The Importance of Being Earnest , so thick wit... neutral"
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import pandas as pd\n",
+ "\n",
+ "df = pd.read_csv('./data/test4dataset.tsv', sep='\\t')\n",
+ "df"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "919ab350",
+ "metadata": {},
+ "source": [
+ "接着,通过`dataset`中的`from_pandas`方法填充数据集,并使用`apply_more`方法对文本进行分词操作"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "4f634586",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Processing: 0%| | 0/6 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+------------+------------------------------+-----------+\n",
+ "| SentenceId | Sentence | Sentiment |\n",
+ "+------------+------------------------------+-----------+\n",
+ "| 1 | ['a', 'series', 'of', 'es... | negative |\n",
+ "| 2 | ['this', 'quiet', ',', 'i... | positive |\n",
+ "| 3 | ['even', 'fans', 'of', 'i... | negative |\n",
+ "| 4 | ['a', 'positively', 'thri... | neutral |\n",
+ "| 5 | ['a', 'comedy-drama', 'of... | positive |\n",
+ "| 6 | ['the', 'importance', 'of... | neutral |\n",
+ "+------------+------------------------------+-----------+\n"
+ ]
+ }
+ ],
+ "source": [
+ "from fastNLP import DataSet\n",
+ "\n",
+ "dataset = DataSet()\n",
+ "dataset = dataset.from_pandas(df)\n",
+ "dataset.apply_more(lambda ins:{'SentenceId': ins['SentenceId'], \n",
+ " 'Sentence': ins['Sentence'].lower().split(), 'Sentiment': ins['Sentiment']}, \n",
+ " progress_bar=\"tqdm\")\n",
+ "print(dataset)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5c1ae192",
+ "metadata": {},
+ "source": [
+ " 如果需要保存中间结果,也可以使用`dataset`的`to_csv`方法,生成`.csv`或`.tsv`文件"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "46722efc",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset.to_csv('./data/test4dataset.csv')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5ba13989",
+ "metadata": {},
+ "source": [
+ "### 3.2 从 dataset 中获取 vocabulary\n",
+ "\n",
+ "然后,初始化`vocabulary`,使用`vocabulary`中的`from_dataset`方法,从`dataset`的指定字段中\n",
+ "\n",
+ " 获取字段中的所有元素,然后编号;如果指定字段是个列表,则针对字段中所有列表包含的元素编号\n",
+ "\n",
+ " 注:**使用 dataset 添加单词**,**不同于 add_word_list**,**单词被添加次数越多**,**序号越靠前**,例如案例中的`a`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "a2de615b",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Counter({'a': 9, 'of': 9, ',': 7, 'the': 6, '.': 5, 'is': 3, 'and': 3, 'good': 2, 'for': 2, 'which': 2, 'this': 2, \"'s\": 2, 'series': 1, 'escapades': 1, 'demonstrating': 1, 'adage': 1, 'that': 1, 'what': 1, 'goose': 1, 'also': 1, 'gander': 1, 'some': 1, 'occasionally': 1, 'amuses': 1, 'but': 1, 'none': 1, 'amounts': 1, 'to': 1, 'much': 1, 'story': 1, 'quiet': 1, 'introspective': 1, 'entertaining': 1, 'independent': 1, 'worth': 1, 'seeking': 1, 'even': 1, 'fans': 1, 'ismail': 1, 'merchant': 1, 'work': 1, 'i': 1, 'suspect': 1, 'would': 1, 'have': 1, 'hard': 1, 'time': 1, 'sitting': 1, 'through': 1, 'one': 1, 'positively': 1, 'thrilling': 1, 'combination': 1, 'ethnography': 1, 'all': 1, 'intrigue': 1, 'betrayal': 1, 'deceit': 1, 'murder': 1, 'shakespearean': 1, 'tragedy': 1, 'or': 1, 'juicy': 1, 'soap': 1, 'opera': 1, 'comedy-drama': 1, 'nearly': 1, 'epic': 1, 'proportions': 1, 'rooted': 1, 'in': 1, 'sincere': 1, 'performance': 1, 'by': 1, 'title': 1, 'character': 1, 'undergoing': 1, 'midlife': 1, 'crisis': 1, 'importance': 1, 'being': 1, 'earnest': 1, 'so': 1, 'thick': 1, 'with': 1, 'wit': 1, 'it': 1, 'plays': 1, 'like': 1, 'reading': 1, 'from': 1, 'bartlett': 1, 'familiar': 1, 'quotations': 1}) \n",
+ "\n",
+ "{'': 0, '': 1, 'a': 2, 'of': 3, ',': 4, 'the': 5, '.': 6, 'is': 7, 'and': 8, 'good': 9, 'for': 10, 'which': 11, 'this': 12, \"'s\": 13, 'series': 14, 'escapades': 15, 'demonstrating': 16, 'adage': 17, 'that': 18, 'what': 19, 'goose': 20, 'also': 21, 'gander': 22, 'some': 23, 'occasionally': 24, 'amuses': 25, 'but': 26, 'none': 27, 'amounts': 28, 'to': 29, 'much': 30, 'story': 31, 'quiet': 32, 'introspective': 33, 'entertaining': 34, 'independent': 35, 'worth': 36, 'seeking': 37, 'even': 38, 'fans': 39, 'ismail': 40, 'merchant': 41, 'work': 42, 'i': 43, 'suspect': 44, 'would': 45, 'have': 46, 'hard': 47, 'time': 48, 'sitting': 49, 'through': 50, 'one': 51, 'positively': 52, 'thrilling': 53, 'combination': 54, 'ethnography': 55, 'all': 56, 'intrigue': 57, 'betrayal': 58, 'deceit': 59, 'murder': 60, 'shakespearean': 61, 'tragedy': 62, 'or': 63, 'juicy': 64, 'soap': 65, 'opera': 66, 'comedy-drama': 67, 'nearly': 68, 'epic': 69, 'proportions': 70, 'rooted': 71, 'in': 72, 'sincere': 73, 'performance': 74, 'by': 75, 'title': 76, 'character': 77, 'undergoing': 78, 'midlife': 79, 'crisis': 80, 'importance': 81, 'being': 82, 'earnest': 83, 'so': 84, 'thick': 85, 'with': 86, 'wit': 87, 'it': 88, 'plays': 89, 'like': 90, 'reading': 91, 'from': 92, 'bartlett': 93, 'familiar': 94, 'quotations': 95} \n",
+ "\n",
+ "Vocabulary(['a', 'series', 'of', 'escapades', 'demonstrating']...)\n"
+ ]
+ }
+ ],
+ "source": [
+ "from fastNLP import Vocabulary\n",
+ "\n",
+ "vocab = Vocabulary()\n",
+ "vocab = vocab.from_dataset(dataset, field_name='Sentence')\n",
+ "print(vocab.word_count, '\\n')\n",
+ "print(vocab.word2idx, '\\n')\n",
+ "print(vocab)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f0857ccb",
+ "metadata": {},
+ "source": [
+ "之后,**通过 vocabulary 的 index_dataset 方法**,**调整 dataset 中指定字段的元素**,**使用编号将之代替**\n",
+ "\n",
+ " 使用上述方法,可以将影评数据集中的单词序列转化为词编号序列,为接下来转化为词嵌入序列做准备"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "2f9a04b2",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+------------+------------------------------+-----------+\n",
+ "| SentenceId | Sentence | Sentiment |\n",
+ "+------------+------------------------------+-----------+\n",
+ "| 1 | [2, 14, 3, 15, 16, 5, 17,... | negative |\n",
+ "| 2 | [12, 32, 4, 33, 8, 34, 35... | positive |\n",
+ "| 3 | [38, 39, 3, 40, 41, 13, 4... | negative |\n",
+ "| 4 | [2, 52, 53, 54, 3, 55, 8,... | neutral |\n",
+ "| 5 | [2, 67, 3, 68, 69, 70, 71... | positive |\n",
+ "| 6 | [5, 81, 3, 82, 83, 4, 84,... | neutral |\n",
+ "+------------+------------------------------+-----------+\n"
+ ]
+ }
+ ],
+ "source": [
+ "vocab.index_dataset(dataset, field_name='Sentence')\n",
+ "print(dataset)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6b26b707",
+ "metadata": {},
+ "source": [
+ "最后,使用相同方法,再将`dataset`中`Sentiment`字段中的`negative`、`neutral`、`positive`转化为数字编号"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "5f5eed18",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'negative': 0, 'positive': 1, 'neutral': 2}\n",
+ "+------------+------------------------------+-----------+\n",
+ "| SentenceId | Sentence | Sentiment |\n",
+ "+------------+------------------------------+-----------+\n",
+ "| 1 | [2, 14, 3, 15, 16, 5, 17,... | 0 |\n",
+ "| 2 | [12, 32, 4, 33, 8, 34, 35... | 1 |\n",
+ "| 3 | [38, 39, 3, 40, 41, 13, 4... | 0 |\n",
+ "| 4 | [2, 52, 53, 54, 3, 55, 8,... | 2 |\n",
+ "| 5 | [2, 67, 3, 68, 69, 70, 71... | 1 |\n",
+ "| 6 | [5, 81, 3, 82, 83, 4, 84,... | 2 |\n",
+ "+------------+------------------------------+-----------+\n"
+ ]
+ }
+ ],
+ "source": [
+ "target_vocab = Vocabulary(padding=None, unknown=None)\n",
+ "\n",
+ "target_vocab.from_dataset(dataset, field_name='Sentiment')\n",
+ "print(target_vocab.word2idx)\n",
+ "target_vocab.index_dataset(dataset, field_name='Sentiment')\n",
+ "print(dataset)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "eed7ea64",
+ "metadata": {},
+ "source": [
+ "在最后的最后,通过以下的一张图,来总结本章关于`dataset`和`vocabulary`主要知识点的讲解,以及两者的联系\n",
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "35b4f0f7",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/source/tutorials/fastnlp_tutorial_2.ipynb b/docs/source/tutorials/fastnlp_tutorial_2.ipynb
new file mode 100644
index 00000000..546e471d
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_tutorial_2.ipynb
@@ -0,0 +1,884 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# T2. databundle 和 tokenizer 的基本使用\n",
+ "\n",
+ " 1 fastNLP 中 dataset 的延伸\n",
+ "\n",
+ " 1.1 databundle 的概念与使用\n",
+ "\n",
+ " 2 fastNLP 中的 tokenizer\n",
+ " \n",
+ " 2.1 PreTrainedTokenizer 的概念\n",
+ "\n",
+ " 2.2 BertTokenizer 的基本使用\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 1. fastNLP 中 dataset 的延伸\n",
+ "\n",
+ "### 1.1 databundle 的概念与使用\n",
+ "\n",
+ "在`fastNLP 1.0`中,在常用的数据加载模块`DataLoader`和数据集`DataSet`模块之间,还存在\n",
+ "\n",
+ " 一个中间模块,即 **数据包 DataBundle 模块**,可以从`fastNLP.io`路径中导入该模块\n",
+ "\n",
+ "在`fastNLP 1.0`中,**一个 databundle 数据包包含若干 dataset 数据集和 vocabulary 词汇表**\n",
+ "\n",
+ " 分别存储在`datasets`和`vocabs`两个变量中,所以了解`databundle`数据包之前\n",
+ "\n",
+ "需要首先**复习 dataset 数据集和 vocabulary 词汇表**,**下面的一串代码**,**你知道其大概含义吗?**\n",
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Processing: 0%| | 0/6 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+------------------------------------------+----------+\n",
+ "| text | label |\n",
+ "+------------------------------------------+----------+\n",
+ "| ['a', 'series', 'of', 'escapades', 'd... | negative |\n",
+ "| ['this', 'quiet', ',', 'introspective... | positive |\n",
+ "| ['even', 'fans', 'of', 'ismail', 'mer... | negative |\n",
+ "| ['the', 'importance', 'of', 'being', ... | neutral |\n",
+ "+------------------------------------------+----------+\n",
+ "+------------------------------------------+----------+\n",
+ "| text | label |\n",
+ "+------------------------------------------+----------+\n",
+ "| ['a', 'comedy-drama', 'of', 'nearly',... | positive |\n",
+ "| ['a', 'positively', 'thrilling', 'com... | neutral |\n",
+ "+------------------------------------------+----------+\n",
+ "{'': 0, '': 1, 'negative': 2, 'positive': 3, 'neutral': 4}\n"
+ ]
+ }
+ ],
+ "source": [
+ "import pandas as pd\n",
+ "\n",
+ "from fastNLP import DataSet\n",
+ "from fastNLP import Vocabulary\n",
+ "from fastNLP.io import DataBundle\n",
+ "\n",
+ "datasets = DataSet.from_pandas(pd.read_csv('./data/test4dataset.tsv', sep='\\t'))\n",
+ "datasets.rename_field('Sentence', 'text')\n",
+ "datasets.rename_field('Sentiment', 'label')\n",
+ "datasets.apply_more(lambda ins:{'label': ins['label'].lower(), \n",
+ " 'text': ins['text'].lower().split()},\n",
+ " progress_bar='tqdm')\n",
+ "datasets.delete_field('SentenceId')\n",
+ "train_ds, test_ds = datasets.split(ratio=0.7)\n",
+ "datasets = {'train': train_ds, 'test': test_ds}\n",
+ "print(datasets['train'])\n",
+ "print(datasets['test'])\n",
+ "\n",
+ "vocabs = {}\n",
+ "vocabs['label'] = Vocabulary().from_dataset(datasets['train'].concat(datasets['test'], inplace=False), field_name='label')\n",
+ "vocabs['text'] = Vocabulary().from_dataset(datasets['train'].concat(datasets['test'], inplace=False), field_name='text')\n",
+ "print(vocabs['label'].word2idx)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "上述代码的含义是:从`test4dataset`的 6 条数据中,划分 4 条训练集(`int(6*0.7) = 4`),2 条测试集\n",
+ "\n",
+ " 修改相关字段名称,删除序号字段,同时将标签都设为小写,对文本进行分词\n",
+ "\n",
+ " 接着通过`concat`方法拼接测试集训练集,注意设置`inplace=False`,生成临时的新数据集\n",
+ "\n",
+ " 使用`from_dataset`方法从拼接的数据集中抽取词汇表,为将数据集中的单词替换为序号做准备\n",
+ "\n",
+ "由此就可以得到**数据集字典 datasets**(**对应训练集、测试集**)和**词汇表字典 vocabs**(**对应数据集各字段**)\n",
+ "\n",
+ " 然后就可以初始化`databundle`了,通过`print`可以观察其大致结构,效果如下"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "In total 2 datasets:\n",
+ "\ttrain has 4 instances.\n",
+ "\ttest has 2 instances.\n",
+ "In total 2 vocabs:\n",
+ "\tlabel has 5 entries.\n",
+ "\ttext has 96 entries.\n",
+ "\n",
+ "['train', 'test']\n",
+ "['label', 'text']\n"
+ ]
+ }
+ ],
+ "source": [
+ "data_bundle = DataBundle(datasets=datasets, vocabs=vocabs)\n",
+ "print(data_bundle)\n",
+ "print(data_bundle.get_dataset_names())\n",
+ "print(data_bundle.get_vocab_names())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "此外,也可以通过`data_bundle`的`num_dataset`和`num_vocab`返回数据表和词汇表个数\n",
+ "\n",
+ " 通过`data_bundle`的`iter_datasets`和`iter_vocabs`遍历数据表和词汇表"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "In total 2 datasets:\n",
+ "\ttrain has 4 instances.\n",
+ "\ttest has 2 instances.\n",
+ "In total 2 datasets:\n",
+ "\tlabel has 5 entries.\n",
+ "\ttext has 96 entries.\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"In total %d datasets:\" % data_bundle.num_dataset)\n",
+ "for name, dataset in data_bundle.iter_datasets():\n",
+ " print(\"\\t%s has %d instances.\" % (name, len(dataset)))\n",
+ "print(\"In total %d datasets:\" % data_bundle.num_dataset)\n",
+ "for name, vocab in data_bundle.iter_vocabs():\n",
+ " print(\"\\t%s has %d entries.\" % (name, len(vocab)))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "在数据包`databundle`中,也有和数据集`dataset`类似的四个`apply`函数,即\n",
+ "\n",
+ " `apply`函数、`apply_field`函数、`apply_field_more`函数和`apply_more`函数\n",
+ "\n",
+ " 负责对数据集进行预处理,如下所示是`apply_more`函数的示例,其他函数类似\n",
+ "\n",
+ "此外,通过`get_dataset`函数,可以通过数据表名`name`称找到对应数据表\n",
+ "\n",
+ " 通过`get_vocab`函数,可以通过词汇表名`field_name`称找到对应词汇表"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Processing: 0%| | 0/4 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Processing: 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+------------------------------+----------+-----+\n",
+ "| text | label | len |\n",
+ "+------------------------------+----------+-----+\n",
+ "| ['a', 'series', 'of', 'es... | negative | 37 |\n",
+ "| ['this', 'quiet', ',', 'i... | positive | 11 |\n",
+ "| ['even', 'fans', 'of', 'i... | negative | 21 |\n",
+ "| ['the', 'importance', 'of... | neutral | 20 |\n",
+ "+------------------------------+----------+-----+\n"
+ ]
+ }
+ ],
+ "source": [
+ "data_bundle.apply_more(lambda ins:{'len': len(ins['text'])}, progress_bar='tqdm')\n",
+ "print(data_bundle.get_dataset('train'))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 2. fastNLP 中的 tokenizer\n",
+ "\n",
+ "### 2.1 PreTrainTokenizer 的提出\n",
+ "\n",
+ "在`fastNLP 1.0`中,**使用 PreTrainedTokenizer 模块来为数据集中的词语进行词向量的标注**\n",
+ "\n",
+ " 需要注意的是,`PreTrainedTokenizer`模块的下载和导入**需要确保环境安装了 transformers 模块**\n",
+ "\n",
+ " 这是因为 `fastNLP 1.0`中`PreTrainedTokenizer`模块的实现基于`Huggingface Transformers`库\n",
+ "\n",
+ "**Huggingface Transformers 是一个开源的**,**基于 transformer 模型结构提供的预训练语言库**\n",
+ "\n",
+ " 包含了多种经典的基于`transformer`的预训练模型,如`BERT`、`BART`、`RoBERTa`、`GPT2`、`CPT`\n",
+ "\n",
+ " 更多相关内容可以参考`Huggingface Transformers`的[相关论文](https://arxiv.org/pdf/1910.03771.pdf)、[官方文档](https://huggingface.co/transformers/)以及[的代码仓库](https://github.com/huggingface/transformers)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.2 BertTokenizer 的基本使用\n",
+ "\n",
+ "在`fastNLP 1.0`中,以`PreTrainedTokenizer`为基类,泛化出多个子类,实现基于`BERT`等模型的标注\n",
+ "\n",
+ " 本节以`BertTokenizer`模块为例,展示`PreTrainedTokenizer`模块的使用方法与应用实例\n",
+ "\n",
+ "**BertTokenizer 的初始化包括 导入模块和导入数据 两步**,先通过从`fastNLP.transformers.torch`中\n",
+ "\n",
+ " 导入`BertTokenizer`模块,再**通过 from_pretrained 方法指定 tokenizer 参数类型下载**\n",
+ "\n",
+ " 其中,**'bert-base-uncased' 指定 tokenizer 使用的预训练 BERT 类型**:单词不区分大小写\n",
+ "\n",
+ " **模块层数 L=12**,**隐藏层维度 H=768**,**自注意力头数 A=12**,**总参数量 110M**\n",
+ "\n",
+ " 另外,模型参数自动下载至 home 目录下的`~\\.cache\\huggingface\\transformers`文件夹中"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "from fastNLP.transformers.torch import BertTokenizer\n",
+ "\n",
+ "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "通过变量`vocab_size`和`vocab_files_names`可以查看`BertTokenizer`的词汇表的大小和对应文件\n",
+ "\n",
+ " 通过变量`vocab`可以访问`BertTokenizer`预训练的词汇表(由于内容过大就不演示了"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "30522 {'vocab_file': 'vocab.txt'}\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(tokenizer.vocab_size, tokenizer.vocab_files_names)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "通过变量`all_special_tokens`或通过变量`special_tokens_map`可以**查看 BertTokenizer 内置的特殊词素**\n",
+ "\n",
+ " 包括**未知符 '[UNK]'**, **断句符 '[SEP]'**, **补零符 '[PAD]'**, **分类符 '[CLS]'**, **掩码 '[MASK]'**\n",
+ "\n",
+ "通过变量`all_special_ids`可以**查看 BertTokenizer 内置的特殊词素对应的词汇表编号**,相同功能\n",
+ "\n",
+ " 也可以直接通过查看`pad_token`,值为`'[UNK]'`,和`pad_token_id`,值为`0`,等变量来实现"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "pad_token [PAD] 0\n",
+ "unk_token [UNK] 100\n",
+ "cls_token [CLS] 101\n",
+ "sep_token [SEP] 102\n",
+ "msk_token [MASK] 103\n",
+ "all_tokens ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'] [100, 102, 0, 101, 103]\n",
+ "{'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}\n"
+ ]
+ }
+ ],
+ "source": [
+ "print('pad_token', tokenizer.pad_token, tokenizer.pad_token_id) \n",
+ "print('unk_token', tokenizer.unk_token, tokenizer.unk_token_id) \n",
+ "print('cls_token', tokenizer.cls_token, tokenizer.cls_token_id) \n",
+ "print('sep_token', tokenizer.sep_token, tokenizer.sep_token_id)\n",
+ "print('msk_token', tokenizer.mask_token, tokenizer.mask_token_id)\n",
+ "print('all_tokens', tokenizer.all_special_tokens, tokenizer.all_special_ids)\n",
+ "print(tokenizer.special_tokens_map)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "此外,还可以添加其他特殊字符,例如起始符`[BOS]`、终止符`[EOS]`,添加后词汇表编号也会相应改变\n",
+ "\n",
+ " *但是如何添加这两个之外的字符,并且如何将这两个的编号设置为 [UNK] 之外的编号???*"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "bos_token [BOS] 100\n",
+ "eos_token [EOS] 100\n",
+ "all_tokens ['[BOS]', '[EOS]', '[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'] [100, 100, 100, 102, 0, 101, 103]\n",
+ "{'bos_token': '[BOS]', 'eos_token': '[EOS]', 'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}\n"
+ ]
+ }
+ ],
+ "source": [
+ "tokenizer.bos_token = '[BOS]'\n",
+ "tokenizer.eos_token = '[EOS]'\n",
+ "# tokenizer.bos_token_id = 104\n",
+ "# tokenizer.eos_token_id = 105\n",
+ "print('bos_token', tokenizer.bos_token, tokenizer.bos_token_id)\n",
+ "print('eos_token', tokenizer.eos_token, tokenizer.eos_token_id)\n",
+ "print('all_tokens', tokenizer.all_special_tokens, tokenizer.all_special_ids)\n",
+ "print(tokenizer.special_tokens_map)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "在`BertTokenizer`中,**使用 tokenize 函数和 convert_tokens_to_string 函数可以实现文本和词素列表的互转**\n",
+ "\n",
+ " 此外,**使用 convert_tokens_to_ids 函数和 convert_ids_to_tokens 函数则可以实现词素和词素编号的互转**\n",
+ "\n",
+ " 上述四个函数的使用效果如下所示,此处可以明显看出,`tokenizer`分词和传统分词的不同效果,例如`'##cap'`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681, 2572, 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037, 2466, 1012]\n",
+ "['a', 'series', 'of', 'es', '##cap', '##ades', 'demonstrating', 'the', 'ada', '##ge', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gan', '##der', ',', 'some', 'of', 'which', 'occasionally', 'am', '##uses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.']\n",
+ "a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .\n"
+ ]
+ }
+ ],
+ "source": [
+ "text = \"a series of escapades demonstrating the adage that what is \" \\\n",
+ " \"good for the goose is also good for the gander , some of which \" \\\n",
+ " \"occasionally amuses but none of which amounts to much of a story .\" \n",
+ "tks = ['a', 'series', 'of', 'es', '##cap', '##ades', 'demonstrating', 'the', \n",
+ " 'ada', '##ge', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', \n",
+ " 'is', 'also', 'good', 'for', 'the', 'gan', '##der', ',', 'some', 'of', \n",
+ " 'which', 'occasionally', 'am', '##uses', 'but', 'none', 'of', 'which', \n",
+ " 'amounts', 'to', 'much', 'of', 'a', 'story', '.']\n",
+ "ids = [ 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, \n",
+ " 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204,\n",
+ " 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681, 2572,\n",
+ " 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037,\n",
+ " 2466, 1012]\n",
+ "\n",
+ "tokens = tokenizer.tokenize(text)\n",
+ "print(tokenizer.convert_tokens_to_ids(tokens))\n",
+ "\n",
+ "ids = tokenizer.convert_tokens_to_ids(tokens)\n",
+ "print(tokenizer.convert_ids_to_tokens(ids))\n",
+ "\n",
+ "print(tokenizer.convert_tokens_to_string(tokens))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "在`BertTokenizer`中,还有另外两个函数可以实现分词标注,分别是 **encode 和 decode 函数**,**可以直接实现**\n",
+ "\n",
+ " **文本字符串和词素编号列表的互转**,但是编码过程中会按照`BERT`的规则,**在句子首末加入 [CLS] 和 [SEP]**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681, 2572, 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037, 2466, 1012, 102]\n",
+ "[CLS] a series of escapades demonstrating the adage that what is good for the goose is also good for the gander, some of which occasionally amuses but none of which amounts to much of a story. [SEP]\n"
+ ]
+ }
+ ],
+ "source": [
+ "enc = tokenizer.encode(text)\n",
+ "print(tokenizer.encode(text))\n",
+ "dec = tokenizer.decode(enc)\n",
+ "print(tokenizer.decode(enc))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "在`encode`函数之上,还有`encode_plus`函数,这也是在数据预处理中,`BertTokenizer`模块最常用到的函数\n",
+ "\n",
+ " **encode 函数的参数**,**encode_plus 函数都有**;**encode 函数词素编号列表**,**encode_plus 函数返回字典**\n",
+ "\n",
+ "在`encode_plus`函数的返回值中,字段`input_ids`表示词素编号,其余两个字段后文有详细解释\n",
+ "\n",
+ " **字段 token_type_ids 详见 text_pairs 的示例**,**字段 attention_mask 详见 batch_text 的示例**\n",
+ "\n",
+ "在`encode_plus`函数的参数中,参数`add_special_tokens`表示是否按照`BERT`的规则,加入相关特殊字符\n",
+ "\n",
+ " 参数`max_length`表示句子截取最大长度(算特殊字符),在参数`truncation=True`时会自动截取\n",
+ "\n",
+ " 参数`return_attention_mask`约定返回的字典中是否包括`attention_mask`字段,以上案例如下"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'input_ids': [101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681, 2572, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n"
+ ]
+ }
+ ],
+ "source": [
+ "text = \"a series of escapades demonstrating the adage that what is good for the goose is also good for \"\\\n",
+ " \"the gander , some of which occasionally amuses but none of which amounts to much of a story .\" \n",
+ "\n",
+ "encoded = tokenizer.encode_plus(text=text, add_special_tokens=True, max_length=32, \n",
+ " truncation=True, return_attention_mask=True)\n",
+ "print(encoded)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "在`encode_plus`函数之上,还有`batch_encode_plus`函数(类似地,在`decode`之上,还有`batch_decode`\n",
+ "\n",
+ " 两者参数类似,**batch_encode_plus 函数针对批量文本 batch_text**,**或者批量句对 text_pairs**\n",
+ "\n",
+ "在针对批量文本`batch_text`的示例中,注意`batch_encode_plus`函数返回字典中的`attention_mask`字段\n",
+ "\n",
+ " 可以发现,**attention_mask 字段通过 01 标注出词素序列中该位置是否为补零**,可以用做自注意力的掩模"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'input_ids': [[101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 102, 0, 0], [101, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 102], [101, 2070, 1997, 2029, 5681, 2572, 25581, 102, 0, 0, 0, 0, 0, 0, 0], [101, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037, 2466, 102, 0, 0, 0]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]]}\n"
+ ]
+ }
+ ],
+ "source": [
+ "batch_text = [\"a series of escapades demonstrating the adage that\",\n",
+ " \"what is good for the goose is also good for the gander\",\n",
+ " \"some of which occasionally amuses\",\n",
+ " \"but none of which amounts to much of a story\" ]\n",
+ "\n",
+ "encoded = tokenizer.batch_encode_plus(batch_text_or_text_pairs=batch_text, padding=True,\n",
+ " add_special_tokens=True, max_length=16, truncation=True, \n",
+ " return_attention_mask=True)\n",
+ "print(encoded)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "而在针对批量句对`text_pairs`的示例中,注意`batch_encode_plus`函数返回字典中的`attention_mask`字段\n",
+ "\n",
+ " 可以发现,**token_type_ids 字段通过 01 标注出词素序列中该位置为句对中的第几句**,句对用 [SEP] 分割"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'input_ids': [[101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 102, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 102], [101, 2070, 1997, 2029, 5681, 2572, 25581, 102, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037, 2466, 102, 0, 0, 0, 0, 0, 0, 0, 0]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]]}\n"
+ ]
+ }
+ ],
+ "source": [
+ "text_pairs = [(\"a series of escapades demonstrating the adage that\",\n",
+ " \"what is good for the goose is also good for the gander\"),\n",
+ " (\"some of which occasionally amuses\",\n",
+ " \"but none of which amounts to much of a story\")]\n",
+ "\n",
+ "encoded = tokenizer.batch_encode_plus(batch_text_or_text_pairs=text_pairs, padding=True,\n",
+ " add_special_tokens=True, max_length=32, truncation=True, \n",
+ " return_attention_mask=True)\n",
+ "print(encoded)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "回到`encode_plus`上,在接下来的示例中,**使用内置的 functools.partial 模块构造 encode 函数**\n",
+ "\n",
+ " 接着**使用该函数对 databundle 进行数据预处理**,由于`tokenizer.encode_plus`返回的是一个字典\n",
+ "\n",
+ " 读入的是一个字段,所以此处使用`apply_field_more`方法,得到结果自动并入`databundle`中如下"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "functools.partial(, max_length=32, truncation=True, return_attention_mask=True)\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Processing: 0%| | 0/4 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Processing: 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+------------------+----------+-----+------------------+--------------------+--------------------+\n",
+ "| text | label | len | input_ids | token_type_ids | attention_mask |\n",
+ "+------------------+----------+-----+------------------+--------------------+--------------------+\n",
+ "| ['a', 'series... | negative | 37 | [101, 1037, 2... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
+ "| ['this', 'qui... | positive | 11 | [101, 2023, 4... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
+ "| ['even', 'fan... | negative | 21 | [101, 2130, 4... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
+ "| ['the', 'impo... | neutral | 20 | [101, 1996, 5... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
+ "+------------------+----------+-----+------------------+--------------------+--------------------+\n"
+ ]
+ }
+ ],
+ "source": [
+ "from functools import partial\n",
+ "\n",
+ "encode = partial(tokenizer.encode_plus, max_length=32, truncation=True,\n",
+ " return_attention_mask=True)\n",
+ "print(encode)\n",
+ "\n",
+ "data_bundle.apply_field_more(encode, field_name='text', progress_bar='tqdm')\n",
+ "print(data_bundle.datasets['train'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ },
+ "source": [
+ "经过`tokenizer`的处理,原始数据集中的文本被替换为词素编号列表,此时,调用`databundle`模块的\n",
+ "\n",
+ " **set_pad 函数**,**将 databundle 的补零符编号 pad_val 和 tokenizer 补零符编号 pad_token_id 统一**\n",
+ "\n",
+ " 该函数同时将`databundle`的`'input_ids'`字段添加到对应数据集的`collator`中(见`tutorial 3.`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{}\n",
+ "{}\n",
+ "{'input_ids': {'pad_val': 0, 'dtype': None, 'backend': 'auto', 'pad_fn': None}}\n",
+ "{'input_ids': {'pad_val': 0, 'dtype': None, 'backend': 'auto', 'pad_fn': None}}\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(data_bundle.get_dataset('train').collator.input_fields)\n",
+ "print(data_bundle.get_dataset('test').collator.input_fields)\n",
+ "data_bundle.set_pad('input_ids', pad_val=tokenizer.pad_token_id)\n",
+ "print(data_bundle.get_dataset('train').collator.input_fields)\n",
+ "print(data_bundle.get_dataset('test').collator.input_fields)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "最后,使用`from_dataset`、`index_dataset`和`iter_datasets`方法,为处理数据集的`'label'`字段编码\n",
+ "\n",
+ " 接着**通过 set_ignore 函数**,**指定 databundle 的部分字段**,如`'text'`等,**在划分 batch 时不再出现**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+----------------+----------+-----+----------------+--------------------+--------------------+--------+\n",
+ "| text | label | len | input_ids | token_type_ids | attention_mask | target |\n",
+ "+----------------+----------+-----+----------------+--------------------+--------------------+--------+\n",
+ "| ['a', 'seri... | negative | 37 | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 0 |\n",
+ "| ['this', 'q... | positive | 11 | [101, 2023,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 1 |\n",
+ "| ['even', 'f... | negative | 21 | [101, 2130,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 0 |\n",
+ "| ['the', 'im... | neutral | 20 | [101, 1996,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 2 |\n",
+ "+----------------+----------+-----+----------------+--------------------+--------------------+--------+\n"
+ ]
+ }
+ ],
+ "source": [
+ "target_vocab = Vocabulary(padding=None, unknown=None)\n",
+ "\n",
+ "target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label')\n",
+ "target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label',\n",
+ " new_field_name='target')\n",
+ "\n",
+ "data_bundle.set_ignore('text', 'len', 'label') \n",
+ "print(data_bundle.datasets['train'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "以上就是使用`dataset`、`vocabulary`、`databundle`和`tokenizer`实现输入文本数据的读取\n",
+ "\n",
+ " 分词标注、序列化的全部预处理过程,通过下方的代码梳理,相信你会有更详细的了解\n",
+ "\n",
+ "```python\n",
+ "# 首先,导入预训练的 BertTokenizer,这里使用 'bert-base-uncased' 版本\n",
+ "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
+ "\n",
+ "# 接着,导入数据,先生成为 dataset 形式,再变成 dataset-dict,并转为 databundle 形式\n",
+ "datasets = DataSet.from_pandas(pd.read_csv('./data/test4dataset.tsv', sep='\\t'))\n",
+ "train_ds, test_ds = datasets.split(ratio=0.7)\n",
+ "data_bundle = DataBundle(datasets={'train': train_ds, 'test': test_ds})\n",
+ "\n",
+ "# 然后,通过 tokenizer.encode_plus 函数,进行文本分词标注、修改并补充数据包内容\n",
+ "encode = partial(tokenizer.encode_plus, max_length=100, truncation=True,\n",
+ " return_attention_mask=True)\n",
+ "data_bundle.apply_field_more(encode, field_name='Sentence', progress_bar='tqdm')\n",
+ "\n",
+ "# 在修改好 'text' 字段的文本信息后,接着处理 'label' 字段的预测信息\n",
+ "target_vocab = Vocabulary(padding=None, unknown=None)\n",
+ "target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment')\n",
+ "target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment',\n",
+ " new_field_name='target')\n",
+ "\n",
+ "# 最后,通过 data_bundle 的其他一些函数,完成善后内容\n",
+ "data_bundle.set_pad('input_ids', pad_val=tokenizer.pad_token_id)\n",
+ "data_bundle.set_ignore('SentenceId', 'Sentiment', 'Sentence') \n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ },
+ "source": [
+ "\n",
+ "\n",
+ "在接下来的`tutorial 3.`中,将会介绍`fastNLP v1.0`中的`dataloader`模块,会涉及本章中\n",
+ "\n",
+ " 提到的`collator`模块,`fastNLP`的多框架适应以及完整的数据加载过程,敬请期待"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.13"
+ },
+ "pycharm": {
+ "stem_cell": {
+ "cell_type": "raw",
+ "metadata": {
+ "collapsed": false
+ },
+ "source": []
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
diff --git a/docs/source/tutorials/fastnlp_tutorial_3.ipynb b/docs/source/tutorials/fastnlp_tutorial_3.ipynb
new file mode 100644
index 00000000..4100105a
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_tutorial_3.ipynb
@@ -0,0 +1,621 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "213d538c",
+ "metadata": {},
+ "source": [
+ "# T3. dataloader 的内部结构和基本使用\n",
+ "\n",
+ " 1 fastNLP 中的 dataloader\n",
+ " \n",
+ " 1.1 dataloader 的基本介绍\n",
+ "\n",
+ " 1.2 dataloader 的函数创建\n",
+ "\n",
+ " 2 fastNLP 中 dataloader 的延伸\n",
+ "\n",
+ " 2.1 collator 的概念与使用\n",
+ "\n",
+ " 2.2 结合 datasets 框架"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "85857115",
+ "metadata": {},
+ "source": [
+ "## 1. fastNLP 中的 dataloader\n",
+ "\n",
+ "### 1.1 dataloader 的基本介绍\n",
+ "\n",
+ "在`fastNLP 1.0`的开发中,最关键的开发目标就是**实现 fastNLP 对当前主流机器学习框架**,例如\n",
+ "\n",
+ " **当下流行的 pytorch**,以及**国产的 paddle 、jittor 和 oneflow 的兼容**,扩大受众的同时,也是助力国产\n",
+ "\n",
+ "本着分而治之的思想,我们可以将`fastNLP 1.0`对`pytorch`、`paddle`、`jittor`、`oneflow`框架的兼容,划分为\n",
+ "\n",
+ " **对数据预处理**、**批量 batch 的划分与补齐**、**模型训练**、**模型评测**,**四个部分的兼容**\n",
+ "\n",
+ " 针对数据预处理,我们已经在`tutorial-1`中介绍了`dataset`和`vocabulary`的使用\n",
+ "\n",
+ " 而结合`tutorial-0`,我们可以发现**数据预处理环节本质上是框架无关的**\n",
+ "\n",
+ " 因为在不同框架下,读取的原始数据格式都差异不大,彼此也很容易转换\n",
+ "\n",
+ "只有涉及到张量、模型,不同框架才展现出其各自的特色:**pytorch 和 oneflow 中的 tensor 和 nn.Module**\n",
+ "\n",
+ " **在 paddle 中称为 tensor 和 nn.Layer**,**在 jittor 中则称为 Var 和 Module**\n",
+ "\n",
+ " 因此,**模型训练、模型评测**,**是兼容的重难点**,我们将会在`tutorial-5`中详细介绍\n",
+ "\n",
+ " 针对批量`batch`的处理,作为`fastNLP 1.0`中框架无关部分想框架相关部分的过渡\n",
+ "\n",
+ " 就是`dataloader`模块的职责,这也是本篇教程`tutorial-3`讲解的重点\n",
+ "\n",
+ "**dataloader 模块的职责**,详细划分可以包含以下三部分,**采样划分、补零对齐、框架匹配**\n",
+ "\n",
+ " 第一,确定`batch`大小,确定采样方式,划分后通过迭代器即可得到`batch`序列\n",
+ "\n",
+ " 第二,对于序列处理,这也是`fastNLP`主要针对的,将同个`batch`内的数据对齐\n",
+ "\n",
+ " 第三,**batch 内数据格式要匹配框架**,**但 batch 结构需保持一致**,**参数匹配机制**\n",
+ "\n",
+ " 对此,`fastNLP 1.0`给出了 **TorchDataLoader 、 PaddleDataLoader 、 JittorDataLoader 和 OneflowDataLoader**\n",
+ "\n",
+ " 分别针对并匹配不同框架,但彼此之间参数名、属性、方法仍然类似,前两者大致如下表所示\n",
+ "\n",
+ "名称|参数|属性|功能|内容\n",
+ "----|----|----|----|----|\n",
+ " `dataset` | √ | √ | 指定`dataloader`的数据内容 | |\n",
+ " `batch_size` | √ | √ | 指定`dataloader`的`batch`大小 | 默认`16` |\n",
+ " `shuffle` | √ | √ | 指定`dataloader`的数据是否打乱 | 默认`False` |\n",
+ " `collate_fn` | √ | √ | 指定`dataloader`的`batch`打包方法 | 视框架而定 |\n",
+ " `sampler` | √ | √ | 指定`dataloader`的`__len__`和`__iter__`函数的实现 | 默认`None` |\n",
+ " `batch_sampler` | √ | √ | 指定`dataloader`的`__len__`和`__iter__`函数的实现 | 默认`None` |\n",
+ " `drop_last` | √ | √ | 指定`dataloader`划分`batch`时是否丢弃剩余的 | 默认`False` |\n",
+ " `cur_batch_indices` | | √ | 记录`dataloader`当前遍历批量序号 | |\n",
+ " `num_workers` | √ | √ | 指定`dataloader`开启子进程数量 | 默认`0` |\n",
+ " `worker_init_fn` | √ | √ | 指定`dataloader`子进程初始方法 | 默认`None` |\n",
+ " `generator` | √ | √ | 指定`dataloader`子进程随机种子 | 默认`None` |\n",
+ " `prefetch_factor` | | √ | 指定为每个`worker`装载的`sampler`数量 | 默认`2` |"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "60a8a224",
+ "metadata": {},
+ "source": [
+ " 论及`dataloader`的函数,其中,`get_batch_indices`用来获取当前遍历到的`batch`序号,其他函数\n",
+ "\n",
+ " 包括`set_ignore`、`set_pad`和`databundle`类似,请参考`tutorial-2`,此处不做更多介绍\n",
+ "\n",
+ " 以下是`tutorial-2`中已经介绍过的数据预处理流程,接下来是对相关数据进行`dataloader`处理"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "aca72b49",
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\u001b[38;5;2m[i 0604 15:44:29.773860 92 log.cc:351] Load log_sync: 1\u001b[m\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Processing: 0%| | 0/4 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Processing: 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Processing: 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n",
+ "| SentenceId | Sentence | Sentiment | input_ids | token_type_ids | attention_mask | target |\n",
+ "+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n",
+ "| 1 | A series of... | negative | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 1 |\n",
+ "| 4 | A positivel... | neutral | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 2 |\n",
+ "| 3 | Even fans o... | negative | [101, 2130,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 1 |\n",
+ "| 5 | A comedy-dr... | positive | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 0 |\n",
+ "+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n"
+ ]
+ }
+ ],
+ "source": [
+ "import sys\n",
+ "sys.path.append('..')\n",
+ "\n",
+ "import pandas as pd\n",
+ "from functools import partial\n",
+ "from fastNLP.transformers.torch import BertTokenizer\n",
+ "\n",
+ "from fastNLP import DataSet\n",
+ "from fastNLP import Vocabulary\n",
+ "from fastNLP.io import DataBundle\n",
+ "\n",
+ "\n",
+ "class PipeDemo:\n",
+ " def __init__(self, tokenizer='bert-base-uncased'):\n",
+ " self.tokenizer = BertTokenizer.from_pretrained(tokenizer)\n",
+ "\n",
+ " def process_from_file(self, path='./data/test4dataset.tsv'):\n",
+ " datasets = DataSet.from_pandas(pd.read_csv(path, sep='\\t'))\n",
+ " train_ds, test_ds = datasets.split(ratio=0.7)\n",
+ " train_ds, dev_ds = datasets.split(ratio=0.8)\n",
+ " data_bundle = DataBundle(datasets={'train': train_ds, 'dev': dev_ds, 'test': test_ds})\n",
+ "\n",
+ " encode = partial(self.tokenizer.encode_plus, max_length=100, truncation=True,\n",
+ " return_attention_mask=True)\n",
+ " data_bundle.apply_field_more(encode, field_name='Sentence', progress_bar='tqdm')\n",
+ " \n",
+ " target_vocab = Vocabulary(padding=None, unknown=None)\n",
+ "\n",
+ " target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment')\n",
+ " target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment',\n",
+ " new_field_name='target')\n",
+ "\n",
+ " data_bundle.set_pad('input_ids', pad_val=self.tokenizer.pad_token_id)\n",
+ " data_bundle.set_ignore('SentenceId', 'Sentence', 'Sentiment') \n",
+ " return data_bundle\n",
+ "\n",
+ " \n",
+ "pipe = PipeDemo(tokenizer='bert-base-uncased')\n",
+ "\n",
+ "data_bundle = pipe.process_from_file('./data/test4dataset.tsv')\n",
+ "\n",
+ "print(data_bundle.get_dataset('train'))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "76e6b8ab",
+ "metadata": {},
+ "source": [
+ "### 1.2 dataloader 的函数创建\n",
+ "\n",
+ "在`fastNLP 1.0`中,**更方便、可能更常用的 dataloader 创建方法是通过 prepare_xx_dataloader 函数**\n",
+ "\n",
+ " 例如下方的`prepare_torch_dataloader`函数,指定必要参数,读取数据集,生成对应`dataloader`\n",
+ "\n",
+ " 类型为`TorchDataLoader`,只能适用于`pytorch`框架,因此对应`trainer`初始化时`driver='torch'`\n",
+ "\n",
+ "同时我们看还可以发现,在`fastNLP 1.0`中,**batch 表示为字典 dict 类型**,**key 值就是原先数据集中各个字段**\n",
+ "\n",
+ " **除去经过 DataBundle.set_ignore 函数隐去的部分**,而`value`值为`pytorch`框架对应的`torch.Tensor`类型"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "5fd60e42",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ " ['input_ids', 'token_type_ids', 'attention_mask', 'target']\n",
+ "{'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
+ " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
+ " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
+ " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),\n",
+ " 'input_ids': tensor([[ 101, 1037, 4038, 1011, 3689, 1997, 3053, 8680, 19173, 15685,\n",
+ " 1999, 1037, 18006, 2836, 2011, 1996, 2516, 2839, 14996, 3054,\n",
+ " 15509, 5325, 1012, 102, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0],\n",
+ " [ 101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262,\n",
+ " 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036,\n",
+ " 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681,\n",
+ " 2572, 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997,\n",
+ " 1037, 2466, 1012, 102],\n",
+ " [ 101, 2130, 4599, 1997, 19214, 6432, 1005, 1055, 2147, 1010,\n",
+ " 1045, 8343, 1010, 2052, 2031, 1037, 2524, 2051, 3564, 2083,\n",
+ " 2023, 2028, 1012, 102, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0],\n",
+ " [ 101, 1037, 13567, 26162, 5257, 1997, 3802, 7295, 9888, 1998,\n",
+ " 2035, 1996, 20014, 27611, 1010, 14583, 1010, 11703, 20175, 1998,\n",
+ " 4028, 1997, 1037, 8101, 2319, 10576, 2030, 1037, 28900, 7815,\n",
+ " 3850, 1012, 102, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0]]),\n",
+ " 'target': tensor([0, 1, 1, 2]),\n",
+ " 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
+ " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
+ " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
+ " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}\n"
+ ]
+ }
+ ],
+ "source": [
+ "from fastNLP import prepare_torch_dataloader\n",
+ "\n",
+ "train_dataset = data_bundle.get_dataset('train')\n",
+ "evaluate_dataset = data_bundle.get_dataset('dev')\n",
+ "\n",
+ "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
+ "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)\n",
+ "\n",
+ "print(type(train_dataloader))\n",
+ "\n",
+ "import pprint\n",
+ "\n",
+ "for batch in train_dataloader:\n",
+ " print(type(batch), type(batch['input_ids']), list(batch))\n",
+ " pprint.pprint(batch, width=1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9f457a6e",
+ "metadata": {},
+ "source": [
+ "之所以说`prepare_xx_dataloader`函数更方便,是因为其**导入对象不仅可也是 DataSet 类型**,**还可以**\n",
+ "\n",
+ " **是 DataBundle 类型**,不过数据集名称需要是`'train'`、`'dev'`、`'test'`供`fastNLP`识别\n",
+ "\n",
+ "例如下方就是**直接通过 prepare_paddle_dataloader 函数生成基于 PaddleDataLoader 的字典**\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "7827557d",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "from fastNLP import prepare_paddle_dataloader\n",
+ "\n",
+ "dl_bundle = prepare_paddle_dataloader(data_bundle, batch_size=16, shuffle=True)\n",
+ "\n",
+ "print(type(dl_bundle['train']))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d898cf40",
+ "metadata": {},
+ "source": [
+ " 而在接下来`trainer`的初始化过程中,按如下方式使用即可,除了初始化时`driver='paddle'`外\n",
+ "\n",
+ " 这里也可以看出`trainer`模块中,**evaluate_dataloaders 的设计允许评测可以针对多个数据集**\n",
+ "\n",
+ "```python\n",
+ "trainer = Trainer(\n",
+ " model=model,\n",
+ " train_dataloader=dl_bundle['train'],\n",
+ " optimizers=optimizer,\n",
+ "\t...\n",
+ "\tdriver='paddle',\n",
+ "\tdevice='gpu',\n",
+ "\t...\n",
+ " evaluate_dataloaders={'dev': dl_bundle['dev'], 'test': dl_bundle['test']}, \n",
+ " metrics={'acc': Accuracy()},\n",
+ "\t...\n",
+ ")\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d74d0523",
+ "metadata": {},
+ "source": [
+ "## 2. fastNLP 中 dataloader 的延伸\n",
+ "\n",
+ "### 2.1 collator 的概念与使用\n",
+ "\n",
+ "在`fastNLP 1.0`中,在数据加载模块`dataloader`内部,如之前表格所列举的,还存在其他的一些模块\n",
+ "\n",
+ " 例如,**实现序列的补零对齐的核对器 collator 模块**;注:`collate vt. 整理(文件或书等);核对,校勘`\n",
+ "\n",
+ "在`fastNLP 1.0`中,虽然`dataloader`随框架不同,但`collator`模块却是统一的,主要属性、方法如下表所示\n",
+ "\n",
+ "名称|属性|方法|功能|内容\n",
+ " ----|----|----|----|----|\n",
+ " `backend` | √ | | 记录`collator`对应框架 | 字符串型,如`'torch'` |\n",
+ " `padders` | √ | | 记录各字段对应的`padder`,每个负责具体补零对齐 | 字典类型 |\n",
+ " `ignore_fields` | √ | | 记录`dataloader`采样`batch`时不予考虑的字段 | 集合类型 |\n",
+ " `input_fields` | √ | | 记录`collator`每个字段的补零值、数据类型等 | 字典类型 |\n",
+ " `set_backend` | | √ | 设置`collator`对应框架 | 字符串型,如`'torch'` |\n",
+ " `set_ignore` | | √ | 设置`dataloader`采样`batch`时不予考虑的字段 | 字符串型,表示`field_name` |\n",
+ " `set_pad` | | √ | 设置`collator`每个字段的补零值、数据类型等 | |"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "d0795b3e",
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "train_dataloader.collate_fn\n",
+ "\n",
+ "print(type(train_dataloader.collate_fn))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5f816ef5",
+ "metadata": {},
+ "source": [
+ "此外,还可以 **手动定义 dataloader 中的 collate_fn**,而不是使用`fastNLP 1.0`中自带的`collator`模块\n",
+ "\n",
+ " 该函数的定义可以大致如下,需要注意的是,**定义 collate_fn 之前需要了解 batch 作为字典的格式**\n",
+ "\n",
+ " 该函数通过`collate_fn`参数传入`dataloader`,**在 batch 分发**(**而不是 batch 划分**)**时调用**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "ff8e405e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "\n",
+ "def collate_fn(batch):\n",
+ " input_ids, atten_mask, labels = [], [], []\n",
+ " max_length = [0] * 3\n",
+ " for each_item in batch:\n",
+ " input_ids.append(each_item['input_ids'])\n",
+ " max_length[0] = max(len(each_item['input_ids']), max_length[0])\n",
+ " atten_mask.append(each_item['token_type_ids'])\n",
+ " max_length[1] = max(len(each_item['token_type_ids']), max_length[1])\n",
+ " labels.append(each_item['attention_mask'])\n",
+ " max_length[2] = max(len(each_item['attention_mask']), max_length[2])\n",
+ "\n",
+ " for i in range(3):\n",
+ " each = (input_ids, atten_mask, labels)[i]\n",
+ " for item in each:\n",
+ " item.extend([0] * (max_length[i] - len(item)))\n",
+ " return {'input_ids': torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n",
+ " 'token_type_ids': torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n",
+ " 'attention_mask': torch.cat([torch.tensor(item) for item in labels], dim=0)}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "487b75fb",
+ "metadata": {},
+ "source": [
+ "注意:使用自定义的`collate_fn`函数,`trainer`的`collate_fn`变量也会自动调整为`function`类型"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "e916d1ac",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\n",
+ "{'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,\n",
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0]),\n",
+ " 'input_ids': tensor([[ 101, 1037, 4038, 1011, 3689, 1997, 3053, 8680, 19173, 15685,\n",
+ " 1999, 1037, 18006, 2836, 2011, 1996, 2516, 2839, 14996, 3054,\n",
+ " 15509, 5325, 1012, 102, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0],\n",
+ " [ 101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262,\n",
+ " 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036,\n",
+ " 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681,\n",
+ " 2572, 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997,\n",
+ " 1037, 2466, 1012, 102],\n",
+ " [ 101, 2130, 4599, 1997, 19214, 6432, 1005, 1055, 2147, 1010,\n",
+ " 1045, 8343, 1010, 2052, 2031, 1037, 2524, 2051, 3564, 2083,\n",
+ " 2023, 2028, 1012, 102, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0],\n",
+ " [ 101, 1037, 13567, 26162, 5257, 1997, 3802, 7295, 9888, 1998,\n",
+ " 2035, 1996, 20014, 27611, 1010, 14583, 1010, 11703, 20175, 1998,\n",
+ " 4028, 1997, 1037, 8101, 2319, 10576, 2030, 1037, 28900, 7815,\n",
+ " 3850, 1012, 102, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0]]),\n",
+ " 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
+ " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
+ " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
+ " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}\n"
+ ]
+ }
+ ],
+ "source": [
+ "train_dataloader = prepare_torch_dataloader(train_dataset, collate_fn=collate_fn, shuffle=True)\n",
+ "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, collate_fn=collate_fn, shuffle=True)\n",
+ "\n",
+ "print(type(train_dataloader))\n",
+ "print(type(train_dataloader.collate_fn))\n",
+ "\n",
+ "for batch in train_dataloader:\n",
+ " pprint.pprint(batch, width=1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0bd98365",
+ "metadata": {},
+ "source": [
+ "### 2.2 fastNLP 与 datasets 的结合\n",
+ "\n",
+ "从`tutorial-1`至`tutorial-3`,我们已经完成了对`fastNLP v1.0`数据读取、预处理、加载,整个流程的介绍\n",
+ "\n",
+ " 不过在实际使用中,我们往往也会采取更为简便的方法读取数据,例如使用`huggingface`的`datasets`模块\n",
+ "\n",
+ "**使用 datasets 模块中的 load_dataset 函数**,通过指定数据集两级的名称,示例中即是**GLUE 标准中的 SST-2 数据集**\n",
+ "\n",
+ " 即可以快速从网上下载好`SST-2`数据集读入,之后以`pandas.DataFrame`作为中介,再转化成`fastNLP.DataSet`\n",
+ "\n",
+ " 之后的步骤就和其他关于`dataset`、`databundle`、`vocabulary`、`dataloader`中介绍的相关使用相同了"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "91879c30",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "639a0ad3c63944c6abef4e8ee1f7bf7c",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/3 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from datasets import load_dataset\n",
+ "\n",
+ "sst2data = load_dataset('glue', 'sst2')\n",
+ "\n",
+ "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.13"
+ },
+ "pycharm": {
+ "stem_cell": {
+ "cell_type": "raw",
+ "metadata": {
+ "collapsed": false
+ },
+ "source": []
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/source/tutorials/fastnlp_tutorial_4.ipynb b/docs/source/tutorials/fastnlp_tutorial_4.ipynb
new file mode 100644
index 00000000..909991e1
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_tutorial_4.ipynb
@@ -0,0 +1,2614 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "fdd7ff16",
+ "metadata": {},
+ "source": [
+ "# T4. fastNLP 中的预定义模型\n",
+ "\n",
+ " 1 fastNLP 中 modules 的介绍\n",
+ " \n",
+ " 1.1 modules 模块、models 模块 简介\n",
+ "\n",
+ " 1.2 示例一:modules 实现 LSTM 分类\n",
+ "\n",
+ " 2 fastNLP 中 models 的介绍\n",
+ " \n",
+ " 2.1 示例一:models 实现 CNN 分类\n",
+ "\n",
+ " 2.3 示例二:models 实现 BiLSTM 标注"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d3d65d53",
+ "metadata": {},
+ "source": [
+ "## 1. fastNLP 中 modules 模块的介绍\n",
+ "\n",
+ "### 1.1 modules 模块、models 模块 简介\n",
+ "\n",
+ "在`fastNLP 1.0`中,**modules.torch 路径下定义了一些基于 pytorch 实现的基础模块**\n",
+ "\n",
+ " 包括长短期记忆网络`LSTM`、条件随机场`CRF`、`transformer`的编解码器模块等,详见下表\n",
+ "\n",
+ "代码名称|简要介绍|代码路径\n",
+ "----|----|----|\n",
+ " `LSTM` | 轻量封装`pytorch`的`LSTM` | `/modules/torch/encoder/lstm.py` |\n",
+ " `Seq2SeqEncoder` | 序列变换编码器,基类 | `/modules/torch/encoder/seq2seq_encoder.py` |\n",
+ " `LSTMSeq2SeqEncoder` | 序列变换编码器,基于`LSTM` | `/modules/torch/encoder/seq2seq_encoder.py` |\n",
+ " `TransformerSeq2SeqEncoder` | 序列变换编码器,基于`transformer` | `/modules/torch/encoder/seq2seq_encoder.py` |\n",
+ " `StarTransformer` | `Star-Transformer`的编码器部分 | `/modules/torch/encoder/star_transformer.py` |\n",
+ " `VarRNN` | 实现`Variational Dropout RNN` | `/modules/torch/encoder/variational_rnn.py` |\n",
+ " `VarLSTM` | 实现`Variational Dropout LSTM` | `/modules/torch/encoder/variational_rnn.py` |\n",
+ " `VarGRU` | 实现`Variational Dropout GRU` | `/modules/torch/encoder/variational_rnn.py` |\n",
+ " `MLP` | 多层感知机模型 | `/modules/torch/decoder/mlp.py` |\n",
+ " `ConditionalRandomField` | 条件随机场模型 | `/modules/torch/decoder/crf.py` |\n",
+ " `Seq2SeqDecoder` | 序列变换解码器,基类 | `/modules/torch/decoder/seq2seq_decoder.py` |\n",
+ " `LSTMSeq2SeqDecoder` | 序列变换解码器,基于`LSTM` | `/modules/torch/decoder/seq2seq_decoder.py` |\n",
+ " `TransformerSeq2SeqDecoder` | 序列变换解码器,基于`transformer` | `/modules/torch/decoder/seq2seq_decoder.py` |\n",
+ " `SequenceGenerator` | 序列生成,封装`Seq2SeqDecoder` | `/models/torch/sequence_labeling.py` |\n",
+ " `TimestepDropout` | 在每个`timestamp`上`dropout` | `/modules/torch/dropout.py` |"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "89ffcf07",
+ "metadata": {},
+ "source": [
+ " **models.torch 路径下定义了一些基于 pytorch 、 modules 实现的预定义模型** \n",
+ "\n",
+ " 例如基于`CNN`的分类模型、基于`BiLSTM+CRF`的标注模型、基于[双仿射注意力机制](https://arxiv.org/pdf/1611.01734.pdf)的分析模型\n",
+ "\n",
+ " 基于`modules.torch`中的`LSTM`/`transformer`编/解码器模块的序列变换/生成模型,详见下表\n",
+ "\n",
+ "代码名称|简要介绍|代码路径\n",
+ "----|----|----|\n",
+ "| `BiaffineParser` | 句法分析模型,基于双仿射注意力 | `/models/torch/biaffine_parser.py` |\n",
+ "| `CNNText` | 文本分类模型,基于`CNN` | `/models/torch/cnn_text_classification.py` |\n",
+ "| `Seq2SeqModel` | 序列变换,基类`encoder+decoder` | `/models/torch/seq2seq_model.py` |\n",
+ "| `LSTMSeq2SeqModel` | 序列变换,基于`LSTM` | `/models/torch/seq2seq_model.py` |\n",
+ "| `TransformerSeq2SeqModel` | 序列变换,基于`transformer` | `/models/torch/seq2seq_model.py` |\n",
+ "| `SequenceGeneratorModel` | 封装`Seq2SeqModel`,结合`SequenceGenerator` | `/models/torch/seq2seq_generator.py` |\n",
+ "| `SeqLabeling` | 标注模型,基类`LSTM+FC+CRF` | `/models/torch/sequence_labeling.py` |\n",
+ "| `BiLSTMCRF` | 标注模型,`BiLSTM+FC+CRF` | `/models/torch/sequence_labeling.py` |\n",
+ "| `AdvSeqLabel` | 标注模型,`LN+BiLSTM*2+LN+FC+CRF` | `/models/torch/sequence_labeling.py` |"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "61318354",
+ "metadata": {},
+ "source": [
+ "上述`fastNLP`模块,不仅**为入门级用户提供了简单易用的工具**,以解决各种`NLP`任务,或复现相关论文\n",
+ "\n",
+ " 同时**也为专业研究人员提供了便捷可操作的接口**,封装部分代码的同时,也能指定参数修改细节\n",
+ "\n",
+ " 在接下来的`tutorial`中,我们将通过`SST-2`分类和`CoNLL-2003`标注,展示相关模型使用\n",
+ "\n",
+ "注一:**SST**,**单句情感分类**数据集,包含电影评论和对应情感极性,1 对应正面情感,0 对应负面情感\n",
+ "\n",
+ " 数据集包括三部分:训练集 67350 条,验证集 873 条,测试集 1821 条,更多参考[下载链接](https://gluebenchmark.com/tasks)\n",
+ "\n",
+ "注二:**CoNLL-2003**,**文本语法标注**数据集,包含语句和对应的词性标签`pos_tags`(名动形数量代)\n",
+ "\n",
+ " 语法结构标签`chunk_tags`(主谓宾定状补)、命名实体标签`ner_tags`(人名、组织名、地名、时间等)\n",
+ "\n",
+ " 数据集包括三部分:训练集 14041 条,验证集 3250 条,测试集 3453 条,更多参考[原始论文](https://aclanthology.org/W03-0419.pdf)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2a36bbe4",
+ "metadata": {},
+ "source": [
+ "### 1.2 示例一:modules 实现 LSTM 分类\n",
+ "\n",
+ "\n",
+ " 本示例使用`fastNLP 1.0`中预定义模型`modules`模块,基于`LSTM`模型,实现`SST-2`文本二分类任务\n",
+ "\n",
+ "数据使用方面:首先,**使用 datasets 模块中的 load_dataset 函数**,以如下形式,指定`SST-2`数据集加载\n",
+ "\n",
+ " 首次下载保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "1aa5cf6d",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b8bdfdc011d349e38a1aa2aff35b2482",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/3 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from datasets import load_dataset\n",
+ "\n",
+ "sst2data = load_dataset('glue', 'sst2')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c476abe7",
+ "metadata": {},
+ "source": [
+ " 接着,使用`tutorial-1`和`tutorial-2`中的知识,将数据集转化为`fastNLP`中的`DataSet`格式\n",
+ "\n",
+ " **使用 apply_more 函数、 Vocabulary 模块的 from_/index_dataset 函数预处理数据**\n",
+ "\n",
+ " 并结合`delete_field`函数删除字段调整格式,`split`函数划分测试集和验证集\n",
+ "\n",
+ " **仅保留 'words' 字段表示输入文本单词序号序列、 'target' 字段表示文本对应预测输出结果**\n",
+ "\n",
+ " 两者**对应到 CNNText 中 train_step 函数和 evaluate_step 函数的签名/输入参数**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "357ea748",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\u001b[38;5;2m[i 0604 16:19:46.727257 48 log.cc:351] Load log_sync: 1\u001b[m\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Processing: 0%| | 0/6000 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import sys\n",
+ "sys.path.append('..')\n",
+ "\n",
+ "from fastNLP import DataSet\n",
+ "\n",
+ "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n",
+ "\n",
+ "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n",
+ " progress_bar=\"tqdm\")\n",
+ "dataset.delete_field('sentence')\n",
+ "dataset.delete_field('label')\n",
+ "dataset.delete_field('idx')\n",
+ "\n",
+ "from fastNLP import Vocabulary\n",
+ "\n",
+ "vocab = Vocabulary()\n",
+ "vocab.from_dataset(dataset, field_name='words')\n",
+ "vocab.index_dataset(dataset, field_name='words')\n",
+ "\n",
+ "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "96380c67",
+ "metadata": {},
+ "source": [
+ " 然后,使用`tutorial-3`中的知识,**通过 prepare_torch_dataloader 处理数据集得到 dataloader**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "b9dd1273",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from fastNLP import prepare_torch_dataloader\n",
+ "\n",
+ "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
+ "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "eb75aaba",
+ "metadata": {},
+ "source": [
+ "模型使用方面,这里使用`Embedding`、`LSTM`、`MLP`等模块搭建模型,方法类似`pytorch`,结构如下所示\n",
+ "\n",
+ "```\n",
+ "ClsByModules(\n",
+ " (embedding): Embedding(\n",
+ " (embed): Embedding(8458, 100)\n",
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (lstm): LSTM(\n",
+ " (lstm): LSTM(100, 64, num_layers=2, batch_first=True, bidirectional=True)\n",
+ " )\n",
+ " (mlp): MLP(\n",
+ " (hiddens): ModuleList()\n",
+ " (output): Linear(in_features=128, out_features=2, bias=True)\n",
+ " (dropout): Dropout(p=0.5, inplace=False)\n",
+ " )\n",
+ " (loss_fn): CrossEntropyLoss()\n",
+ ")\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "0b25b25c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "\n",
+ "from fastNLP.modules.torch import LSTM, MLP\n",
+ "from fastNLP.embeddings.torch import Embedding\n",
+ "\n",
+ "\n",
+ "class ClsByModules(nn.Module):\n",
+ " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n",
+ " nn.Module.__init__(self)\n",
+ "\n",
+ " self.embedding = Embedding((vocab_size, embedding_dim))\n",
+ " self.lstm = LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=True)\n",
+ " self.mlp = MLP([hidden_dim * 2, output_dim], dropout=dropout)\n",
+ " \n",
+ " self.loss_fn = nn.CrossEntropyLoss()\n",
+ "\n",
+ " def forward(self, words):\n",
+ " output = self.embedding(words)\n",
+ " output, (hidden, cell) = self.lstm(output)\n",
+ " output = self.mlp(torch.cat((hidden[-1], hidden[-2]), dim=1))\n",
+ " return output\n",
+ " \n",
+ " def train_step(self, words, target):\n",
+ " pred = self(words)\n",
+ " return {\"loss\": self.loss_fn(pred, target)}\n",
+ "\n",
+ " def evaluate_step(self, words, target):\n",
+ " pred = self(words)\n",
+ " pred = torch.max(pred, dim=-1)[1]\n",
+ " return {\"pred\": pred, \"target\": target}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4890de5a",
+ "metadata": {},
+ "source": [
+ " 接着,初始化模型`model`实例,同时,使用`torch.optim.AdamW`初始化`optimizer`实例"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "9dbbf50d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = ClsByModules(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n",
+ "\n",
+ "from torch.optim import AdamW\n",
+ "\n",
+ "optimizers = AdamW(params=model.parameters(), lr=5e-5)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "054538f5",
+ "metadata": {},
+ "source": [
+ " 最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "7a93432f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from fastNLP import Trainer, Accuracy\n",
+ "\n",
+ "trainer = Trainer(\n",
+ " model=model,\n",
+ " driver='torch',\n",
+ " device=0, # 'cuda'\n",
+ " n_epochs=10,\n",
+ " optimizers=optimizers,\n",
+ " train_dataloader=train_dataloader,\n",
+ " evaluate_dataloaders=evaluate_dataloader,\n",
+ " metrics={'acc': Accuracy()}\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "31102e0f",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "[16:20:10] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m[16:20:10]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=908530;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=864197;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+ "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+ ".get_parent()\n",
+ " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+ "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+ ".get_parent()\n",
+ " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+ "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+ ".get_parent()\n",
+ " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+ "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+ ".get_parent()\n",
+ " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.525,\n",
+ " \"total#acc\": 160.0,\n",
+ " \"correct#acc\": 84.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.525\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m84.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.54375,\n",
+ " \"total#acc\": 160.0,\n",
+ " \"correct#acc\": 87.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.54375\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m87.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.55,\n",
+ " \"total#acc\": 160.0,\n",
+ " \"correct#acc\": 88.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.55\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m88.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.625,\n",
+ " \"total#acc\": 160.0,\n",
+ " \"correct#acc\": 100.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.625\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.65,\n",
+ " \"total#acc\": 160.0,\n",
+ " \"correct#acc\": 104.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.65\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m104.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.69375,\n",
+ " \"total#acc\": 160.0,\n",
+ " \"correct#acc\": 111.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.69375\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m111.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.675,\n",
+ " \"total#acc\": 160.0,\n",
+ " \"correct#acc\": 108.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.675\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m108.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.66875,\n",
+ " \"total#acc\": 160.0,\n",
+ " \"correct#acc\": 107.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.66875\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m107.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.675,\n",
+ " \"total#acc\": 160.0,\n",
+ " \"correct#acc\": 108.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.675\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m108.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.68125,\n",
+ " \"total#acc\": 160.0,\n",
+ " \"correct#acc\": 109.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.68125\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m109.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "trainer.run(num_eval_batch_per_dl=10)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "8bc4bfb2",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "{'acc#acc': 0.712222, 'total#acc': 900.0, 'correct#acc': 641.0}"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "trainer.evaluator.run()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "07538876",
+ "metadata": {},
+ "source": [
+ " 注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间,下同"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "1b52eafd",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "383"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import gc\n",
+ "\n",
+ "del model\n",
+ "del trainer\n",
+ "del dataset\n",
+ "del sst2data\n",
+ "\n",
+ "gc.collect()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d9443213",
+ "metadata": {},
+ "source": [
+ "## 2. fastNLP 中 models 模块的介绍\n",
+ "\n",
+ "### 2.1 示例一:models 实现 CNN 分类\n",
+ "\n",
+ " 本示例使用`fastNLP 1.0`中预定义模型`models`中的`CNNText`模型,实现`SST-2`文本二分类任务\n",
+ "\n",
+ "数据使用方面,此处沿用在上个示例中展示的`SST-2`数据集,数据加载过程相同且已经执行过了,因此简略\n",
+ "\n",
+ "模型使用方面,如上所述,这里使用**基于卷积神经网络 CNN 的预定义文本分类模型 CNNText**,结构如下所示\n",
+ "\n",
+ " 首先是内置的`100`维嵌入层、`dropout`层、紧接着是三个一维卷积,将`100`维嵌入特征,分别通过\n",
+ "\n",
+ " **感受野为 1 、 3 、 5 的卷积算子变换至 30 维、 40 维、 50 维的卷积特征**,再将三者拼接\n",
+ "\n",
+ " 最终再次通过`dropout`层、线性变换层,映射至二元的输出值,对应两个分类结果上的几率`logits`\n",
+ "\n",
+ "```\n",
+ "CNNText(\n",
+ " (embed): Embedding(\n",
+ " (embed): Embedding(5194, 100)\n",
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (conv_pool): ConvMaxpool(\n",
+ " (convs): ModuleList(\n",
+ " (0): Conv1d(100, 30, kernel_size=(1,), stride=(1,), bias=False)\n",
+ " (1): Conv1d(100, 40, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
+ " (2): Conv1d(100, 50, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
+ " )\n",
+ " )\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (fc): Linear(in_features=120, out_features=2, bias=True)\n",
+ ")\n",
+ "```\n",
+ "\n",
+ "对应到代码上,**从 fastNLP.models.torch 路径下导入 CNNText**,初始化`CNNText`和`optimizer`实例\n",
+ "\n",
+ " 注意:初始化`CNNText`时,**二元组参数 embed 、分类数量 num_classes 是必须传入的**,其中\n",
+ "\n",
+ " **embed 表示嵌入层的嵌入抽取矩阵大小**,因此第二个元素对应的是默认隐藏层维度 `100` 维"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "f6e76e2e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from fastNLP.models.torch import CNNText\n",
+ "\n",
+ "model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n",
+ "\n",
+ "from torch.optim import AdamW\n",
+ "\n",
+ "optimizers = AdamW(params=model.parameters(), lr=5e-4)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0cc5ca10",
+ "metadata": {},
+ "source": [
+ " 最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "50a13ee5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from fastNLP import Trainer, Accuracy\n",
+ "\n",
+ "trainer = Trainer(\n",
+ " model=model,\n",
+ " driver='torch',\n",
+ " device=0, # 'cuda'\n",
+ " n_epochs=10,\n",
+ " optimizers=optimizers,\n",
+ " train_dataloader=train_dataloader,\n",
+ " evaluate_dataloaders=evaluate_dataloader,\n",
+ " metrics={'acc': Accuracy()}\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "28903a7d",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "[16:21:57] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m[16:21:57]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=813103;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=271516;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.654444,\n",
+ " \"total#acc\": 900.0,\n",
+ " \"correct#acc\": 589.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.654444\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m589.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.767778,\n",
+ " \"total#acc\": 900.0,\n",
+ " \"correct#acc\": 691.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.767778\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m691.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.797778,\n",
+ " \"total#acc\": 900.0,\n",
+ " \"correct#acc\": 718.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.797778\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m718.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.803333,\n",
+ " \"total#acc\": 900.0,\n",
+ " \"correct#acc\": 723.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.803333\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m723.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.807778,\n",
+ " \"total#acc\": 900.0,\n",
+ " \"correct#acc\": 727.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.807778\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m727.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.812222,\n",
+ " \"total#acc\": 900.0,\n",
+ " \"correct#acc\": 731.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.812222\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m731.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.804444,\n",
+ " \"total#acc\": 900.0,\n",
+ " \"correct#acc\": 724.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.804444\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m724.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.811111,\n",
+ " \"total#acc\": 900.0,\n",
+ " \"correct#acc\": 730.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.811111\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m730.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.811111,\n",
+ " \"total#acc\": 900.0,\n",
+ " \"correct#acc\": 730.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.811111\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m730.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.806667,\n",
+ " \"total#acc\": 900.0,\n",
+ " \"correct#acc\": 726.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.806667\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m726.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "trainer.run()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "f47a6a35",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "{'acc#acc': 0.806667, 'total#acc': 900.0, 'correct#acc': 726.0}"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "trainer.evaluator.run()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5b5c0446",
+ "metadata": {},
+ "source": [
+ " 注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间,下同"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "e9e70f88",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "344"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import gc\n",
+ "\n",
+ "del model\n",
+ "del trainer\n",
+ "\n",
+ "gc.collect()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6aec2a19",
+ "metadata": {},
+ "source": [
+ "### 2.2 示例二:models 实现 BiLSTM 标注\n",
+ "\n",
+ " 通过两个示例一的对比可以发现,得益于`models`对模型结构的封装,使用`models`明显更加便捷\n",
+ "\n",
+ " 针对更加复杂的模型时,编码更加轻松;本示例将使用`models`中的`BiLSTMCRF`模型\n",
+ "\n",
+ " 避免`CRF`和`Viterbi`算法代码书写的困难,轻松实现`CoNLL-2003`中的命名实体识别`NER`任务\n",
+ "\n",
+ "模型使用方面,如上所述,这里使用**基于双向 LSTM +条件随机场 CRF 的标注模型 BiLSTMCRF**,结构如下所示\n",
+ "\n",
+ " 其中,隐藏层维度默认`100`维,因此对应双向`LSTM`输出`200`维,`dropout`层退学概率、`LSTM`层数可调\n",
+ "\n",
+ "```\n",
+ "BiLSTMCRF(\n",
+ " (embed): Embedding(7590, 100)\n",
+ " (lstm): LSTM(\n",
+ " (lstm): LSTM(100, 100, batch_first=True, bidirectional=True)\n",
+ " )\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (fc): Linear(in_features=200, out_features=9, bias=True)\n",
+ " (crf): ConditionalRandomField()\n",
+ ")\n",
+ "```\n",
+ "\n",
+ "数据使用方面,此处仍然**使用 datasets 模块中的 load_dataset 函数**,以如下形式,加载`CoNLL-2003`数据集\n",
+ "\n",
+ " 首次下载后会保存至`~.cache/huggingface/datasets/conll2003/conll2003/1.0.0/`目录下"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "03e66686",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Reusing dataset conll2003 (/remote-home/xrliu/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63f4ebd1bcb7148b1644497336fd74643d4ce70123334431a3c053b7ee4e96ee)\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "593bc03ed5914953ab94268ff2f01710",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/3 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from datasets import load_dataset\n",
+ "\n",
+ "ner2data = load_dataset('conll2003', 'conll2003')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "fc505631",
+ "metadata": {},
+ "source": [
+ "紧接着,使用`tutorial-1`和`tutorial-2`中的知识,将数据集转化为`fastNLP`中的`DataSet`格式\n",
+ "\n",
+ " 完成数据集格式调整、文本序列化等操作;此处**需要 'words' 、 'seq_len' 、 'target' 三个字段**\n",
+ "\n",
+ "此外,**需要定义 NER 标签到标签序号的映射**(**词汇表 label_vocab**),数据集中标签已经完成了序号映射\n",
+ "\n",
+ " 所以需要人工定义**9 个标签对应之前的 9 个分类目标**;数据集说明中规定,`'O'`表示其他标签\n",
+ "\n",
+ " **后缀 '-PER' 、 '-ORG' 、 '-LOC' 、 '-MISC' 对应人名、组织名、地名、时间等其他命名**\n",
+ "\n",
+ " **前缀 'B-' 表示起始标签、 'I-' 表示终止标签**;例如,`'B-PER'`表示人名实体的起始标签"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "1f88cad4",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Processing: 0%| | 0/4000 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import sys\n",
+ "sys.path.append('..')\n",
+ "\n",
+ "from fastNLP import DataSet\n",
+ "\n",
+ "dataset = DataSet.from_pandas(ner2data['train'].to_pandas())[:4000]\n",
+ "\n",
+ "dataset.apply_more(lambda ins:{'words': ins['tokens'], 'seq_len': len(ins['tokens']), 'target': ins['ner_tags']}, \n",
+ " progress_bar=\"tqdm\")\n",
+ "dataset.delete_field('tokens')\n",
+ "dataset.delete_field('ner_tags')\n",
+ "dataset.delete_field('pos_tags')\n",
+ "dataset.delete_field('chunk_tags')\n",
+ "dataset.delete_field('id')\n",
+ "\n",
+ "from fastNLP import Vocabulary\n",
+ "\n",
+ "token_vocab = Vocabulary()\n",
+ "token_vocab.from_dataset(dataset, field_name='words')\n",
+ "token_vocab.index_dataset(dataset, field_name='words')\n",
+ "label_vocab = Vocabulary(padding=None, unknown=None)\n",
+ "label_vocab.add_word_lst(['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'])\n",
+ "\n",
+ "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d9889427",
+ "metadata": {},
+ "source": [
+ "然后,同样使用`tutorial-3`中的知识,通过`prepare_torch_dataloader`处理数据集得到`dataloader`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "7802a072",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from fastNLP import prepare_torch_dataloader\n",
+ "\n",
+ "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
+ "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2bc7831b",
+ "metadata": {},
+ "source": [
+ "接着,**从 fastNLP.models.torch 路径下导入 BiLSTMCRF**,初始化`BiLSTMCRF`实例和优化器\n",
+ "\n",
+ " 注意:初始化`BiLSTMCRF`时,和`CNNText`相同,**参数 embed 、 num_classes 是必须传入的**\n",
+ "\n",
+ " 隐藏层维度`hidden_size`默认`100`维,调整`150`维;退学概率默认`0.1`,调整`0.2`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "4e12c09f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from fastNLP.models.torch import BiLSTMCRF\n",
+ "\n",
+ "model = BiLSTMCRF(embed=(len(token_vocab), 150), num_classes=len(label_vocab), \n",
+ " num_layers=1, hidden_size=150, dropout=0.2)\n",
+ "\n",
+ "from torch.optim import AdamW\n",
+ "\n",
+ "optimizers = AdamW(params=model.parameters(), lr=1e-3)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bf30608f",
+ "metadata": {},
+ "source": [
+ "最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练\n",
+ "\n",
+ " **使用 SpanFPreRecMetric 作为 NER 的评价标准**,详细请参考接下来的`tutorial-5`\n",
+ "\n",
+ " 同时,**初始化时需要添加 vocabulary 形式的标签与序号之间的映射 tag_vocab**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "cbd6c205",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from fastNLP import Trainer, SpanFPreRecMetric\n",
+ "\n",
+ "trainer = Trainer(\n",
+ " model=model,\n",
+ " driver='torch',\n",
+ " device=0, # 'cuda'\n",
+ " n_epochs=10,\n",
+ " optimizers=optimizers,\n",
+ " train_dataloader=train_dataloader,\n",
+ " evaluate_dataloaders=evaluate_dataloader,\n",
+ " metrics={'F1': SpanFPreRecMetric(tag_vocab=label_vocab)}\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "0f8eff34",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "[16:23:41] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m[16:23:41]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=565652;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=224849;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"f#F1\": 0.169014,\n",
+ " \"pre#F1\": 0.170732,\n",
+ " \"rec#F1\": 0.167331\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.169014\u001b[0m,\n",
+ " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.170732\u001b[0m,\n",
+ " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.167331\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"f#F1\": 0.361809,\n",
+ " \"pre#F1\": 0.312139,\n",
+ " \"rec#F1\": 0.430279\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.361809\u001b[0m,\n",
+ " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.312139\u001b[0m,\n",
+ " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.430279\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"f#F1\": 0.525,\n",
+ " \"pre#F1\": 0.475728,\n",
+ " \"rec#F1\": 0.585657\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.525\u001b[0m,\n",
+ " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.475728\u001b[0m,\n",
+ " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.585657\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"f#F1\": 0.627306,\n",
+ " \"pre#F1\": 0.584192,\n",
+ " \"rec#F1\": 0.677291\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.627306\u001b[0m,\n",
+ " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.584192\u001b[0m,\n",
+ " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.677291\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"f#F1\": 0.710937,\n",
+ " \"pre#F1\": 0.697318,\n",
+ " \"rec#F1\": 0.7251\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.710937\u001b[0m,\n",
+ " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.697318\u001b[0m,\n",
+ " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.7251\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"f#F1\": 0.739563,\n",
+ " \"pre#F1\": 0.738095,\n",
+ " \"rec#F1\": 0.741036\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.739563\u001b[0m,\n",
+ " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.738095\u001b[0m,\n",
+ " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"f#F1\": 0.748491,\n",
+ " \"pre#F1\": 0.756098,\n",
+ " \"rec#F1\": 0.741036\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.748491\u001b[0m,\n",
+ " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.756098\u001b[0m,\n",
+ " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"f#F1\": 0.716763,\n",
+ " \"pre#F1\": 0.69403,\n",
+ " \"rec#F1\": 0.741036\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.716763\u001b[0m,\n",
+ " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.69403\u001b[0m,\n",
+ " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"f#F1\": 0.768293,\n",
+ " \"pre#F1\": 0.784232,\n",
+ " \"rec#F1\": 0.752988\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.768293\u001b[0m,\n",
+ " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.784232\u001b[0m,\n",
+ " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.752988\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"f#F1\": 0.757692,\n",
+ " \"pre#F1\": 0.732342,\n",
+ " \"rec#F1\": 0.784861\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.757692\u001b[0m,\n",
+ " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.732342\u001b[0m,\n",
+ " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.784861\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "trainer.run(num_eval_batch_per_dl=10)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "37871d6b",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "{'f#F1': 0.766798, 'pre#F1': 0.741874, 'rec#F1': 0.793456}"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "trainer.evaluator.run()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "96bae094",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/source/tutorials/fastnlp_tutorial_5.ipynb b/docs/source/tutorials/fastnlp_tutorial_5.ipynb
new file mode 100644
index 00000000..ab759feb
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_tutorial_5.ipynb
@@ -0,0 +1,1242 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "fdd7ff16",
+ "metadata": {},
+ "source": [
+ "# T5. trainer 和 evaluator 的深入介绍\n",
+ "\n",
+ " 1 fastNLP 中 driver 的补充介绍\n",
+ " \n",
+ " 1.1 trainer 和 driver 的构想 \n",
+ "\n",
+ " 1.2 device 与 多卡训练\n",
+ "\n",
+ " 2 fastNLP 中的更多 metric 类型\n",
+ "\n",
+ " 2.1 预定义的 metric 类型\n",
+ "\n",
+ " 2.2 自定义的 metric 类型\n",
+ "\n",
+ " 3 fastNLP 中 trainer 的补充介绍\n",
+ "\n",
+ " 3.1 trainer 的内部结构"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "08752c5a",
+ "metadata": {
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ },
+ "source": [
+ "## 1. fastNLP 中 driver 的补充介绍\n",
+ "\n",
+ "### 1.1 trainer 和 driver 的构想\n",
+ "\n",
+ "在`fastNLP 1.0`中,模型训练最关键的模块便是**训练模块 trainer 、评测模块 evaluator 、驱动模块 driver**,\n",
+ "\n",
+ " 在`tutorial 0`中,已经简单介绍过上述三个模块:**driver 用来控制训练评测中的 model 的最终运行**\n",
+ "\n",
+ " **evaluator 封装评测的 metric**,**trainer 封装训练的 optimizer**,**也可以包括 evaluator**\n",
+ "\n",
+ "之所以做出上述的划分,其根本目的在于要**达成对于多个 python 学习框架**,**例如 pytorch 、 paddle 、 jittor 的兼容**\n",
+ "\n",
+ " 对于训练环节,其伪代码如下方左边紫色一栏所示,由于**不同框架对模型、损失、张量的定义各有不同**,所以将训练环节\n",
+ "\n",
+ " 划分为**框架无关的循环控制、批量分发部分**,**由 trainer 模块负责**实现,对应的伪代码如下方中间一栏所示\n",
+ "\n",
+ " 以及**随框架不同的模型调用、数值优化部分**,**由 driver 模块负责**实现,对应的伪代码如下方右边一栏所示\n",
+ "\n",
+ "|训练过程|框架无关 对应`Trainer`|框架相关 对应`Driver`\n",
+ "|----|----|----|\n",
+ "| try: | try: | |\n",
+ "| for epoch in 1:n_eoochs: | for epoch in 1:n_eoochs: | |\n",
+ "| for step in 1:total_steps: | for step in 1:total_steps: | |\n",
+ "| batch = fetch_batch() | batch = fetch_batch() | |\n",
+ "| loss = model.forward(batch) | | loss = model.forward(batch) |\n",
+ "| loss.backward() | | loss.backward() |\n",
+ "| model.clear_grad() | | model.clear_grad() |\n",
+ "| model.update() | | model.update() |\n",
+ "| if need_save: | if need_save: | |\n",
+ "| model.save() | | model.save() |\n",
+ "| except: | except: | |\n",
+ "| process_exception() | process_exception() | |"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3e55f07b",
+ "metadata": {},
+ "source": [
+ " 对于评测环节,其伪代码如下方左边紫色一栏所示,同样由于不同框架对模型、损失、张量的定义各有不同,所以将评测环节\n",
+ "\n",
+ " 划分为**框架无关的循环控制、分发汇总部分**,**由 evaluator 模块负责**实现,对应的伪代码如下方中间一栏所示\n",
+ "\n",
+ " 以及**随框架不同的模型调用、评测计算部分**,同样**由 driver 模块负责**实现,对应的伪代码如下方右边一栏所示\n",
+ "\n",
+ "|评测过程|框架无关 对应`Evaluator`|框架相关 对应`Driver`\n",
+ "|----|----|----|\n",
+ "| try: | try: | |\n",
+ "| model.set_eval() | model.set_eval() | |\n",
+ "| for step in 1:total_steps: | for step in 1:total_steps: | |\n",
+ "| batch = fetch_batch() | batch = fetch_batch() | |\n",
+ "| outputs = model.evaluate(batch) | | outputs = model.evaluate(batch) |\n",
+ "| metric.compute(batch, outputs) | | metric.compute(batch, outputs) |\n",
+ "| results = metric.get_metric() | results = metric.get_metric() | |\n",
+ "| except: | except: | |\n",
+ "| process_exception() | process_exception() | |"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "94ba11c6",
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "source": [
+ "由此,从程序员的角度,`fastNLP v1.0` **通过一个 driver 让基于 pytorch 、 paddle 、 jittor 、 oneflow 框架的模型**\n",
+ "\n",
+ " **都能在相同的 trainer 和 evaluator 上运行**,这也**是 fastNLP v1.0 相比于之前版本的一大亮点**\n",
+ "\n",
+ " 而从`driver`的角度,`fastNLP v1.0`通过定义一个`driver`基类,**将所有张量转化为 numpy.tensor**\n",
+ "\n",
+ " 并由此泛化出`torch_driver`、`paddle_driver`、`jittor_driver`三个子类,从而实现了\n",
+ "\n",
+ " 对`pytorch`、`paddle`、`jittor`的兼容,有关后两者的实践请参考接下来的`tutorial-6`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ab1cea7d",
+ "metadata": {},
+ "source": [
+ "### 1.2 device 与 多卡训练\n",
+ "\n",
+ "**fastNLP v1.0 支持多卡训练**,实现方法则是**通过将 trainer 中的 device 设置为对应显卡的序号列表**\n",
+ "\n",
+ " 由单卡切换成多卡,无论是数据、模型还是评测都会面临一定的调整,`fastNLP v1.0`保证:\n",
+ "\n",
+ " 数据拆分时,不同卡之间相互协调,所有数据都可以被训练,且不会使用到相同的数据\n",
+ "\n",
+ " 模型训练时,模型之间需要交换梯度;评测计算时,每张卡先各自计算,再汇总结果\n",
+ "\n",
+ " 例如,在评测计算运行`get_metric`函数时,`fastNLP v1.0`将自动按照`self.right`和`self.total`\n",
+ "\n",
+ " 指定的 **aggregate_method 方法**,默认为`sum`,将每张卡上结果汇总起来,因此最终\n",
+ "\n",
+ " 在调用`get_metric`方法时,`Accuracy`类能够返回全部的统计结果,代码如下\n",
+ " \n",
+ "```python\n",
+ "trainer = Trainer(\n",
+ " model=model, # model 基于 pytorch 实现 \n",
+ " train_dataloader=train_dataloader,\n",
+ " optimizers=optimizer,\n",
+ " ...\n",
+ " driver='torch', # driver 使用 torch_driver \n",
+ " device=[0, 1], # gpu 选择 cuda:0 + cuda:1\n",
+ " ...\n",
+ " evaluate_dataloaders=evaluate_dataloader,\n",
+ " metrics={'acc': Accuracy()},\n",
+ " ...\n",
+ " )\n",
+ "\n",
+ "class Accuracy(Metric):\n",
+ " def __init__(self):\n",
+ " super().__init__()\n",
+ " self.register_element(name='total', value=0, aggregate_method='sum')\n",
+ " self.register_element(name='right', value=0, aggregate_method='sum')\n",
+ "```\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e2e0a210",
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "source": [
+ "注:`fastNLP v1.0`中要求`jupyter`不能多卡,仅能单卡,故在所有`tutorial`中均不作相关演示"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8d19220c",
+ "metadata": {},
+ "source": [
+ "## 2. fastNLP 中的更多 metric 类型\n",
+ "\n",
+ "### 2.1 预定义的 metric 类型\n",
+ "\n",
+ "在`fastNLP 1.0`中,除了前几篇`tutorial`中经常见到的**正确率 Accuracy**,还有其他**预定义的评测标准 metric**\n",
+ "\n",
+ " 包括**所有 metric 的基类 Metric**、适配`Transformers`中相关模型的正确率`TransformersAccuracy`\n",
+ "\n",
+ " **适用于分类语境下的 F1 值 ClassifyFPreRecMetric**(其中也包括召回率`Pre`、精确率`Rec`\n",
+ "\n",
+ " **适用于抽取语境下的 F1 值 SpanFPreRecMetric**;相关基本信息内容见下表,之后是详细分析\n",
+ "\n",
+ "代码名称|简要介绍|代码路径\n",
+ "----|----|----|\n",
+ " `Metric` | 定义`metrics`时继承的基类 | `/core/metrics/metric.py` |\n",
+ " `Accuracy` | 正确率,最为常用 | `/core/metrics/accuracy.py` |\n",
+ " `TransformersAccuracy` | 正确率,为了兼容`Transformers`中相关模型 | `/core/metrics/accuracy.py` |\n",
+ " `ClassifyFPreRecMetric` | 召回率、精确率、F1值,适用于**分类问题** | `/core/metrics/classify_f1_pre_rec_metric.py` |\n",
+ " `SpanFPreRecMetric` | 召回率、精确率、F1值,适用于**抽取问题** | `/core/metrics/span_f1_pre_rec_metric.py` |"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "fdc083a3",
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "source": [
+ " 如`tutorial-0`中所述,所有的`metric`都包含`get_metric`和`update`函数,其中\n",
+ "\n",
+ " **update 函数更新单个 batch 的统计量**,**get_metric 函数返回最终结果**,并打印显示\n",
+ "\n",
+ "\n",
+ "### 2.1.1 Accuracy 与 TransformersAccuracy\n",
+ "\n",
+ "`Accuracy`,正确率,预测正确的数据`right_num`在总数据`total_num`,中的占比(公式就不用列了\n",
+ "\n",
+ " `get_metric`函数打印格式为 **{\"acc#xx\": float, 'total#xx': float, 'correct#xx': float}**\n",
+ "\n",
+ " 一般在初始化时不需要传参,`fastNLP`会根据`update`函数的传入参数确定对应后台框架`backend`\n",
+ "\n",
+ " **update 函数的参数包括 pred 、 target 、 seq_len**,**后者用来标记批次中每笔数据的长度**\n",
+ "\n",
+ "`TransformersAccuracy`,继承自`Accuracy`,只是为了兼容`Transformers`框架中相关模型\n",
+ "\n",
+ " 在`update`函数中,将`Transformers`框架输出的`attention_mask`参数转化为`seq_len`参数\n",
+ "\n",
+ "\n",
+ "### 2.1.2 ClassifyFPreRecMetric 与 SpanFPreRecMetric\n",
+ "\n",
+ "`ClassifyFPreRecMetric`,分类评价,`SpanFPreRecMetric`,抽取评价,后者在`tutorial-4`中已出现\n",
+ "\n",
+ " 两者的相同之处在于:**第一**,**都包括召回率/查全率 ec**、**精确率/查准率 Pre**、**F1 值**这三个指标\n",
+ "\n",
+ " `get_metric`函数打印格式为 **{\"f#xx\": float, 'pre#xx': float, 'rec#xx': float}**\n",
+ "\n",
+ " 三者的计算公式如下,其中`beta`默认为`1`,即`F1`值是召回率`Rec`和精确率`Pre`的调和平均数\n",
+ "\n",
+ "$$\\text{召回率}\\ Rec=\\dfrac{\\text{正确预测为正例的数量}}{\\text{所有本来是正例的数量}}\\qquad \\text{精确率}\\ Pre=\\dfrac{\\text{正确预测为正例的数量}}{\\text{所有预测为正例的数量}}$$\n",
+ "\n",
+ "$$F_{beta} = \\frac{(1 + {beta}^{2})*(Pre*Rec)}{({beta}^{2}*Pre + Rec)}$$\n",
+ "\n",
+ " **第二**,可以通过参数`only_gross`为`False`,要求返回所有类别的`Rec-Pre-F1`,同时`F1`值又根据参数`f_type`又分为\n",
+ "\n",
+ " **micro F1**(**直接统计所有类别的 Rec-Pre-F1**)、**macro F1**(**统计各类别的 Rec-Pre-F1 再算术平均**)\n",
+ "\n",
+ " **第三**,两者在初始化时还可以**传入基于 fastNLP.Vocabulary 的 tag_vocab 参数记录数据集中的标签序号**\n",
+ "\n",
+ " **与标签名称之间的映射**,通过字符串列表`ignore_labels`参数,指定若干标签不用于`Rec-Pre-F1`的计算\n",
+ "\n",
+ "两者的不同之处在于:`ClassifyFPreRecMetric`针对简单的分类问题,每个分类标签之间彼此独立,不构成标签对\n",
+ "\n",
+ " **SpanFPreRecMetric 针对更复杂的抽取问题**,**规定标签 B-xx 和 I-xx 或 B-xx 和 E-xx 构成标签对**\n",
+ "\n",
+ " 在计算`Rec-Pre-F1`时,`ClassifyFPreRecMetric`只需要考虑标签本身是否正确这就足够了,但是\n",
+ "\n",
+ " 对于`SpanFPreRecMetric`,需要保证**标签符合规则且覆盖的区间与正确结果重合才算正确**\n",
+ "\n",
+ " 因此回到`tutorial-4`中`CoNLL-2003`的`NER`任务,如果评测方法选择`ClassifyFPreRecMetric`\n",
+ "\n",
+ " 或者`Accuracy`,会发现虽然评测结果显示很高,这是因为选择的评测方法要求太低\n",
+ "\n",
+ " 最后通过`CoNLL-2003`的词性标注`POS`任务简单演示下`ClassifyFPreRecMetric`相关的使用\n",
+ "\n",
+ "```python\n",
+ "from fastNLP import Vocabulary\n",
+ "from fastNLP import ClassifyFPreRecMetric\n",
+ "\n",
+ "tag_vocab = Vocabulary(padding=None, unknown=None) # 记录序号与标签之间的映射\n",
+ "tag_vocab.add_word_lst(['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', \n",
+ " 'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', \n",
+ " 'MD', 'NN', 'NNP', 'NNPS', 'NNS', 'NN|SYM', 'PDT', 'POS', 'PRP', 'PRP$', \n",
+ " 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', \n",
+ " 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP+', 'WRB', ]) # CoNLL-2003 中的 pos_tags\n",
+ "ignore_labels = ['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', ]\n",
+ "\n",
+ "FPreRec = ClassifyFPreRecMetric(tag_vocab=tag_vocab, \n",
+ " ignore_labels=ignore_labels, # 表示评测/优化中不考虑上述标签的正误/损失\n",
+ " only_gross=True, # 默认为 True 表示输出所有类别的综合统计结果\n",
+ " f_type='micro') # 默认为 'micro' 表示统计所有类别的 Rec-Pre-F1\n",
+ "metrics = {'F1': FPreRec}\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8a22f522",
+ "metadata": {},
+ "source": [
+ "### 2.2 自定义的 metric 类型\n",
+ "\n",
+ "如上文所述,`Metric`作为所有`metric`的基类,`Accuracy`等都是其子类,同样地,对于**自定义的 metric 类型**\n",
+ "\n",
+ " 也**需要继承自 Metric 类**,同时**内部自定义好 __init__ 、 update 和 get_metric 函数**\n",
+ "\n",
+ " 在`__init__`函数中,根据需求定义评测时需要用到的变量,此处沿用`Accuracy`中的`total_num`和`right_num`\n",
+ "\n",
+ " 在`update`函数中,根据需求定义评测变量的更新方式,需要注意的是如`tutorial-0`中所述,**update`的参数名**\n",
+ "\n",
+ " **需要待评估模型在 evaluate_step 中的输出名称一致**,由此**和数据集中对应字段名称一致**,即**参数匹配**\n",
+ "\n",
+ " 在`fastNLP v1.0`中,`update`函数的默认输入参数:`pred`,对应预测值;`target`,对应真实值\n",
+ "\n",
+ " 此处仍然沿用,因为接下来会需要使用`fastNLP`函数的与定义模型,其输入参数格式即使如此\n",
+ "\n",
+ " 在`get_metric`函数中,根据需求定义评测指标最终的计算,此处直接计算准确率,该函数必须返回一个字典\n",
+ "\n",
+ " 其中,字串`'prefix'`表示该`metric`的名称,会对应显示到`trainer`的`progress bar`中\n",
+ "\n",
+ "根据上述要求,这里简单定义了一个名为`MyMetric`的评测模块,用于分类问题的评测,以此展开一个实例展示"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "08a872e9",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import sys\n",
+ "sys.path.append('..')\n",
+ "\n",
+ "from fastNLP import Metric\n",
+ "\n",
+ "class MyMetric(Metric):\n",
+ "\n",
+ " def __init__(self):\n",
+ " Metric.__init__(self)\n",
+ " self.total_num = 0\n",
+ " self.right_num = 0\n",
+ "\n",
+ " def update(self, pred, target):\n",
+ " self.total_num += target.size(0)\n",
+ " self.right_num += target.eq(pred).sum().item()\n",
+ "\n",
+ " def get_metric(self, reset=True):\n",
+ " acc = self.right_num / self.total_num\n",
+ " if reset:\n",
+ " self.total_num = 0\n",
+ " self.right_num = 0\n",
+ " return {'prefix': acc}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0155f447",
+ "metadata": {},
+ "source": [
+ " 数据使用方面,此处仍然使用`datasets`模块中的`load_dataset`函数,加载`SST-2`二分类数据集"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "5ad81ac7",
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "ef923b90b19847f4916cccda5d33fc36",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/3 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from datasets import load_dataset\n",
+ "\n",
+ "sst2data = load_dataset('glue', 'sst2')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e9d81760",
+ "metadata": {},
+ "source": [
+ " 在数据预处理中,需要注意的是,这里原本应该根据`metric`和`model`的输入参数格式,调整\n",
+ "\n",
+ " 数据集中表示预测目标的字段,调整为`target`,在后文中会揭晓为什么,以及如何补救"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "cfb28b1b",
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Processing: 0%| | 0/6000 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from fastNLP import DataSet\n",
+ "\n",
+ "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n",
+ "\n",
+ "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split()}, progress_bar=\"tqdm\")\n",
+ "dataset.delete_field('sentence')\n",
+ "dataset.delete_field('idx')\n",
+ "\n",
+ "from fastNLP import Vocabulary\n",
+ "\n",
+ "vocab = Vocabulary()\n",
+ "vocab.from_dataset(dataset, field_name='words')\n",
+ "vocab.index_dataset(dataset, field_name='words')\n",
+ "\n",
+ "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)\n",
+ "\n",
+ "from fastNLP import prepare_torch_dataloader\n",
+ "\n",
+ "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
+ "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "af3f8c63",
+ "metadata": {},
+ "source": [
+ " 模型使用方面,此处仍然使用`tutorial-4`中介绍过的预定义`CNNText`模型,实现`SST-2`二分类"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "2fd210c5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from fastNLP.models.torch import CNNText\n",
+ "\n",
+ "model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n",
+ "\n",
+ "from torch.optim import AdamW\n",
+ "\n",
+ "optimizers = AdamW(params=model.parameters(), lr=5e-4)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6e723b87",
+ "metadata": {},
+ "source": [
+ "## 3. fastNLP 中 trainer 的补充介绍\n",
+ "\n",
+ "### 3.1 trainer 的内部结构\n",
+ "\n",
+ "在`tutorial-0`中,我们已经介绍了`trainer`的基本使用,从`tutorial-1`到`tutorial-4`,我们也已经展示了\n",
+ "\n",
+ " 很多`trainer`的使用案例,这里通过表格,相对完整地介绍`trainer`模块的属性和初始化参数(标粗为必选参数\n",
+ "\n",
+ "\n",
+ "名称|参数|属性|功能|内容\n",
+ "----|----|----|----|----|\n",
+ "| **model** | √ | √ | 指定`trainer`控制的模型 | 视框架而定,如`torch.nn.Module` |\n",
+ "| `device` | √ | | 指定`trainer`运行的卡位 | 例如`'cpu'`、`'cuda'`、`0`、`[0, 1]`等 |\n",
+ "| | | √ | 记录`trainer`运行的卡位 | `Device`类型,在初始化阶段生成 |\n",
+ "| **driver** | √ | | 指定`trainer`驱动的框架 | 包括`'torch'`、`'paddle'`、`'jittor'` |\n",
+ "| | | √ | 记录`trainer`驱动的框架 | `Driver`类型,在初始化阶段生成 |\n",
+ "| `n_epochs` | √ | - | 指定`trainer`迭代的轮数 | 默认`20`,记录在`driver.n_epochs`中 |\n",
+ "| **optimizers** | √ | √ | 指定`trainer`优化的方法 | 视框架而定,如`torch.optim.Adam` |\n",
+ "| `metrics` | √ | √ | 指定`trainer`评测的方法 | 字典类型,如`{'acc': Metric()}` |\n",
+ "| `evaluator` | | √ | 内置的`trainer`评测模块 | `Evaluator`类型,在初始化阶段生成 |\n",
+ "| `input_mapping` | √ | √ | 调整`dataloader`的参数不匹配 | 函数类型,输出字典匹配`forward`输入参数 |\n",
+ "| `output_mapping` | √ | √ | 调整`forward`输出的参数不匹配 | 函数类型,输出字典匹配`xx_step`输入参数 |\n",
+ "| **train_dataloader** | √ | √ | 指定`trainer`训练的数据 | `DataLoader`类型,生成视框架而定 |\n",
+ "| `evaluate_dataloaders` | √ | √ | 指定`trainer`评测的数据 | `DataLoader`类型,生成视框架而定 |\n",
+ "| `train_fn` | √ | √ | 指定`trainer`获取某个批次的损失值 | 函数类型,默认为`model.train_step` |\n",
+ "| `evaluate_fn` | √ | √ | 指定`trainer`获取某个批次的评估量 | 函数类型,默认为`model.evaluate_step` |\n",
+ "| `batch_step_fn` | √ | √ | 指定`trainer`训练时前向传输一个批次的方式 | 函数类型,默认为`TrainBatchLoop.batch_step_fn` |\n",
+ "| `evaluate_batch_step_fn` | √ | √ | 指定`trainer`评测时前向传输一个批次的方式 | 函数类型,默认为`EvaluateBatchLoop.batch_step_fn` |\n",
+ "| `accumulation_steps` | √ | √ | 指定`trainer`训练时反向传播的频率 | 默认为`1`,即每个批次都反向传播 |\n",
+ "| `evaluate_every` | √ | √ | 指定`evaluator`评测时计算的频率 | 默认`-1`表示每个循环一次,相反`1`表示每个批次一次 |\n",
+ "| `progress_bar` | √ | √ | 指定`trainer`训练和评测时的进度条样式 | 包括`'auto'`、`'tqdm'`、`'raw'`、`'rich'` |\n",
+ "| `callbacks` | √ | | 指定`trainer`训练时需要触发的函数 | `Callback`列表类型,详见`tutorial-7` |\n",
+ "| `callback_manager` | | √ | 记录与管理`callbacks`相关内容 | `CallbackManager`类型,详见`tutorial-7` |\n",
+ "| `monitor` | √ | √ | 辅助部分的`callbacks`相关内容 | 字符串/函数类型,详见`tutorial-7` |\n",
+ "| `marker` | √ | √ | 标记`trainer`实例,辅助`callbacks`相关内容 | 字符串型,详见`tutorial-7` |\n",
+ "| `trainer_state` | | √ | 记录`trainer`状态,辅助`callbacks`相关内容 | `TrainerState`类型,详见`tutorial-7` |\n",
+ "| `state` | | √ | 记录`trainer`状态,辅助`callbacks`相关内容 | `State`类型,详见`tutorial-7` |\n",
+ "| `fp16` | √ | √ | 指定`trainer`是否进行混合精度训练 | 布尔类型,默认`False` |\n",
+ "\n",
+ "其中,**input_mapping 和 output_mapping** 定义形式如下:输入字典形式的数据,根据参数匹配要求调整数据格式,这里就回应了前文未在数据集预处理时调整格式的问题,**总之参数匹配一定要求**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "de96c1d1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def input_mapping(data):\n",
+ " data['target'] = data['label']\n",
+ " return data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2fc8b9f3",
+ "metadata": {},
+ "source": [
+ " 而`trainer`模块的基础方法列表如下,相关进阶操作,如`on`系列函数、`callback`控制,请参考后续的`tutorial-7`\n",
+ "\n",
+ "|名称|功能|主要参数|\n",
+ "|----|----|----|\n",
+ "| `run` | 控制`trainer`中模型的训练和评测 | 详见后文 |\n",
+ "| `train_step` | 实现`trainer`训练中一个批数据的前向传播过程 | 输入`batch` |\n",
+ "| `backward` | 实现`trainer`训练中一次损失的反向传播过程 | 输入`output` |\n",
+ "| `zero_grad` | 实现`trainer`训练中`optimizers`的梯度置零 | 无输入 |\n",
+ "| `step` | 实现`trainer`训练中`optimizers`的参数更新 | 无输入 |\n",
+ "| `epoch_evaluate` | 实现`trainer`训练中每个循环的评测,实际是否执行取决于评测频率 | 无输入 |\n",
+ "| `step_evaluate` | 实现`trainer`训练中每个批次的评测,实际是否执行取决于评测频率 | 无输入 |\n",
+ "| `save_model` | 保存`trainer`中的模型参数/状态字典至`fastnlp_model.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`False` |\n",
+ "| `load_model` | 加载`trainer`中的模型参数/状态字典自`fastnlp_model.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只加载状态字典,默认`True` |\n",
+ "| `save_checkpoint` | 保存`trainer`中模型参数/状态字典 以及 `callback`、`sampler` 和`optimizer`的状态至`fastnlp_model/checkpoint.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`True` |\n",
+ "| `load_checkpoint` | 加载`trainer`中模型参数/状态字典 以及 `callback`、`sampler` 和`optimizer`的状态自`fastnlp_model/checkpoint.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`True` `resume_training`指明是否只精确到上次训练的批量,默认`True` |\n",
+ "| `add_callback_fn` | 在`trainer`初始化后添加`callback`函数 | 输入`event`指明回调时机,`fn`指明回调函数 |\n",
+ "| `on` | 函数修饰器,将一个函数转变为`callback`函数 | 详见`tutorial-7` |\n",
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1e21df35",
+ "metadata": {},
+ "source": [
+ "紧接着,初始化`trainer`实例,继续完成`SST-2`分类,其中`metrics`输入的键值对,字串`'suffix'`和之前定义的\n",
+ "\n",
+ " 字串`'prefix'`将拼接在一起显示到`progress bar`中,故完整的输出形式为`{'prefix#suffix': float}`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "926a9c50",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from fastNLP import Trainer\n",
+ "\n",
+ "trainer = Trainer(\n",
+ " model=model,\n",
+ " driver='torch',\n",
+ " device=0, # 'cuda'\n",
+ " n_epochs=10,\n",
+ " optimizers=optimizers,\n",
+ " input_mapping=input_mapping,\n",
+ " train_dataloader=train_dataloader,\n",
+ " evaluate_dataloaders=evaluate_dataloader,\n",
+ " metrics={'suffix': MyMetric()}\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b1b2e8b7",
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "source": [
+ "最后就是`run`函数的使用,关于其参数,这里也以表格形式列出,由此就解答了`num_eval_batch_per_dl=10`的含义\n",
+ "\n",
+ "|名称|功能|默认值|\n",
+ "|----|----|----|\n",
+ "| `num_train_batch_per_epoch` | 指定`trainer`训练时,每个循环计算批量数目 | 整数类型,默认`-1`,表示训练时,每个循环计算所有批量 |\n",
+ "| `num_eval_batch_per_dl` | 指定`trainer`评测时,每个循环计算批量数目 | 整数类型,默认`-1`,表示评测时,每个循环计算所有批量 |\n",
+ "| `num_eval_sanity_batch` | 指定`trainer`训练开始前,试探性评测批量数目 | 整数类型,默认`2`,表示训练开始前评估两个批量 |\n",
+ "| `resume_from` | 指定`trainer`恢复状态的路径,需要是文件夹 | 字符串型,默认`None`,使用可参考`CheckpointCallback` |\n",
+ "| `resume_training` | 指定`trainer`恢复状态的程度 | 布尔类型,默认`True`恢复所有状态,`False`仅恢复`model`和`optimizers`状态 |"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "43be274f",
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "[09:30:35] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m[09:30:35]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=954293;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=366534;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+ "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+ ".get_parent()\n",
+ " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+ "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+ ".get_parent()\n",
+ " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+ "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+ ".get_parent()\n",
+ " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+ "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+ ".get_parent()\n",
+ " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"prefix#suffix\": 0.6875\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.6875\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"prefix#suffix\": 0.8125\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"prefix#suffix\": 0.80625\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"prefix#suffix\": 0.825\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.825\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"prefix#suffix\": 0.8125\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"prefix#suffix\": 0.80625\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"prefix#suffix\": 0.80625\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"prefix#suffix\": 0.8\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"prefix#suffix\": 0.80625\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"prefix#suffix\": 0.80625\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "trainer.run(num_eval_batch_per_dl=10)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f1abfa0a",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.13"
+ },
+ "pycharm": {
+ "stem_cell": {
+ "cell_type": "raw",
+ "metadata": {
+ "collapsed": false
+ },
+ "source": []
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/source/tutorials/fastnlp_tutorial_6.ipynb b/docs/source/tutorials/fastnlp_tutorial_6.ipynb
new file mode 100644
index 00000000..63f7481e
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_tutorial_6.ipynb
@@ -0,0 +1,1646 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "fdd7ff16",
+ "metadata": {},
+ "source": [
+ "# T6. fastNLP 与 paddle 或 jittor 的结合\n",
+ "\n",
+ " 1 fastNLP 结合 paddle 训练模型\n",
+ " \n",
+ " 1.1 关于 paddle 的简单介绍\n",
+ "\n",
+ " 1.2 使用 paddle 搭建并训练模型\n",
+ "\n",
+ " 2 fastNLP 结合 jittor 训练模型\n",
+ "\n",
+ " 2.1 关于 jittor 的简单介绍\n",
+ "\n",
+ " 2.2 使用 jittor 搭建并训练模型\n",
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "08752c5a",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "6b13d42c39ba455eb370bf2caaa3a264",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/3 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from datasets import load_dataset\n",
+ "\n",
+ "sst2data = load_dataset('glue', 'sst2')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "7e8cc210",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\u001b[38;5;2m[i 0604 21:01:38.510813 72 log.cc:351] Load log_sync: 1\u001b[m\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Processing: 0%| | 0/6000 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " True\n"
+ ]
+ }
+ ],
+ "source": [
+ "import sys\n",
+ "sys.path.append('..')\n",
+ "\n",
+ "from fastNLP import DataSet\n",
+ "\n",
+ "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n",
+ "\n",
+ "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n",
+ " progress_bar=\"tqdm\")\n",
+ "dataset.delete_field('sentence')\n",
+ "dataset.delete_field('label')\n",
+ "dataset.delete_field('idx')\n",
+ "\n",
+ "from fastNLP import Vocabulary\n",
+ "\n",
+ "vocab = Vocabulary()\n",
+ "vocab.from_dataset(dataset, field_name='words')\n",
+ "vocab.index_dataset(dataset, field_name='words')\n",
+ "\n",
+ "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)\n",
+ "print(type(train_dataset), isinstance(train_dataset, DataSet))\n",
+ "\n",
+ "from fastNLP.io import DataBundle\n",
+ "\n",
+ "data_bundle = DataBundle(datasets={'train': train_dataset, 'dev': evaluate_dataset})"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "57a3272f",
+ "metadata": {},
+ "source": [
+ "## 1. fastNLP 结合 paddle 训练模型\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "e31b3198",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import paddle\n",
+ "import paddle.nn as nn\n",
+ "import paddle.nn.functional as F\n",
+ "\n",
+ "\n",
+ "class ClsByPaddle(nn.Layer):\n",
+ " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, dropout=0.5):\n",
+ " nn.Layer.__init__(self)\n",
+ " self.hidden_dim = hidden_dim\n",
+ "\n",
+ " self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)\n",
+ " \n",
+ " self.conv1 = nn.Sequential(nn.Conv1D(embedding_dim, 30, 1, padding=0), nn.ReLU())\n",
+ " self.conv2 = nn.Sequential(nn.Conv1D(embedding_dim, 40, 3, padding=1), nn.ReLU())\n",
+ " self.conv3 = nn.Sequential(nn.Conv1D(embedding_dim, 50, 5, padding=2), nn.ReLU())\n",
+ "\n",
+ " self.mlp = nn.Sequential(('dropout', nn.Dropout(p=dropout)),\n",
+ " ('linear_1', nn.Linear(120, hidden_dim)),\n",
+ " ('activate', nn.ReLU()),\n",
+ " ('linear_2', nn.Linear(hidden_dim, output_dim)))\n",
+ " \n",
+ " self.loss_fn = nn.MSELoss()\n",
+ "\n",
+ " def forward(self, words):\n",
+ " output = self.embedding(words).transpose([0, 2, 1])\n",
+ " conv1, conv2, conv3 = self.conv1(output), self.conv2(output), self.conv3(output)\n",
+ "\n",
+ " pool1 = F.max_pool1d(conv1, conv1.shape[-1]).squeeze(2)\n",
+ " pool2 = F.max_pool1d(conv2, conv2.shape[-1]).squeeze(2)\n",
+ " pool3 = F.max_pool1d(conv3, conv3.shape[-1]).squeeze(2)\n",
+ "\n",
+ " pool = paddle.concat([pool1, pool2, pool3], axis=1)\n",
+ " output = self.mlp(pool)\n",
+ " return output\n",
+ " \n",
+ " def train_step(self, words, target):\n",
+ " pred = self(words)\n",
+ " target = paddle.stack((1 - target, target), axis=1).cast(pred.dtype)\n",
+ " return {'loss': self.loss_fn(pred, target)}\n",
+ "\n",
+ " def evaluate_step(self, words, target):\n",
+ " pred = self(words)\n",
+ " pred = paddle.argmax(pred, axis=-1)\n",
+ " return {'pred': pred, 'target': target}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "c63b030f",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "W0604 21:02:25.453869 19014 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 6.1, Driver API Version: 11.1, Runtime API Version: 10.2\n",
+ "W0604 21:02:26.061690 19014 gpu_context.cc:306] device: 0, cuDNN Version: 7.6.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "ClsByPaddle(\n",
+ " (embedding): Embedding(8458, 100, sparse=False)\n",
+ " (conv1): Sequential(\n",
+ " (0): Conv1D(100, 30, kernel_size=[1], data_format=NCL)\n",
+ " (1): ReLU()\n",
+ " )\n",
+ " (conv2): Sequential(\n",
+ " (0): Conv1D(100, 40, kernel_size=[3], padding=1, data_format=NCL)\n",
+ " (1): ReLU()\n",
+ " )\n",
+ " (conv3): Sequential(\n",
+ " (0): Conv1D(100, 50, kernel_size=[5], padding=2, data_format=NCL)\n",
+ " (1): ReLU()\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (dropout): Dropout(p=0.5, axis=None, mode=upscale_in_train)\n",
+ " (linear_1): Linear(in_features=120, out_features=64, dtype=float32)\n",
+ " (activate): ReLU()\n",
+ " (linear_2): Linear(in_features=64, out_features=2, dtype=float32)\n",
+ " )\n",
+ " (loss_fn): MSELoss()\n",
+ ")"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model = ClsByPaddle(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n",
+ "\n",
+ "model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "2997c0aa",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from paddle.optimizer import AdamW\n",
+ "\n",
+ "optimizers = AdamW(parameters=model.parameters(), learning_rate=5e-4)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "ead35fb8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from fastNLP import prepare_paddle_dataloader\n",
+ "\n",
+ "train_dataloader = prepare_paddle_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
+ "evaluate_dataloader = prepare_paddle_dataloader(evaluate_dataset, batch_size=16)\n",
+ "\n",
+ "# dl_bundle = prepare_paddle_dataloader(data_bundle, batch_size=16, shuffle=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "25e8da83",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from fastNLP import Trainer, Accuracy\n",
+ "\n",
+ "trainer = Trainer(\n",
+ " model=model,\n",
+ " driver='paddle',\n",
+ " device='gpu', # 'cpu', 'gpu', 'gpu:x'\n",
+ " n_epochs=10,\n",
+ " optimizers=optimizers,\n",
+ " train_dataloader=train_dataloader, # dl_bundle['train'],\n",
+ " evaluate_dataloaders=evaluate_dataloader, # dl_bundle['dev'], \n",
+ " metrics={'acc': Accuracy()}\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "d63c5d74",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "[21:03:08] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m[21:03:08]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=894986;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=567751;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+ "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+ ".get_parent()\n",
+ " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+ "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+ ".get_parent()\n",
+ " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+ "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+ ".get_parent()\n",
+ " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+ "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+ ".get_parent()\n",
+ " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/paddle/tensor/creation.py:\n",
+ "125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To \n",
+ "silence this warning, use `object` by itself. Doing this will not modify any behavior and is \n",
+ "safe. \n",
+ "Deprecated in NumPy 1.20; for more details and guidance: \n",
+ "https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
+ " if data.dtype == np.object:\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/paddle/tensor/creation.py:\n",
+ "125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To \n",
+ "silence this warning, use `object` by itself. Doing this will not modify any behavior and is \n",
+ "safe. \n",
+ "Deprecated in NumPy 1.20; for more details and guidance: \n",
+ "https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
+ " if data.dtype == np.object:\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.78125,\n",
+ " \"total#acc\": 160.0,\n",
+ " \"correct#acc\": 125.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.78125\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m125.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.7875,\n",
+ " \"total#acc\": 160.0,\n",
+ " \"correct#acc\": 126.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7875\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m126.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.8,\n",
+ " \"total#acc\": 160.0,\n",
+ " \"correct#acc\": 128.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.79375,\n",
+ " \"total#acc\": 160.0,\n",
+ " \"correct#acc\": 127.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.79375\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m127.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.81875,\n",
+ " \"total#acc\": 160.0,\n",
+ " \"correct#acc\": 131.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.8,\n",
+ " \"total#acc\": 160.0,\n",
+ " \"correct#acc\": 128.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.80625,\n",
+ " \"total#acc\": 160.0,\n",
+ " \"correct#acc\": 129.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m129.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.79375,\n",
+ " \"total#acc\": 160.0,\n",
+ " \"correct#acc\": 127.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.79375\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m127.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.7875,\n",
+ " \"total#acc\": 160.0,\n",
+ " \"correct#acc\": 126.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7875\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m126.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.8,\n",
+ " \"total#acc\": 160.0,\n",
+ " \"correct#acc\": 128.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "trainer.run(num_eval_batch_per_dl=10) "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cb9a0b3c",
+ "metadata": {},
+ "source": [
+ "## 2. fastNLP 结合 jittor 训练模型"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "c600191d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import jittor\n",
+ "import jittor.nn as nn\n",
+ "\n",
+ "from jittor import Module\n",
+ "\n",
+ "\n",
+ "class ClsByJittor(Module):\n",
+ " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n",
+ " Module.__init__(self)\n",
+ " self.hidden_dim = hidden_dim\n",
+ "\n",
+ " self.embedding = nn.Embedding(num=vocab_size, dim=embedding_dim)\n",
+ " self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, batch_first=True, # 默认 batch_first=False\n",
+ " num_layers=num_layers, bidirectional=True, dropout=dropout)\n",
+ " self.mlp = nn.Sequential([nn.Dropout(p=dropout),\n",
+ " nn.Linear(hidden_dim * 2, hidden_dim * 2),\n",
+ " nn.ReLU(),\n",
+ " nn.Linear(hidden_dim * 2, output_dim),\n",
+ " nn.Sigmoid(),])\n",
+ "\n",
+ " self.loss_fn = nn.MSELoss()\n",
+ "\n",
+ " def execute(self, words):\n",
+ " output = self.embedding(words)\n",
+ " output, (hidden, cell) = self.lstm(output)\n",
+ " output = self.mlp(jittor.concat((hidden[-1], hidden[-2]), dim=1))\n",
+ " return output\n",
+ " \n",
+ " def train_step(self, words, target):\n",
+ " pred = self(words)\n",
+ " target = jittor.stack((1 - target, target), dim=1)\n",
+ " return {'loss': self.loss_fn(pred, target)}\n",
+ "\n",
+ " def evaluate_step(self, words, target):\n",
+ " pred = self(words)\n",
+ " pred = jittor.argmax(pred, dim=-1)[0]\n",
+ " return {'pred': pred, 'target': target}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "a94ed8c4",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "ClsByJittor(\n",
+ " embedding: Embedding(8458, 100)\n",
+ " lstm: LSTM(100, 64, 2, bias=True, batch_first=True, dropout=0.5, bidirectional=True, proj_size=0)\n",
+ " mlp: Sequential(\n",
+ " 0: Dropout(0.5, is_train=False)\n",
+ " 1: Linear(128, 128, float32[128,], None)\n",
+ " 2: relu()\n",
+ " 3: Linear(128, 2, float32[2,], None)\n",
+ " 4: Sigmoid()\n",
+ " )\n",
+ " loss_fn: MSELoss(mean)\n",
+ ")"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model = ClsByJittor(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n",
+ "\n",
+ "model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "6d15ebc1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from jittor.optim import AdamW\n",
+ "\n",
+ "optimizers = AdamW(params=model.parameters(), lr=5e-3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "95d8d09e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from fastNLP import prepare_jittor_dataloader\n",
+ "\n",
+ "train_dataloader = prepare_jittor_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
+ "evaluate_dataloader = prepare_jittor_dataloader(evaluate_dataset, batch_size=16)\n",
+ "\n",
+ "# dl_bundle = prepare_jittor_dataloader(data_bundle, batch_size=16, shuffle=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "917eab81",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from fastNLP import Trainer, Accuracy\n",
+ "\n",
+ "trainer = Trainer(\n",
+ " model=model,\n",
+ " driver='jittor',\n",
+ " device='gpu', # 'cpu', 'gpu', 'cuda'\n",
+ " n_epochs=10,\n",
+ " optimizers=optimizers,\n",
+ " train_dataloader=train_dataloader, # dl_bundle['train'],\n",
+ " evaluate_dataloaders=evaluate_dataloader, # dl_bundle['dev'],\n",
+ " metrics={'acc': Accuracy()}\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "f7c4ac5a",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "[21:05:51] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m[21:05:51]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=69759;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=202322;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Compiling Operators(5/6) used: 8.31s eta: 1.66s 6/6) used: 9.33s eta: 0s \n",
+ "\n",
+ "Compiling Operators(31/31) used: 7.31s eta: 0s \n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.61875,\n",
+ " \"total#acc\": 160,\n",
+ " \"correct#acc\": 99\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.61875\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m99\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.7,\n",
+ " \"total#acc\": 160,\n",
+ " \"correct#acc\": 112\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m112\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.725,\n",
+ " \"total#acc\": 160,\n",
+ " \"correct#acc\": 116\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.725\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m116\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.74375,\n",
+ " \"total#acc\": 160,\n",
+ " \"correct#acc\": 119\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.74375\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m119\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.75625,\n",
+ " \"total#acc\": 160,\n",
+ " \"correct#acc\": 121\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.75625\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m121\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.75625,\n",
+ " \"total#acc\": 160,\n",
+ " \"correct#acc\": 121\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.75625\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m121\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.73125,\n",
+ " \"total#acc\": 160,\n",
+ " \"correct#acc\": 117\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.73125\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m117\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.7625,\n",
+ " \"total#acc\": 160,\n",
+ " \"correct#acc\": 122\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7625\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m122\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.74375,\n",
+ " \"total#acc\": 160,\n",
+ " \"correct#acc\": 119\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.74375\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m119\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.7625,\n",
+ " \"total#acc\": 160,\n",
+ " \"correct#acc\": 122\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7625\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m122\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "trainer.run(num_eval_batch_per_dl=10)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3df5f425",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/source/tutorials/fastnlp_tutorial_e1.ipynb b/docs/source/tutorials/fastnlp_tutorial_e1.ipynb
new file mode 100644
index 00000000..af8e60a0
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_tutorial_e1.ipynb
@@ -0,0 +1,1280 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " 从这篇开始,我们将开启 **fastNLP v1.0 tutorial 的 example 系列**,在接下来的\n",
+ "\n",
+ " 每篇`tutorial`里,我们将会介绍`fastNLP v1.0`在自然语言处理任务上的应用实例"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[, , ]\n"
+ ]
+ }
+ ],
+ "source": [
+ "from pygments.plugin import find_plugin_lexers\n",
+ "print(list(find_plugin_lexers()))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# E1. 使用 Bert + fine-tuning 完成 SST-2 分类\n",
+ "\n",
+ " 1 基础介绍:`GLUE`通用语言理解评估、`SST-2`文本情感二分类数据集 \n",
+ "\n",
+ " 2 准备工作:加载`tokenizer`、预处理`dataset`、`dataloader`使用\n",
+ "\n",
+ " 3 模型训练:加载`distilbert-base`、`fastNLP`参数匹配、`fine-tuning`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "4.18.0\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "from torch.optim import AdamW\n",
+ "from torch.utils.data import DataLoader, Dataset\n",
+ "\n",
+ "import transformers\n",
+ "from transformers import AutoTokenizer\n",
+ "from transformers import AutoModelForSequenceClassification\n",
+ "\n",
+ "import sys\n",
+ "sys.path.append('..')\n",
+ "\n",
+ "import fastNLP\n",
+ "from fastNLP import Trainer\n",
+ "from fastNLP import Accuracy\n",
+ "\n",
+ "print(transformers.__version__)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1. 基础介绍:GLUE 通用语言理解评估、SST-2 文本情感二分类数据集\n",
+ "\n",
+ " 本示例使用`GLUE`评估基准中的`SST-2`数据集,通过`fine-tuning`方式\n",
+ "\n",
+ " 调整`distilbert-bert`分类模型,以下首先简单介绍下`GLUE`和`SST-2`\n",
+ "\n",
+ "**GLUE**,**全称 General Language Understanding Evaluation**,**通用语言理解评估**,\n",
+ "\n",
+ " 包含9个数据集,各语料的语言均为英语,涉及多个自然语言理解`NLU`任务,包括\n",
+ "\n",
+ " **CoLA**,文本分类任务,预测单句语法正误分类;**SST-2**,文本分类任务,预测单句情感二分类\n",
+ "\n",
+ " **MRPC**,句对分类任务,预测句对语义一致性;**STS-B**,相似度打分任务,预测句对语义相似度回归\n",
+ "\n",
+ " **QQP**,句对分类任务,预测问题对语义一致性;**MNLI**,文本推理任务,预测句对蕴含/矛盾/中立预测\n",
+ "\n",
+ " **QNLI / RTE / WNLI**,文本推理,预测是否蕴含二分类(其中,`QNLI`从`SQuAD`转化而来\n",
+ "\n",
+ " 诸如`BERT`、`T5`等经典模型都会在此基准上验证效果,更多参考[GLUE论文](https://arxiv.org/pdf/1804.07461v3.pdf)\n",
+ "\n",
+ " 此处,我们使用`SST-2`来训练`bert`,实现文本分类,其他任务描述见下图"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "GLUE_TASKS = ['cola', 'mnli', 'mrpc', 'qnli', 'qqp', 'rte', 'sst2', 'stsb', 'wnli']\n",
+ "\n",
+ "task = 'sst2'"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "\n",
+ "**SST**,**全称`Stanford Sentiment Treebank**,**斯坦福情感树库**,**单句情感分类**数据集\n",
+ "\n",
+ " 包含电影评论语句和对应的情感极性,1 对应`positive` 正面情感,0 对应`negative` 负面情感\n",
+ "\n",
+ " 数据集包括三部分:训练集 67350 条,验证集 873 条,测试集 1821 条,更多参考[下载链接](https://gluebenchmark.com/tasks)\n",
+ "\n",
+ "对应到代码上,此处使用`datasets`模块中的`load_dataset`函数,指定`SST-2`数据集,自动加载\n",
+ "\n",
+ " 首次下载后会保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "c5915debacf9443986b5b3b34870b303",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/3 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from datasets import load_dataset\n",
+ "\n",
+ "dataset = load_dataset('glue', task)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " 加载之后,根据`GLUE`中`SST-2`数据集的格式,尝试打印部分数据,检查加载结果"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Sentence: hide new secretions from the parental units \n"
+ ]
+ }
+ ],
+ "source": [
+ "task_to_keys = {\n",
+ " 'cola': ('sentence', None),\n",
+ " 'mnli': ('premise', 'hypothesis'),\n",
+ " 'mnli': ('premise', 'hypothesis'),\n",
+ " 'mrpc': ('sentence1', 'sentence2'),\n",
+ " 'qnli': ('question', 'sentence'),\n",
+ " 'qqp': ('question1', 'question2'),\n",
+ " 'rte': ('sentence1', 'sentence2'),\n",
+ " 'sst2': ('sentence', None),\n",
+ " 'stsb': ('sentence1', 'sentence2'),\n",
+ " 'wnli': ('sentence1', 'sentence2'),\n",
+ "}\n",
+ "\n",
+ "sentence1_key, sentence2_key = task_to_keys[task]\n",
+ "\n",
+ "if sentence2_key is None:\n",
+ " print(f\"Sentence: {dataset['train'][0][sentence1_key]}\")\n",
+ "else:\n",
+ " print(f\"Sentence 1: {dataset['train'][0][sentence1_key]}\")\n",
+ " print(f\"Sentence 2: {dataset['train'][0][sentence2_key]}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2. 准备工作:加载 tokenizer、预处理 dataset、dataloader 使用\n",
+ "\n",
+ " 接下来进入模型训练的准备工作,分别需要使用`tokenizer`模块对数据集进行分词与标注\n",
+ "\n",
+ " 定义`SeqClsDataset`对应`dataloader`模块用来实现数据集在训练/测试时的加载\n",
+ "\n",
+ "此处的`tokenizer`和`SequenceClassificationModel`都是基于**distilbert-base-uncased 模型**\n",
+ "\n",
+ " 即使用较小的、不区分大小写的数据集,**对 bert-base 进行知识蒸馏后的版本**,结构上\n",
+ "\n",
+ " 包含**1个编码层**、**6个自注意力层**,**参数量`66M**,详解见本篇末尾,更多请参考[DistilBert论文](https://arxiv.org/pdf/1910.01108.pdf)\n",
+ "\n",
+ "首先,通过从`transformers`库中导入 **AutoTokenizer 模块**,**使用 from_pretrained 函数初始化**\n",
+ "\n",
+ " 此处的`use_fast`表示是否使用`tokenizer`的快速版本;尝试序列化示例数据,检查加载结果\n",
+ "\n",
+ " 需要注意的是,处理后返回的两个键值,**'input_ids'**表示原始文本对应的词素编号序列\n",
+ "\n",
+ " **'attention_mask'**表示自注意力运算时的掩码(标上`0`的部分对应`padding`的内容"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n"
+ ]
+ }
+ ],
+ "source": [
+ "model_checkpoint = 'distilbert-base-uncased'\n",
+ "\n",
+ "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n",
+ "\n",
+ "print(tokenizer(\"Hello, this one sentence!\", \"And this sentence goes with it.\"))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "接着,定义预处理函数,**通过 dataset.map 方法**,**将数据集中的文本**,**替换为词素编号序列**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ca1fbe5e8eb059f3.arrow\n",
+ "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-03661263fbf302f5.arrow\n",
+ "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-fbe8e7a4e4f18f45.arrow\n"
+ ]
+ }
+ ],
+ "source": [
+ "def preprocess_function(examples):\n",
+ " if sentence2_key is None:\n",
+ " return tokenizer(examples[sentence1_key], truncation=True)\n",
+ " return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)\n",
+ "\n",
+ "encoded_dataset = dataset.map(preprocess_function, batched=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "然后,通过继承`torch`中的`Dataset`类,定义`SeqClsDataset`类,需要注意的是\n",
+ "\n",
+ " 其中,**\\_\\_getitem\\_\\_ 函数各返回值引用的键值**,**必须和原始数据集中的属性对应**\n",
+ "\n",
+ " 例如,`'label'`是`SST-2`数据集中原有的内容(包括`'sentence'`和`'label'`\n",
+ "\n",
+ " `'input_ids'`和`'attention_mask'`则是`tokenizer`处理后添加的字段"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class SeqClsDataset(Dataset):\n",
+ " def __init__(self, dataset):\n",
+ " Dataset.__init__(self)\n",
+ " self.dataset = dataset\n",
+ "\n",
+ " def __len__(self):\n",
+ " return len(self.dataset)\n",
+ "\n",
+ " def __getitem__(self, item):\n",
+ " item = self.dataset[item]\n",
+ " return item['input_ids'], item['attention_mask'], [item['label']] "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "再然后,**定义校对函数 collate_fn 对齐同个 batch 内的每笔数据**,需要注意的是该函数的\n",
+ "\n",
+ " **返回值必须是字典**,**键值必须同待训练模型的 train_step 和 evaluate_step 函数的参数**\n",
+ "\n",
+ " **相对应**;这也就是在`tutorial-0`中便被强调的,`fastNLP v1.0`的第一条**参数匹配**机制"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def collate_fn(batch):\n",
+ " input_ids, atten_mask, labels = [], [], []\n",
+ " max_length = [0] * 3\n",
+ " for each_item in batch:\n",
+ " input_ids.append(each_item[0])\n",
+ " max_length[0] = max(max_length[0], len(each_item[0]))\n",
+ " atten_mask.append(each_item[1])\n",
+ " max_length[1] = max(max_length[1], len(each_item[1]))\n",
+ " labels.append(each_item[2])\n",
+ " max_length[2] = max(max_length[2], len(each_item[2]))\n",
+ "\n",
+ " for i in range(3):\n",
+ " each = (input_ids, atten_mask, labels)[i]\n",
+ " for item in each:\n",
+ " item.extend([0] * (max_length[i] - len(item)))\n",
+ " return {'input_ids': torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n",
+ " 'attention_mask': torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n",
+ " 'labels': torch.cat([torch.tensor(item) for item in labels], dim=0)}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "最后,分别对`tokenizer`处理过的训练集数据、验证集数据,进行预处理和批量划分"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset_train = SeqClsDataset(encoded_dataset['train'])\n",
+ "dataloader_train = DataLoader(dataset=dataset_train, \n",
+ " batch_size=32, shuffle=True, collate_fn=collate_fn)\n",
+ "dataset_valid = SeqClsDataset(encoded_dataset['validation'])\n",
+ "dataloader_valid = DataLoader(dataset=dataset_valid, \n",
+ " batch_size=32, shuffle=False, collate_fn=collate_fn)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3. 模型训练:加载 distilbert-base、fastNLP 参数匹配、fine-tuning\n",
+ "\n",
+ " 最后就是模型训练的,分别需要使用`distilbert-base-uncased`搭建分类模型\n",
+ "\n",
+ " 初始化优化器`optimizer`、训练模块`trainer`,通过`run`函数完成训练\n",
+ "\n",
+ "此处使用的`nn.Module`模块搭建模型,与`tokenizer`类似,通过从`transformers`库中\n",
+ "\n",
+ " 导入`AutoModelForSequenceClassification`模块,基于`distilbert-base-uncased`模型初始\n",
+ "\n",
+ "需要注意的是**AutoModelForSequenceClassification 模块的输入参数和输出结构**\n",
+ "\n",
+ " 一方面,可以**通过输入标签值 labels**,**使用模块内的损失函数计算损失 loss**\n",
+ "\n",
+ " 并且可以选择输入是词素编号序列`input_ids`,还是词素嵌入序列`inputs_embeds`\n",
+ "\n",
+ " 另方面,该模块不会直接输出预测结果,而是会**输出各预测分类上的几率 logits**\n",
+ "\n",
+ " 基于上述描述,此处完成了中`train_step`和`evaluate_step`函数的定义\n",
+ "\n",
+ " 同样需要注意,函数的返回值体现了`fastNLP v1.0`的第二条**参数匹配**机制"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class SeqClsModel(nn.Module):\n",
+ " def __init__(self, num_labels, model_checkpoint):\n",
+ " nn.Module.__init__(self)\n",
+ " self.num_labels = num_labels\n",
+ " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n",
+ " num_labels=num_labels)\n",
+ "\n",
+ " def forward(self, input_ids, attention_mask, labels=None):\n",
+ " output = self.back_bone(input_ids=input_ids, \n",
+ " attention_mask=attention_mask, labels=labels)\n",
+ " return output\n",
+ "\n",
+ " def train_step(self, input_ids, attention_mask, labels):\n",
+ " loss = self(input_ids, attention_mask, labels).loss\n",
+ " return {'loss': loss}\n",
+ "\n",
+ " def evaluate_step(self, input_ids, attention_mask, labels):\n",
+ " pred = self(input_ids, attention_mask, labels).logits\n",
+ " pred = torch.max(pred, dim=-1)[1]\n",
+ " return {'pred': pred, 'target': labels}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "接着,通过确定分类数量初始化模型实例,同时调用`torch.optim.AdamW`模块初始化优化器"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias']\n",
+ "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.weight', 'classifier.bias', 'pre_classifier.bias']\n",
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
+ ]
+ }
+ ],
+ "source": [
+ "num_labels = 3 if task == 'mnli' else 1 if task == 'stsb' else 2\n",
+ "\n",
+ "model = SeqClsModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n",
+ "\n",
+ "optimizers = AdamW(params=model.parameters(), lr=5e-5)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "然后,使用之前完成的`dataloader_train`和`dataloader_valid`,定义训练模块`trainer`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trainer = Trainer(\n",
+ " model=model,\n",
+ " driver='torch',\n",
+ " device=0, # 'cuda'\n",
+ " n_epochs=10,\n",
+ " optimizers=optimizers,\n",
+ " train_dataloader=dataloader_train,\n",
+ " evaluate_dataloaders=dataloader_valid,\n",
+ " metrics={'acc': Accuracy()}\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "最后,使用`trainer.run`方法,训练模型,`n_epochs`参数中已经指定需要迭代`10`轮\n",
+ "\n",
+ " `num_eval_batch_per_dl`参数则指定每次只对验证集中的`10`个`batch`进行评估"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "[09:12:45] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m[09:12:45]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=408427;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=303634;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.884375,\n",
+ " \"total#acc\": 320.0,\n",
+ " \"correct#acc\": 283.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.878125,\n",
+ " \"total#acc\": 320.0,\n",
+ " \"correct#acc\": 281.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.884375,\n",
+ " \"total#acc\": 320.0,\n",
+ " \"correct#acc\": 283.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.9,\n",
+ " \"total#acc\": 320.0,\n",
+ " \"correct#acc\": 288.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.9\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m288.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.8875,\n",
+ " \"total#acc\": 320.0,\n",
+ " \"correct#acc\": 284.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.88125,\n",
+ " \"total#acc\": 320.0,\n",
+ " \"correct#acc\": 282.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.88125\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m282.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.875,\n",
+ " \"total#acc\": 320.0,\n",
+ " \"correct#acc\": 280.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.875\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m280.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.865625,\n",
+ " \"total#acc\": 320.0,\n",
+ " \"correct#acc\": 277.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.865625\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m277.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.884375,\n",
+ " \"total#acc\": 320.0,\n",
+ " \"correct#acc\": 283.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.878125,\n",
+ " \"total#acc\": 320.0,\n",
+ " \"correct#acc\": 281.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "trainer.run(num_eval_batch_per_dl=10)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "{'acc#acc': 0.884174, 'total#acc': 872.0, 'correct#acc': 771.0}"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "trainer.evaluator.run()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 附:`DistilBertForSequenceClassification`模块结构\n",
+ "\n",
+ "```\n",
+ "\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.7.13 ('fnlp-paddle')",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.13"
+ },
+ "pycharm": {
+ "stem_cell": {
+ "cell_type": "raw",
+ "metadata": {
+ "collapsed": false
+ },
+ "source": []
+ }
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "31f2d9d3efc23c441973d7c4273acfea8b132b6a578f002629b6b44b8f65e720"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
diff --git a/docs/source/tutorials/fastnlp_tutorial_e2.ipynb b/docs/source/tutorials/fastnlp_tutorial_e2.ipynb
new file mode 100644
index 00000000..588ee8c3
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_tutorial_e2.ipynb
@@ -0,0 +1,1082 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# E2. 使用 Bert + prompt 完成 SST-2 分类\n",
+ "\n",
+ " 1 基础介绍:`prompt-based model`简介、与`fastNLP`的结合\n",
+ "\n",
+ " 2 准备工作:`P-Tuning v2`原理概述、`P-Tuning v2`模型搭建\n",
+ "\n",
+ " 3 模型训练:加载`tokenizer`、预处理`dataset`、模型训练与分析"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1. 基础介绍:prompt-based model 简介、与 fastNLP 的结合\n",
+ "\n",
+ " 本示例使用`GLUE`评估基准中的`SST-2`数据集,通过`prompt-based tuning`方式\n",
+ "\n",
+ " 微调`bert-base-uncased`模型,实现文本情感的二分类,在此之前本示例\n",
+ "\n",
+ " 将首先简单介绍提示学习模型的研究,以及与`fastNLP v1.0`结合的优势\n",
+ "\n",
+ "**prompt**,**提示词**,最早出自论文[Exploiting Cloze Questions for Few Shot TC and NLI](https://arxiv.org/pdf/2001.07676.pdf)中的 **PET 模型**\n",
+ "\n",
+ " 全称 **Pattern-Exploiting Training**,虽然文中并没有提到`prompt`的说法,但仍被视为开山之作\n",
+ "\n",
+ " 其大致思路包括,对于文本分类任务,假定输入文本为`\" X . \"`,设计**输入模板 template**,**后来被称为 prompt**\n",
+ "\n",
+ " 将输入重构为`\" X . It is [MASK] . \"`,**诱导或刺激语言模型在 [MASK] 位置生成含有情感倾向的词汇**\n",
+ "\n",
+ " 接着将该词汇**输入分类器中**,**后来被称为 verbalizer**,从而得到该语句对应的情感倾向,实现文本分类\n",
+ "\n",
+ " 其主要贡献在于,通过构造`prompt`,诱导/刺激预训练模型生成期望适应下游任务特征,适合少样本学习的需求\n",
+ "\n",
+ "\n",
+ "\n",
+ "**prompt-based tuning**,**基于提示的微调**,将`prompt`应用于**参数高效微调**,**parameter-efficient tuning**\n",
+ "\n",
+ " 通过**设计模板调整模型输入**或者**调整模型内部状态**,**固定预训练模型**,**诱导/刺激模型**调整输出以适应\n",
+ "\n",
+ " 当前任务,极大降低了训练开销,也省去了`verbalizer`的构造,更多参考[prompt综述](https://arxiv.org/pdf/2107.13586.pdf)、[DeltaTuning综述](https://arxiv.org/pdf/2203.06904.pdf)\n",
+ "\n",
+ " 以下列举些经典的`prompt-based tuning`案例,简单地介绍下`prompt-based tuning`的脉络\n",
+ "\n",
+ " **案例一**:**PrefixTuning**,详细内容参考[PrefixTuning论文](https://arxiv.org/pdf/2101.00190.pdf)\n",
+ "\n",
+ " 其主要贡献在于,**提出连续的、非人工构造的、任务导向的 prompt**,即**前缀 prefix**,**调整**\n",
+ "\n",
+ " **模型内部更新状态**,诱导模型在特定任务下生成期望目标,降低优化难度,提升微调效果\n",
+ "\n",
+ " 其主要研究对象,是`GPT2`和`BART`,主要面向生成任务`NLG`,如`table-to-text`和摘要\n",
+ "\n",
+ " **案例二**:**P-Tuning v1**,详细内容参考[P-Tuning-v1论文](https://arxiv.org/pdf/2103.10385.pdf)\n",
+ "\n",
+ " 其主要贡献在于,**通过连续的、非人工构造的 prompt 调整模型输入**,取代原先基于单词设计的\n",
+ "\n",
+ " 但离散且不易于优化的`prompt`;同时也**证明了 GPT2 在语言理解任务上仍然是可以胜任的**\n",
+ "\n",
+ " 其主要研究对象,是`GPT2`,主要面向知识探测`knowledge probing`和自然语言理解`NLU`\n",
+ "\n",
+ " **案例三**:**PromptTuning**,详细内容参考[PromptTuning论文](https://arxiv.org/pdf/2104.08691.pdf)\n",
+ "\n",
+ " 其主要贡献在于,通过连续的`prompt`调整模型输入,**证明了 prompt-based tuning 的效果**\n",
+ "\n",
+ " **随模型参数量的增加而提升**,最终**在 10B 左右追上了全参数微调 fine-tuning 的效果**\n",
+ "\n",
+ " 其主要面向自然语言理解`NLU`,通过为每个任务定义不同的`prompt`,从而支持多任务语境\n",
+ "\n",
+ "通过上述介绍可以发现`prompt-based tuning`只是模型微调方式,独立于预训练模型基础`backbone`\n",
+ "\n",
+ " 目前,加载预训练模型的主流方法是使用**transformers 模块**,而实现微调的框架则\n",
+ "\n",
+ " 可以是`pytorch`、`paddle`、`jittor`等,而不同框架间又存在不兼容的问题\n",
+ "\n",
+ " 因此,**使用 fastNLP v1.0 实现 prompt-based tuning**,可以**很好地解决 paddle 等框架**\n",
+ "\n",
+ " **和 transformers 模块之间的桥接**(`transformers`模块基于`pytorch`实现)\n",
+ "\n",
+ "本示例仍使用了`tutorial-E1`的`SST-2`数据集、`distilbert-base-uncased`模型(便于比较\n",
+ "\n",
+ " 使用`pytorch`框架,通过将连续的`prompt`与`model`拼接,解决`SST-2`二分类任务"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "4.18.0\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "from torch.optim import AdamW\n",
+ "from torch.utils.data import DataLoader, Dataset\n",
+ "\n",
+ "import transformers\n",
+ "from transformers import AutoTokenizer\n",
+ "from transformers import AutoModelForSequenceClassification\n",
+ "\n",
+ "import sys\n",
+ "sys.path.append('..')\n",
+ "\n",
+ "import fastNLP\n",
+ "from fastNLP import Trainer\n",
+ "from fastNLP.core.metrics import Accuracy\n",
+ "\n",
+ "print(transformers.__version__)\n",
+ "\n",
+ "task = 'sst2'\n",
+ "model_checkpoint = 'distilbert-base-uncased' # 'bert-base-uncased'"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2. 准备工作:P-Tuning v2 原理概述、P-Tuning v2 模型搭建\n",
+ "\n",
+ " 本示例使用`P-Tuning v2`作为`prompt-based tuning`与`fastNLP v1.0`结合的案例\n",
+ "\n",
+ " 以下首先简述`P-Tuning v2`的论文原理,并由此引出`fastNLP v1.0`的代码实践\n",
+ "\n",
+ "**P-Tuning v2**出自论文[Prompt Tuning Can Be Comparable to Fine-tuning Universally Across Scales and Tasks](https://arxiv.org/pdf/2110.07602.pdf)\n",
+ "\n",
+ " 其主要贡献在于,**在 PrefixTuning 等深度提示学习基础上**,**提升了其在分类标注等 NLU 任务的表现**\n",
+ "\n",
+ " 并使之在中等规模模型,主要是**参数量在 100M-1B 区间的模型上**,**获得与全参数微调相同的效果**\n",
+ "\n",
+ " 其结构如图所示,通过**在输入序列的分类符 [CLS] 之前**,**加入前缀序列**(**序号对应嵌入是待训练的连续值向量**\n",
+ "\n",
+ " **刺激模型在新任务下**,从`[CLS]`对应位置,**输出符合微调任务的输出**,从而达到适应微调任务的目的\n",
+ "\n",
+ "\n",
+ "\n",
+ "本示例使用`bert-base-uncased`模型,作为`P-Tuning v2`的基础`backbone`,设置`requires_grad=False`\n",
+ "\n",
+ " 固定其参数不参与训练,**设置 pre_seq_len 长的 prefix_tokens 作为输入的提示前缀序列**\n",
+ "\n",
+ " **使用基于 nn.Embedding 的 prefix_encoder 为提示前缀嵌入**,通过`get_prompt`函数获取,再将之\n",
+ "\n",
+ " 拼接至批量内每笔数据前得到`inputs_embeds`,同时更新自注意力掩码`attention_mask`\n",
+ "\n",
+ " 将`inputs_embeds`、`attention_mask`和`labels`输入`backbone`,**得到输出包括 loss 和 logits**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class SeqClsModel(nn.Module):\n",
+ " def __init__(self, model_checkpoint, num_labels, pre_seq_len):\n",
+ " nn.Module.__init__(self)\n",
+ " self.num_labels = num_labels\n",
+ " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n",
+ " num_labels=num_labels)\n",
+ " self.embeddings = self.back_bone.get_input_embeddings()\n",
+ "\n",
+ " for param in self.back_bone.parameters():\n",
+ " param.requires_grad = False\n",
+ " \n",
+ " self.pre_seq_len = pre_seq_len\n",
+ " self.prefix_tokens = torch.arange(self.pre_seq_len).long()\n",
+ " self.prefix_encoder = nn.Embedding(self.pre_seq_len, self.embeddings.embedding_dim)\n",
+ " \n",
+ " def get_prompt(self, batch_size):\n",
+ " prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.back_bone.device)\n",
+ " prompts = self.prefix_encoder(prefix_tokens)\n",
+ " return prompts\n",
+ "\n",
+ " def forward(self, input_ids, attention_mask, labels=None):\n",
+ " \n",
+ " batch_size = input_ids.shape[0]\n",
+ " raw_embedding = self.embeddings(input_ids)\n",
+ " \n",
+ " prompts = self.get_prompt(batch_size=batch_size)\n",
+ " inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)\n",
+ " prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.back_bone.device)\n",
+ " attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)\n",
+ "\n",
+ " outputs = self.back_bone(inputs_embeds=inputs_embeds, \n",
+ " attention_mask=attention_mask, labels=labels)\n",
+ " return outputs\n",
+ "\n",
+ " def train_step(self, input_ids, attention_mask, labels):\n",
+ " loss = self(input_ids, attention_mask, labels).loss\n",
+ " return {'loss': loss}\n",
+ "\n",
+ " def evaluate_step(self, input_ids, attention_mask, labels):\n",
+ " pred = self(input_ids, attention_mask, labels).logits\n",
+ " pred = torch.max(pred, dim=-1)[1]\n",
+ " return {'pred': pred, 'target': labels}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "接着,通过确定分类数量初始化模型实例,同时调用`torch.optim.AdamW`模块初始化优化器\n",
+ "\n",
+ " 根据`P-Tuning v2`论文:`Generally, simple classification tasks prefer shorter prompts (less than 20)`\n",
+ "\n",
+ " 此处`pre_seq_len`参数设定为`20`,学习率相应做出调整,其他内容和`tutorial-E1`中的内容一致"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_projector.bias']\n",
+ "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'pre_classifier.bias', 'classifier.bias']\n",
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
+ ]
+ }
+ ],
+ "source": [
+ "model = SeqClsModel(model_checkpoint=model_checkpoint, num_labels=2, pre_seq_len=20)\n",
+ "\n",
+ "optimizers = AdamW(params=model.parameters(), lr=1e-2)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3. 模型训练:加载 tokenizer、预处理 dataset、模型训练与分析\n",
+ "\n",
+ " 本示例沿用`tutorial-E1`中的数据集,即使用`GLUE`评估基准中的`SST-2`数据集\n",
+ "\n",
+ " 以`bert-base-uncased`模型作为基准,基于`P-Tuning v2`方式微调\n",
+ "\n",
+ " 数据集加载相关代码流程见下,内容和`tutorial-E1`中的内容基本一致\n",
+ "\n",
+ "首先,使用`datasets.load_dataset`加载数据集,使用`transformers.AutoTokenizer`\n",
+ "\n",
+ " 构建`tokenizer`实例,通过`dataset.map`使用`tokenizer`将文本替换为词素序号序列"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "21cbd92c3397497d84dc10f017ec96f4",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/3 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from datasets import load_dataset, load_metric\n",
+ "\n",
+ "dataset = load_dataset('glue', task)\n",
+ "\n",
+ "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-294e481a713c5754.arrow\n",
+ "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ed9d9258aaf0fb54.arrow\n",
+ "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-f44c5576b89f9e6b.arrow\n"
+ ]
+ }
+ ],
+ "source": [
+ "def preprocess_function(examples):\n",
+ " return tokenizer(examples['sentence'], truncation=True)\n",
+ "\n",
+ "encoded_dataset = dataset.map(preprocess_function, batched=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "然后,定义`SeqClsDataset`类、定义校对函数`collate_fn`,这里沿用`tutorial-E1`中的内容\n",
+ "\n",
+ " 同样需要注意/强调的是,**\\_\\_getitem\\_\\_ 函数的返回值必须和原始数据集中的属性对应**\n",
+ "\n",
+ " **collate_fn 函数的返回值必须和 train_step 和 evaluate_step 函数的参数匹配**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class SeqClsDataset(Dataset):\n",
+ " def __init__(self, dataset):\n",
+ " Dataset.__init__(self)\n",
+ " self.dataset = dataset\n",
+ "\n",
+ " def __len__(self):\n",
+ " return len(self.dataset)\n",
+ "\n",
+ " def __getitem__(self, item):\n",
+ " item = self.dataset[item]\n",
+ " return item['input_ids'], item['attention_mask'], [item['label']] \n",
+ "\n",
+ "def collate_fn(batch):\n",
+ " input_ids, atten_mask, labels = [], [], []\n",
+ " max_length = [0] * 3\n",
+ " for each_item in batch:\n",
+ " input_ids.append(each_item[0])\n",
+ " max_length[0] = max(max_length[0], len(each_item[0]))\n",
+ " atten_mask.append(each_item[1])\n",
+ " max_length[1] = max(max_length[1], len(each_item[1]))\n",
+ " labels.append(each_item[2])\n",
+ " max_length[2] = max(max_length[2], len(each_item[2]))\n",
+ "\n",
+ " for i in range(3):\n",
+ " each = (input_ids, atten_mask, labels)[i]\n",
+ " for item in each:\n",
+ " item.extend([0] * (max_length[i] - len(item)))\n",
+ " return {'input_ids': torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n",
+ " 'attention_mask': torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n",
+ " 'labels': torch.cat([torch.tensor(item) for item in labels], dim=0)}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "再然后,分别对`tokenizer`处理过的训练集数据、验证集数据,进行预处理和批量划分"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset_train = SeqClsDataset(encoded_dataset['train'])\n",
+ "dataloader_train = DataLoader(dataset=dataset_train, \n",
+ " batch_size=32, shuffle=True, collate_fn=collate_fn)\n",
+ "dataset_valid = SeqClsDataset(encoded_dataset['validation'])\n",
+ "dataloader_valid = DataLoader(dataset=dataset_valid, \n",
+ " batch_size=32, shuffle=False, collate_fn=collate_fn)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "最后,使用之前完成的`dataloader_train`和`dataloader_valid`,定义训练模块`trainer`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trainer = Trainer(\n",
+ " model=model,\n",
+ " driver='torch',\n",
+ " device=1, # [0, 1],\n",
+ " n_epochs=10,\n",
+ " optimizers=optimizers,\n",
+ " train_dataloader=dataloader_train,\n",
+ " evaluate_dataloaders=dataloader_valid,\n",
+ " metrics={'acc': Accuracy()}\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " 使用`trainer.run`方法训练模型,同样每次只对验证集中的`10`个`batch`进行评估"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "[22:53:00] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m[22:53:00]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=406635;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=951504;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.540625,\n",
+ " \"total#acc\": 320.0,\n",
+ " \"correct#acc\": 173.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.540625\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m173.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.5,\n",
+ " \"total#acc\": 320.0,\n",
+ " \"correct#acc\": 160.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.5\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.509375,\n",
+ " \"total#acc\": 320.0,\n",
+ " \"correct#acc\": 163.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.509375\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m163.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.634375,\n",
+ " \"total#acc\": 320.0,\n",
+ " \"correct#acc\": 203.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.634375\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m203.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.6125,\n",
+ " \"total#acc\": 320.0,\n",
+ " \"correct#acc\": 196.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.6125\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m196.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.675,\n",
+ " \"total#acc\": 320.0,\n",
+ " \"correct#acc\": 216.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.675\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m216.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.64375,\n",
+ " \"total#acc\": 320.0,\n",
+ " \"correct#acc\": 206.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.64375\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m206.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.665625,\n",
+ " \"total#acc\": 320.0,\n",
+ " \"correct#acc\": 213.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.665625\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m213.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.659375,\n",
+ " \"total#acc\": 320.0,\n",
+ " \"correct#acc\": 211.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.659375\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m211.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#acc\": 0.696875,\n",
+ " \"total#acc\": 320.0,\n",
+ " \"correct#acc\": 223.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.696875\u001b[0m,\n",
+ " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m223.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "trainer.run(num_eval_batch_per_dl=10)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "可以发现,其效果远远逊色于`fine-tuning`,这是因为`P-Tuning v2`虽然能够适应参数量\n",
+ "\n",
+ " 在`100M-1B`区间的模型,但是,**distilbert-base 的参数量仅为 66M**,无法触及其下限\n",
+ "\n",
+ "另一方面,**fastNLP v1.0 不支持 jupyter 多卡**,所以无法在笔者的电脑/服务器上,完成\n",
+ "\n",
+ " 合适规模模型的学习,例如`110M`的`bert-base`模型,以及`340M`的`bert-large`模型"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "{'acc#acc': 0.737385, 'total#acc': 872.0, 'correct#acc': 643.0}"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "trainer.evaluator.run()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.13"
+ },
+ "pycharm": {
+ "stem_cell": {
+ "cell_type": "raw",
+ "metadata": {
+ "collapsed": false
+ },
+ "source": []
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
diff --git a/docs/source/tutorials/fastnlp_tutorial_paddle_e1.ipynb b/docs/source/tutorials/fastnlp_tutorial_paddle_e1.ipynb
new file mode 100644
index 00000000..a5883416
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_tutorial_paddle_e1.ipynb
@@ -0,0 +1,1086 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# E3. 使用 paddlenlp 和 fastNLP 实现中文文本情感分析\n",
+ "\n",
+ "本篇教程属于 **fastNLP v1.0 tutorial 的 paddle examples 系列**。在本篇教程中,我们将为您展示如何使用 `paddlenlp` 自然语言处理库和 `fastNLP` 来完成比较简单的情感分析任务。\n",
+ "\n",
+ "1. 基础介绍:飞桨自然语言处理库 ``paddlenlp`` 和语义理解框架 ``ERNIE``\n",
+ "\n",
+ "2. 准备工作:使用 ``tokenizer`` 处理数据并构造 ``dataloader``\n",
+ "\n",
+ "3. 模型训练:加载 ``ERNIE`` 预训练模型,使用 ``fastNLP`` 进行训练"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1. 基础介绍:飞桨自然语言处理库 paddlenlp 和语义理解框架 ERNIE\n",
+ "\n",
+ "#### 1.1 飞桨自然语言处理库 paddlenlp\n",
+ "\n",
+ "``paddlenlp`` 是由百度以飞桨 ``PaddlePaddle`` 为核心开发的自然语言处理库,集成了多个数据集和 NLP 模型,包括百度自研的语义理解框架 ``ERNIE`` 。在本篇教程中,我们会以 ``paddlenlp`` 为基础,使用模型 ``ERNIE`` 完成中文情感分析任务。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2.3.3\n"
+ ]
+ }
+ ],
+ "source": [
+ "import sys\n",
+ "sys.path.append(\"../\")\n",
+ "\n",
+ "import paddle\n",
+ "import paddlenlp\n",
+ "from paddlenlp.transformers import AutoTokenizer\n",
+ "from paddlenlp.transformers import AutoModelForSequenceClassification\n",
+ "\n",
+ "print(paddlenlp.__version__)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 1.2 语义理解框架 ERNIE\n",
+ "\n",
+ "``ERNIE(Enhanced Representation from kNowledge IntEgration)`` 是百度提出的基于知识增强的持续学习语义理解框架,至今已有 ``ERNIE 2.0``、``ERNIE 3.0``、``ERNIE-M``、``ERNIE-tiny`` 等多种预训练模型。``ERNIE 1.0`` 采用``Transformer Encoder`` 作为其语义表示的骨架,并改进了两种 ``mask`` 策略,分别为基于**短语**和**实体**(人名、组织等)的策略。在 ``ERNIE`` 中,由多个字组成的短语或者实体将作为一个统一单元,在训练的时候被统一地 ``mask`` 掉,这样可以潜在地学习到知识的依赖以及更长的语义依赖来让模型更具泛化性。\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "``ERNIE 2.0`` 则提出了连续学习(``Continual Learning``)的概念,即首先用一个简单的任务来初始化模型,在更新时用前一个任务训练好的参数作为下一个任务模型初始化的参数。这样在训练新的任务时,模型便可以记住之前学习到的知识,使得模型在新任务上获得更好的表现。``ERNIE 2.0`` 分别构建了词法、语法、语义不同级别的预训练任务,并使用不同的 task id 来标示不同的任务,在共计16个中英文任务上都取得了SOTA效果。\n",
+ "\n",
+ "\n",
+ "\n",
+ "``ERNIE 3.0`` 将自回归和自编码网络融合在一起进行预训练,其中自编码网络采用 ``ERNIE 2.0`` 的多任务学习增量式构建预训练任务,持续进行语义理解学习。其中自编码网络增加了知识增强的预训练任务。自回归网络则基于 ``Tranformer-XL`` 结构,支持长文本语言模型建模,并在多个自然语言处理任务中取得了SOTA的效果。\n",
+ "\n",
+ "\n",
+ "\n",
+ "接下来,我们将展示如何在 ``fastNLP`` 中使用基于 ``paddle`` 的 ``ERNIE 1.0`` 框架进行中文情感分析。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2. 使用 tokenizer 处理数据并构造 dataloader\n",
+ "\n",
+ "#### 2.1 加载中文数据集 ChnSentiCorp\n",
+ "\n",
+ "``ChnSentiCorp`` 数据集是由中国科学院发布的中文句子级情感分析数据集,包含了从网络上获取的酒店、电影、书籍等多个领域的评论,每条评论都被划分为两个标签:消极(``0``)和积极(``1``),可以用于二分类的中文情感分析任务。通过 ``paddlenlp.datasets.load_dataset`` 函数,我们可以加载并查看 ``ChnSentiCorp`` 数据集的内容。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "训练集大小: 9600\n",
+ "{'text': '选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般', 'label': 1, 'qid': ''}\n",
+ "{'text': '15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错', 'label': 1, 'qid': ''}\n",
+ "{'text': '房间太小。其他的都一般。。。。。。。。。', 'label': 0, 'qid': ''}\n"
+ ]
+ }
+ ],
+ "source": [
+ "from paddlenlp.datasets import load_dataset\n",
+ "\n",
+ "train_dataset, val_dataset, test_dataset = load_dataset(\"chnsenticorp\", splits=[\"train\", \"dev\", \"test\"])\n",
+ "print(\"训练集大小:\", len(train_dataset))\n",
+ "for i in range(3):\n",
+ " print(train_dataset[i])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 2.2 处理数据\n",
+ "\n",
+ "可以看到,原本的数据集仅包含中文的文本和标签,这样的数据是无法被模型识别的。同英文文本分类任务一样,我们需要使用 ``tokenizer`` 对文本进行分词并转换为数字形式的结果。我们可以加载已经预训练好的中文分词模型 ``ernie-1.0-base-zh``,将分词的过程写在函数 ``_process`` 中,然后调用数据集的 ``map`` 函数对每一条数据进行分词。其中:\n",
+ "- 参数 ``max_length`` 代表句子的最大长度;\n",
+ "- ``padding=\"max_length\"`` 表示将长度不足的结果 padding 至和最大长度相同;\n",
+ "- ``truncation=True`` 表示将长度过长的句子进行截断。\n",
+ "\n",
+ "至此,我们得到了每条数据长度均相同的数据集。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\u001b[32m[2022-06-22 21:31:04,168] [ INFO]\u001b[0m - We are using to load 'ernie-1.0-base-zh'.\u001b[0m\n",
+ "\u001b[32m[2022-06-22 21:31:04,171] [ INFO]\u001b[0m - Already cached /remote-home/shxing/.paddlenlp/models/ernie-1.0-base-zh/vocab.txt\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'text': '选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般', 'label': 1, 'qid': '', 'input_ids': [1, 352, 790, 1252, 409, 283, 509, 5, 250, 196, 113, 10, 58, 518, 4, 9, 128, 70, 1495, 1855, 339, 293, 45, 302, 233, 554, 4, 544, 637, 1134, 774, 6, 494, 2068, 6, 278, 191, 6, 634, 99, 6, 2678, 144, 7, 149, 1573, 62, 12043, 661, 737, 371, 435, 7, 689, 4, 255, 201, 559, 407, 1308, 12043, 2275, 1110, 11, 19, 842, 5, 1207, 878, 4, 196, 198, 321, 96, 4, 16, 93, 291, 464, 1099, 10, 692, 811, 12043, 392, 5, 748, 1134, 10, 213, 220, 5, 4, 201, 559, 723, 595, 12043, 231, 112, 1114, 4, 7, 689, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}\n"
+ ]
+ }
+ ],
+ "source": [
+ "max_len = 128\n",
+ "model_checkpoint = \"ernie-1.0-base-zh\"\n",
+ "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n",
+ "def _process(data):\n",
+ " data.update(tokenizer(\n",
+ " data[\"text\"],\n",
+ " max_length=max_len,\n",
+ " padding=\"max_length\",\n",
+ " truncation=True,\n",
+ " return_attention_mask=True,\n",
+ " ))\n",
+ " return data\n",
+ "\n",
+ "train_dataset.map(_process, num_workers=5)\n",
+ "val_dataset.map(_process, num_workers=5)\n",
+ "test_dataset.map(_process, num_workers=5)\n",
+ "\n",
+ "print(train_dataset[0])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "得到数据集之后,我们便可以将数据集包裹在 ``PaddleDataLoader`` 中,用于之后的训练。``fastNLP`` 提供的 ``PaddleDataLoader`` 拓展了 ``paddle.io.DataLoader`` 的功能,详情可以查看相关的文档。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from fastNLP.core import PaddleDataLoader\n",
+ "import paddle.nn as nn\n",
+ "\n",
+ "train_dataloader = PaddleDataLoader(train_dataset, batch_size=32, shuffle=True)\n",
+ "val_dataloader = PaddleDataLoader(val_dataset, batch_size=32, shuffle=False)\n",
+ "test_dataloader = PaddleDataLoader(test_dataset, batch_size=1, shuffle=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3. 模型训练:加载 ERNIE 预训练模型,使用 fastNLP 进行训练\n",
+ "\n",
+ "#### 3.1 使用 ERNIE 预训练模型\n",
+ "\n",
+ "为了实现文本分类,我们首先需要定义文本分类的模型。``paddlenlp.transformers`` 提供了模型 ``AutoModelForSequenceClassification``,我们可以利用它来加载不同权重的文本分类模型。在 ``fastNLP`` 中,我们可以定义 ``train_step`` 和 ``evaluate_step`` 函数来实现训练和验证过程中的不同行为。\n",
+ "\n",
+ "- ``train_step`` 函数在获得返回值 ``logits`` (大小为 ``(batch_size, num_labels)``)后计算交叉熵损失 ``CrossEntropyLoss``,然后将 ``loss`` 放在字典中返回。``fastNLP`` 也支持返回 ``dataclass`` 类型的训练结果,但二者都需要包含名为 **loss** 的键或成员。\n",
+ "- ``evaluate_step`` 函数在获得返回值 ``logits`` 后,将 ``logits`` 和标签 ``label`` 放在字典中返回。\n",
+ "\n",
+ "这两个函数的参数均为数据集中字典**键**的子集,``fastNLP`` 会自动进行参数匹配然后输入到模型中。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\u001b[32m[2022-06-22 21:31:15,577] [ INFO]\u001b[0m - We are using to load 'ernie-1.0-base-zh'.\u001b[0m\n",
+ "\u001b[32m[2022-06-22 21:31:15,580] [ INFO]\u001b[0m - Already cached /remote-home/shxing/.paddlenlp/models/ernie-1.0-base-zh/ernie_v1_chn_base.pdparams\u001b[0m\n"
+ ]
+ }
+ ],
+ "source": [
+ "import paddle.nn as nn\n",
+ "\n",
+ "class SeqClsModel(nn.Layer):\n",
+ " def __init__(self, model_checkpoint, num_labels):\n",
+ " super(SeqClsModel, self).__init__()\n",
+ " self.model = AutoModelForSequenceClassification.from_pretrained(\n",
+ " model_checkpoint,\n",
+ " num_classes=num_labels,\n",
+ " )\n",
+ "\n",
+ " def forward(self, input_ids, attention_mask, token_type_ids):\n",
+ " logits = self.model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)\n",
+ " return logits\n",
+ "\n",
+ " def train_step(self, input_ids, attention_mask, token_type_ids, label):\n",
+ " logits = self(input_ids, attention_mask, token_type_ids)\n",
+ " loss = nn.CrossEntropyLoss()(logits, label)\n",
+ " return {\"loss\": loss}\n",
+ "\n",
+ " def evaluate_step(self, input_ids, attention_mask, token_type_ids, label):\n",
+ " logits = self(input_ids, attention_mask, token_type_ids)\n",
+ " return {'pred': logits, 'target': label}\n",
+ "\n",
+ "model = SeqClsModel(model_checkpoint, num_labels=2)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2 设置参数并使用 Trainer 开始训练\n",
+ "\n",
+ "现在我们可以着手使用 ``fastNLP.Trainer`` 进行训练了。\n",
+ "\n",
+ "首先,为了高效地训练 ``ERNIE`` 模型,我们最好为学习率指定一定的策略。``paddlenlp`` 提供的 ``LinearDecayWithWarmup`` 可以令学习率在一段时间内从 0 开始线性地增长(预热),然后再线性地衰减至 0 。在本篇教程中,我们将学习率设置为 ``5e-5``,预热时间为 ``0.1``,然后将得到的的 ``lr_scheduler`` 赋值给 ``AdamW`` 优化器。\n",
+ "\n",
+ "其次,我们还可以为 ``Trainer`` 指定多个 ``Callback`` 来在基础的训练过程之外进行额外的定制操作。在本篇教程中,我们使用的 ``Callback`` 有以下三种:\n",
+ "\n",
+ "- ``LRSchedCallback`` - 由于我们使用了 ``Scheduler``,因此需要将 ``lr_scheduler`` 传给该 ``Callback`` 以在训练中进行更新。\n",
+ "- ``LoadBestModelCallback`` - 该 ``Callback`` 会评估结果中的 ``'acc#accuracy'`` 值,保存训练中出现的正确率最高的模型,并在训练结束时加载到模型上,方便对模型进行测试和评估。\n",
+ "\n",
+ "在 ``Trainer`` 中,我们还可以设置 ``metrics`` 来衡量模型的表现。``Accuracy`` 能够根据传入的预测值和真实值计算出模型预测的正确率。还记得模型中 ``evaluate_step`` 函数的返回值吗?键 ``pred`` 和 ``target`` 分别为 ``Accuracy.update`` 的参数名,在验证过程中 ``fastNLP`` 会自动将键和参数名匹配从而计算出正确率,这也是我们规定模型需要返回字典类型数据的原因。\n",
+ "\n",
+ "``Accuracy`` 的返回值包含三个部分:``acc``、``total`` 和 ``correct``,分别代表 ``正确率``、 ``数据总数`` 和 ``预测正确的数目``,这让您能够直观地知晓训练中模型的变化,``LoadBestModelCallback`` 的参数 ``'acc#accuracy'`` 也正是代表了 ``accuracy`` 指标的 ``acc`` 结果。\n",
+ "\n",
+ "在设定好参数之后,调用 ``run`` 函数便可以进行训练和验证了。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "[21:31:16] INFO Running evaluator sanity check for 2 batches. trainer.py:631\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m[21:31:16]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=4641;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=822054;file://../fastNLP/core/controllers/trainer.py#631\u001b\\\u001b[2m631\u001b[0m\u001b]8;;\u001b\\\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:0, Batch:60 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m60\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#accuracy\": 0.895833,\n",
+ " \"total#accuracy\": 1200.0,\n",
+ " \"correct#accuracy\": 1075.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.895833\u001b[0m,\n",
+ " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1075.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:0, Batch:120 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m120\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#accuracy\": 0.8975,\n",
+ " \"total#accuracy\": 1200.0,\n",
+ " \"correct#accuracy\": 1077.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.8975\u001b[0m,\n",
+ " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1077.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:0, Batch:180 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m180\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#accuracy\": 0.911667,\n",
+ " \"total#accuracy\": 1200.0,\n",
+ " \"correct#accuracy\": 1094.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.911667\u001b[0m,\n",
+ " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1094.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:0, Batch:240 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m240\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#accuracy\": 0.9225,\n",
+ " \"total#accuracy\": 1200.0,\n",
+ " \"correct#accuracy\": 1107.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.9225\u001b[0m,\n",
+ " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1107.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:0, Batch:300 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m300\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#accuracy\": 0.9275,\n",
+ " \"total#accuracy\": 1200.0,\n",
+ " \"correct#accuracy\": 1113.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.9275\u001b[0m,\n",
+ " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1113.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:1, Batch:60 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m60\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#accuracy\": 0.930833,\n",
+ " \"total#accuracy\": 1200.0,\n",
+ " \"correct#accuracy\": 1117.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.930833\u001b[0m,\n",
+ " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1117.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:1, Batch:120 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m120\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#accuracy\": 0.935833,\n",
+ " \"total#accuracy\": 1200.0,\n",
+ " \"correct#accuracy\": 1123.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.935833\u001b[0m,\n",
+ " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1123.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:1, Batch:180 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m180\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#accuracy\": 0.935833,\n",
+ " \"total#accuracy\": 1200.0,\n",
+ " \"correct#accuracy\": 1123.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.935833\u001b[0m,\n",
+ " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1123.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:1, Batch:240 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m240\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#accuracy\": 0.9375,\n",
+ " \"total#accuracy\": 1200.0,\n",
+ " \"correct#accuracy\": 1125.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.9375\u001b[0m,\n",
+ " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1125.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:1, Batch:300 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m300\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#accuracy\": 0.941667,\n",
+ " \"total#accuracy\": 1200.0,\n",
+ " \"correct#accuracy\": 1130.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.941667\u001b[0m,\n",
+ " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1130.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "[21:34:28] INFO Loading best model from fnlp-ernie/2022-0 load_best_model_callback.py:111\n",
+ " 6-22-21_29_12_898095/best_so_far with \n",
+ " acc#accuracy: 0.941667... \n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m[21:34:28]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Loading best model from fnlp-ernie/\u001b[1;36m2022\u001b[0m-\u001b[1;36m0\u001b[0m \u001b]8;id=340364;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=763898;file://../fastNLP/core/callbacks/load_best_model_callback.py#111\u001b\\\u001b[2m111\u001b[0m\u001b]8;;\u001b\\\n",
+ "\u001b[2;36m \u001b[0m \u001b[1;36m6\u001b[0m-\u001b[1;36m22\u001b[0m-21_29_12_898095/best_so_far with \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m acc#accuracy: \u001b[1;36m0.941667\u001b[0m\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "[21:34:34] INFO Deleting fnlp-ernie/2022-06-22-21_29_12_8 load_best_model_callback.py:131\n",
+ " 98095/best_so_far... \n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m[21:34:34]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Deleting fnlp-ernie/\u001b[1;36m2022\u001b[0m-\u001b[1;36m06\u001b[0m-\u001b[1;36m22\u001b[0m-21_29_12_8 \u001b]8;id=430330;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=508566;file://../fastNLP/core/callbacks/load_best_model_callback.py#131\u001b\\\u001b[2m131\u001b[0m\u001b]8;;\u001b\\\n",
+ "\u001b[2;36m \u001b[0m 98095/best_so_far\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from fastNLP import LRSchedCallback, LoadBestModelCallback\n",
+ "from fastNLP import Trainer, Accuracy\n",
+ "from paddlenlp.transformers import LinearDecayWithWarmup\n",
+ "\n",
+ "n_epochs = 2\n",
+ "num_training_steps = len(train_dataloader) * n_epochs\n",
+ "lr_scheduler = LinearDecayWithWarmup(5e-5, num_training_steps, 0.1)\n",
+ "optimizer = paddle.optimizer.AdamW(\n",
+ " learning_rate=lr_scheduler,\n",
+ " parameters=model.parameters(),\n",
+ ")\n",
+ "callbacks = [\n",
+ " LRSchedCallback(lr_scheduler, step_on=\"batch\"),\n",
+ " LoadBestModelCallback(\"acc#accuracy\", larger_better=True, save_folder=\"fnlp-ernie\"),\n",
+ "]\n",
+ "trainer = Trainer(\n",
+ " model=model,\n",
+ " driver=\"paddle\",\n",
+ " optimizers=optimizer,\n",
+ " device=0,\n",
+ " n_epochs=n_epochs,\n",
+ " train_dataloader=train_dataloader,\n",
+ " evaluate_dataloaders=val_dataloader,\n",
+ " evaluate_every=60,\n",
+ " metrics={\"accuracy\": Accuracy()},\n",
+ " callbacks=callbacks,\n",
+ ")\n",
+ "trainer.run()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.3 测试和评估\n",
+ "\n",
+ "现在我们已经得到了一个表现良好的 ``ERNIE`` 模型,接下来可以在测试集上测试模型的效果了。``fastNLP.Evaluator`` 提供了定制函数的功能。我们以 ``test_dataloader`` 初始化一个 ``Evaluator``,然后将写好的测试函数 ``test_batch_step_fn`` 传给参数 ``evaluate_batch_step_fn``,``Evaluate`` 在对每个 batch 进行评估时就会调用我们自定义的 ``test_batch_step_fn`` 函数而不是 ``evaluate_step`` 函数。在这里,我们仅测试 5 条数据并输出文本和对应的标签。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "text: ['这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般']\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "text: ['这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般']\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "labels: 0\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "labels: 0\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "text: ['怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片!开始\n",
+ "还怀疑是不是赠送的个别现象,可是后来发现每张DVD后面都有!真不知道生产商怎么想的,我想看的是猫\n",
+ "和老鼠,不是米老鼠!如果厂家是想赠送的话,那就全套米老鼠和唐老鸭都赠送,只在每张DVD后面添加一\n",
+ "集算什么??简直是画蛇添足!!']\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "text: ['怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片!开始\n",
+ "还怀疑是不是赠送的个别现象,可是后来发现每张DVD后面都有!真不知道生产商怎么想的,我想看的是猫\n",
+ "和老鼠,不是米老鼠!如果厂家是想赠送的话,那就全套米老鼠和唐老鸭都赠送,只在每张DVD后面添加一\n",
+ "集算什么??简直是画蛇添足!!']\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "labels: 0\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "labels: 0\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "text: ['还稍微重了点,可能是硬盘大的原故,还要再轻半斤就好了。其他要进一步验证。贴的几种膜气\n",
+ "泡较多,用不了多久就要更换了,屏幕膜稍好点,但比没有要强多了。建议配赠几张膜让用用户自己贴。'\n",
+ "]\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "text: ['还稍微重了点,可能是硬盘大的原故,还要再轻半斤就好了。其他要进一步验证。贴的几种膜气\n",
+ "泡较多,用不了多久就要更换了,屏幕膜稍好点,但比没有要强多了。建议配赠几张膜让用用户自己贴。'\n",
+ "]\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "labels: 0\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "labels: 0\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "text: ['交通方便;环境很好;服务态度很好 房间较小']\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "text: ['交通方便;环境很好;服务态度很好 房间较小']\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "labels: 1\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "labels: 1\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "text: ['不错,作者的观点很颠覆目前中国父母的教育方式,其实古人们对于教育已经有了很系统的体系\n",
+ "了,可是现在的父母以及祖父母们更多的娇惯纵容孩子,放眼看去自私的孩子是大多数,父母觉得自己的\n",
+ "孩子在外面只要不吃亏就是好事,完全把古人几千年总结的教育古训抛在的九霄云外。所以推荐准妈妈们\n",
+ "可以在等待宝宝降临的时候,好好学习一下,怎么把孩子教育成一个有爱心、有责任心、宽容、大度的人\n",
+ "。']\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "text: ['不错,作者的观点很颠覆目前中国父母的教育方式,其实古人们对于教育已经有了很系统的体系\n",
+ "了,可是现在的父母以及祖父母们更多的娇惯纵容孩子,放眼看去自私的孩子是大多数,父母觉得自己的\n",
+ "孩子在外面只要不吃亏就是好事,完全把古人几千年总结的教育古训抛在的九霄云外。所以推荐准妈妈们\n",
+ "可以在等待宝宝降临的时候,好好学习一下,怎么把孩子教育成一个有爱心、有责任心、宽容、大度的人\n",
+ "。']\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "labels: 1\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "labels: 1\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "{}"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from fastNLP import Evaluator\n",
+ "def test_batch_step_fn(evaluator, batch):\n",
+ " input_ids = batch[\"input_ids\"]\n",
+ " attention_mask = batch[\"attention_mask\"]\n",
+ " token_type_ids = batch[\"token_type_ids\"]\n",
+ " logits = model(input_ids, attention_mask, token_type_ids)\n",
+ " predict = logits.argmax().item()\n",
+ " print(\"text:\", batch['text'])\n",
+ " print(\"labels:\", predict)\n",
+ "\n",
+ "evaluator = Evaluator(\n",
+ " model=model,\n",
+ " dataloaders=test_dataloader,\n",
+ " driver=\"paddle\",\n",
+ " device=0,\n",
+ " evaluate_batch_step_fn=test_batch_step_fn,\n",
+ ")\n",
+ "evaluator.run(5) "
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.7.13 ('fnlp-paddle')",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.13"
+ },
+ "orig_nbformat": 4,
+ "vscode": {
+ "interpreter": {
+ "hash": "31f2d9d3efc23c441973d7c4273acfea8b132b6a578f002629b6b44b8f65e720"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/docs/source/tutorials/fastnlp_tutorial_paddle_e2.ipynb b/docs/source/tutorials/fastnlp_tutorial_paddle_e2.ipynb
new file mode 100644
index 00000000..439d7f9f
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_tutorial_paddle_e2.ipynb
@@ -0,0 +1,1510 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# E4. 使用 paddlenlp 和 fastNLP 训练中文阅读理解任务\n",
+ "\n",
+ "本篇教程属于 **fastNLP v1.0 tutorial 的 paddle examples 系列**。在本篇教程中,我们将为您展示如何在 `fastNLP` 中通过自定义 `Metric` 和 损失函数来完成进阶的问答任务。\n",
+ "\n",
+ "1. 基础介绍:自然语言处理中的阅读理解任务\n",
+ "\n",
+ "2. 准备工作:加载 `DuReader-robust` 数据集,并使用 `tokenizer` 处理数据\n",
+ "\n",
+ "3. 模型训练:自己定义评测用的 `Metric` 实现更加自由的任务评测"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1. 基础介绍:自然语言处理中的阅读理解任务\n",
+ "\n",
+ "阅读理解任务,顾名思义,就是给出一段文字,然后让模型理解这段文字所含的语义。大部分机器阅读理解任务都采用问答式测评,即设计与文章内容相关的自然语言式问题,让模型理解问题并根据文章作答。与文本分类任务不同的是,在阅读理解任务中我们有时需要需要输入“一对”句子,分别代表问题和上下文;答案的格式也分为多种:\n",
+ "\n",
+ "- 多项选择:让模型从多个答案选项中选出正确答案\n",
+ "- 区间答案:答案为上下文的一段子句,需要模型给出答案的起始位置\n",
+ "- 自由回答:不做限制,让模型自行生成答案\n",
+ "- 完形填空:在原文中挖空部分关键词,让模型补全;这类答案往往不需要问题\n",
+ "\n",
+ "如果您对 `transformers` 有所了解的话,其中的 `ModelForQuestionAnswering` 系列模型就可以用于这项任务。阅读理解模型的泛用性是衡量该技术能否在实际应用中大规模落地的重要指标之一,随着当前技术的进步,许多模型虽然能够在一些测试集上取得较好的性能,但在实际应用中,这些模型仍然难以让人满意。在本篇教程中,我们将会为您展示如何训练一个问答模型。\n",
+ "\n",
+ "在这一领域,`SQuAD` 数据集是一个影响深远的数据集。它的全称是斯坦福问答数据集(Stanford Question Answering Dataset),每条数据包含 `(问题,上下文,答案)` 三部分,规模大(约十万条,2.0又新增了五万条),在提出之后很快成为训练问答任务的经典数据集之一。`SQuAD` 数据集有两个指标来衡量模型的表现:`EM`(Exact Match,精确匹配)和 `F1`(模糊匹配)。前者反应了模型给出的答案中有多少和正确答案完全一致,后者则反应了模型给出的答案中与正确答案重叠的部分,均为越高越好。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2. 准备工作:加载 DuReader-robust 数据集,并使用 tokenizer 处理数据"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/remote-home/shxing/anaconda3/envs/fnlp-paddle/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2.3.3\n"
+ ]
+ }
+ ],
+ "source": [
+ "import sys\n",
+ "sys.path.append(\"../\")\n",
+ "import paddle\n",
+ "import paddlenlp\n",
+ "\n",
+ "print(paddlenlp.__version__)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "在数据集方面,我们选用 `DuReader-robust` 中文数据集作为训练数据。它是一种抽取式问答数据集,采用 `SQuAD` 数据格式,能够评估真实应用场景下模型的泛用性。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Reusing dataset dureader_robust (/remote-home/shxing/.cache/huggingface/datasets/dureader_robust/plain_text/1.0.0/d462ecadc8c010cee20f57632f1413f272867cd802a91a602df48c7d34eb0c27)\n",
+ "Reusing dataset dureader_robust (/remote-home/shxing/.cache/huggingface/datasets/dureader_robust/plain_text/1.0.0/d462ecadc8c010cee20f57632f1413f272867cd802a91a602df48c7d34eb0c27)\n",
+ "\u001b[32m[2022-06-27 19:22:46,998] [ INFO]\u001b[0m - Already cached /remote-home/shxing/.paddlenlp/models/ernie-1.0-base-zh/vocab.txt\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'id': '0a25cb4bc1ab6f474c699884e04601e4', 'title': '', 'context': '第35集雪见缓缓张开眼睛,景天又惊又喜之际,长卿和紫萱的仙船驶至,见众人无恙,也十分高兴。众人登船,用尽合力把自身的真气和水分输给她。雪见终于醒过来了,但却一脸木然,全无反应。众人向常胤求助,却发现人世界竟没有雪见的身世纪录。长卿询问清微的身世,清微语带双关说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入。于是景天等人打下许多乌鸦,模仿重楼的翅膀,制作数对翅膀状巨物。刚佩戴在身,便被吸入洞口。众人摔落在地,抬头发现魔界守卫。景天和众魔套交情,自称和魔尊重楼相熟,众魔不理,打了起来。', 'question': '仙剑奇侠传3第几集上天界', 'answers': {'text': ['第35集'], 'answer_start': [0]}}\n",
+ "{'id': '7de192d6adf7d60ba73ba25cf590cc1e', 'title': '', 'context': '选择燃气热水器时,一定要关注这几个问题:1、出水稳定性要好,不能出现忽热忽冷的现象2、快速到达设定的需求水温3、操作要智能、方便4、安全性要好,要装有安全报警装置 市场上燃气热水器品牌众多,购买时还需多加对比和仔细鉴别。方太今年主打的磁化恒温热水器在使用体验方面做了全面升级:9秒速热,可快速进入洗浴模式;水温持久稳定,不会出现忽热忽冷的现象,并通过水量伺服技术将出水温度精确控制在±0.5℃,可满足家里宝贝敏感肌肤洗护需求;配备CO和CH4双气体报警装置更安全(市场上一般多为CO单气体报警)。另外,这款热水器还有智能WIFI互联功能,只需下载个手机APP即可用手机远程操作热水器,实现精准调节水温,满足家人多样化的洗浴需求。当然方太的磁化恒温系列主要的是增加磁化功能,可以有效吸附水中的铁锈、铁屑等微小杂质,防止细菌滋生,使沐浴水质更洁净,长期使用磁化水沐浴更利于身体健康。', 'question': '燃气热水器哪个牌子好', 'answers': {'text': ['方太'], 'answer_start': [110]}}\n",
+ "{'id': 'b9e74d4b9228399b03701d1fe6d52940', 'title': '', 'context': '迈克尔.乔丹在NBA打了15个赛季。他在84年进入nba,期间在1993年10月6日第一次退役改打棒球,95年3月18日重新回归,在99年1月13日第二次退役,后于2001年10月31日复出,在03年最终退役。迈克尔·乔丹(Michael Jordan),1963年2月17日生于纽约布鲁克林,美国著名篮球运动员,司职得分后卫,历史上最伟大的篮球运动员。1984年的NBA选秀大会,乔丹在首轮第3顺位被芝加哥公牛队选中。 1986-87赛季,乔丹场均得到37.1分,首次获得分王称号。1990-91赛季,乔丹连夺常规赛MVP和总决赛MVP称号,率领芝加哥公牛首次夺得NBA总冠军。 1997-98赛季,乔丹获得个人职业生涯第10个得分王,并率领公牛队第六次夺得总冠军。2009年9月11日,乔丹正式入选NBA名人堂。', 'question': '乔丹打了多少个赛季', 'answers': {'text': ['15个'], 'answer_start': [12]}}\n",
+ "训练集大小: 14520\n",
+ "验证集大小: 1417\n"
+ ]
+ }
+ ],
+ "source": [
+ "from paddlenlp.datasets import load_dataset\n",
+ "train_dataset = load_dataset(\"PaddlePaddle/dureader_robust\", splits=\"train\")\n",
+ "val_dataset = load_dataset(\"PaddlePaddle/dureader_robust\", splits=\"validation\")\n",
+ "for i in range(3):\n",
+ " print(train_dataset[i])\n",
+ "print(\"训练集大小:\", len(train_dataset))\n",
+ "print(\"验证集大小:\", len(val_dataset))\n",
+ "\n",
+ "MODEL_NAME = \"ernie-1.0-base-zh\"\n",
+ "from paddlenlp.transformers import ErnieTokenizer\n",
+ "tokenizer =ErnieTokenizer.from_pretrained(MODEL_NAME)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 2.1 处理训练集\n",
+ "\n",
+ "对于阅读理解任务,数据处理的方式较为麻烦。接下来我们会为您详细讲解处理函数 `_process_train` 的功能,同时也将通过实践展示关于 `tokenizer` 的更多功能,让您更加深入地了解自然语言处理任务。首先让我们向 `tokenizer` 输入一条数据(以列表的形式):"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2\n",
+ "dict_keys(['offset_mapping', 'input_ids', 'token_type_ids', 'overflow_to_sample'])\n"
+ ]
+ }
+ ],
+ "source": [
+ "result = tokenizer(\n",
+ " [train_dataset[0][\"question\"]],\n",
+ " [train_dataset[0][\"context\"]],\n",
+ " stride=128,\n",
+ " max_length=256,\n",
+ " padding=\"max_length\",\n",
+ " return_dict=False\n",
+ ")\n",
+ "\n",
+ "print(len(result))\n",
+ "print(result[0].keys())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "首先不难理解的是,模型必须要同时接受问题(`question`)和上下文(`context`)才能够进行阅读理解,因此我们需要将二者同时进行分词(`tokenize`)。所幸,`Tokenizer` 提供了这一功能,当我们调用 `tokenizer` 的时候,其第一个参数名为 `text`,第二个参数名为 `text_pair`,这使得我们可以同时对一对文本进行分词。同时,`tokenizer` 还需要标记出一条数据中哪些属于问题,哪些属于上下文,这一功能则由 `token_type_ids` 完成。`token_type_ids` 会将输入的第一个文本(问题)标记为 `0`,第二个文本(上下文)标记为 `1`,这样模型在训练时便可以将问题和上下文区分开来:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[1, 1034, 1189, 734, 2003, 241, 284, 131, 553, 271, 28, 125, 280, 2, 131, 1773, 271, 1097, 373, 1427, 1427, 501, 88, 662, 1906, 4, 561, 125, 311, 1168, 311, 692, 46, 430, 4, 84, 2073, 14, 1264, 3967, 5, 1034, 1020, 1829, 268, 4, 373, 539, 8, 154, 5210, 4, 105, 167, 59, 69, 685, 12043, 539, 8, 883, 1020, 4, 29, 720, 95, 90, 427, 67, 262, 5, 384, 266, 14, 101, 59, 789, 416, 237, 12043, 1097, 373, 616, 37, 1519, 93, 61, 15, 4, 255, 535, 7, 1529, 619, 187, 4, 62, 154, 451, 149, 12043, 539, 8, 253, 223, 3679, 323, 523, 4, 535, 34, 87, 8, 203, 280, 1186, 340, 9, 1097, 373, 5, 262, 203, 623, 704, 12043, 84, 2073, 1137, 358, 334, 702, 5, 262, 203, 4, 334, 702, 405, 360, 653, 129, 178, 7, 568, 28, 15, 125, 280, 518, 9, 1179, 487, 12043, 84, 2073, 1621, 1829, 1034, 1020, 4, 539, 8, 448, 91, 202, 466, 70, 262, 4, 638, 125, 280, 83, 299, 12043, 539, 8, 61, 45, 7, 1537, 176, 4, 84, 2073, 288, 39, 4, 889, 280, 14, 125, 280, 156, 538, 12043, 190, 889, 280, 71, 109, 124, 93, 292, 889, 46, 1248, 4, 518, 48, 883, 125, 12043, 539, 8, 268, 889, 280, 109, 270, 4, 1586, 845, 7, 669, 199, 5, 3964, 3740, 1084, 4, 255, 440, 616, 154, 72, 71, 109, 12043, 49, 61, 283, 3591, 34, 87, 297, 41, 9, 1993, 2602, 518, 52, 706, 109, 2]\n",
+ "['[CLS]', '仙', '剑', '奇', '侠', '传', '3', '第', '几', '集', '上', '天', '界', '[SEP]', '第', '35', '集', '雪', '见', '缓', '缓', '张', '开', '眼', '睛', ',', '景', '天', '又', '惊', '又', '喜', '之', '际', ',', '长', '卿', '和', '紫', '萱', '的', '仙', '船', '驶', '至', ',', '见', '众', '人', '无', '恙', ',', '也', '十', '分', '高', '兴', '。', '众', '人', '登', '船', ',', '用', '尽', '合', '力', '把', '自', '身', '的', '真', '气', '和', '水', '分', '输', '给', '她', '。', '雪', '见', '终', '于', '醒', '过', '来', '了', ',', '但', '却', '一', '脸', '木', '然', ',', '全', '无', '反', '应', '。', '众', '人', '向', '常', '胤', '求', '助', ',', '却', '发', '现', '人', '世', '界', '竟', '没', '有', '雪', '见', '的', '身', '世', '纪', '录', '。', '长', '卿', '询', '问', '清', '微', '的', '身', '世', ',', '清', '微', '语', '带', '双', '关', '说', '一', '切', '上', '了', '天', '界', '便', '有', '答', '案', '。', '长', '卿', '驾', '驶', '仙', '船', ',', '众', '人', '决', '定', '立', '马', '动', '身', ',', '往', '天', '界', '而', '去', '。', '众', '人', '来', '到', '一', '荒', '山', ',', '长', '卿', '指', '出', ',', '魔', '界', '和', '天', '界', '相', '连', '。', '由', '魔', '界', '进', '入', '通', '过', '神', '魔', '之', '井', ',', '便', '可', '登', '天', '。', '众', '人', '至', '魔', '界', '入', '口', ',', '仿', '若', '一', '黑', '色', '的', '蝙', '蝠', '洞', ',', '但', '始', '终', '无', '法', '进', '入', '。', '后', '来', '花', '楹', '发', '现', '只', '要', '有', '翅', '膀', '便', '能', '飞', '入', '[SEP]']\n",
+ "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(result[0][\"input_ids\"])\n",
+ "print(tokenizer.convert_ids_to_tokens(result[0][\"input_ids\"]))\n",
+ "print(result[0][\"token_type_ids\"])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "根据上面的输出我们可以看出,`tokenizer` 会将数据开头用 `[CLS]` 标记,用 `[SEP]` 来分割句子。同时,根据 `token_type_ids` 得到的 0、1 串,我们也很容易将问题和上下文区分开。顺带一提,如果一条数据进行了 `padding`,那么这部分会被标记为 `0` 。\n",
+ "\n",
+ "在输出的 `keys` 中还有一项名为 `offset_mapping` 的键。该项数据能够表示分词后的每个 `token` 在原文中对应文字或词语的位置。比如我们可以像下面这样将数据打印出来:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (0, 0), (0, 1), (1, 3), (3, 4), (4, 5), (5, 6), (6, 7)]\n",
+ "[1, 1034, 1189, 734, 2003, 241, 284, 131, 553, 271, 28, 125, 280, 2, 131, 1773, 271, 1097, 373, 1427]\n",
+ "['[CLS]', '仙', '剑', '奇', '侠', '传', '3', '第', '几', '集', '上', '天', '界', '[SEP]', '第', '35', '集', '雪', '见', '缓']\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(result[0][\"offset_mapping\"][:20])\n",
+ "print(result[0][\"input_ids\"][:20])\n",
+ "print(tokenizer.convert_ids_to_tokens(result[0][\"input_ids\"])[:20])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "`[CLS]` 由于是 `tokenizer` 自己添加进去用于标记数据的 `token`,因此它在原文中找不到任何对应的词语,所以给出的位置范围就是 `(0, 0)`;第二个 `token` 对应第一个 `“仙”` 字,因此映射的位置就是 `(0, 1)`;同理,后面的 `[SEP]` 也不对应任何文字,映射的位置为 `(0, 0)`;而接下来的 `token` 对应 **上下文** 中的第一个字 `“第”`,映射出的位置为 `(0, 1)`;再后面的 `token` 对应原文中的两个字符 `35`,因此其位置映射为 `(1, 3)` 。通过这种手段,我们可以更方便地获取 `token` 与原文的对应关系。\n",
+ "\n",
+ "最后,您也许会注意到我们获取的 `result` 长度为 2 。这是文本在分词后长度超过了 `max_length` 256 ,`tokenizer` 将数据分成了两部分所致。在阅读理解任务中,我们不可能像文本分类那样轻易地将一条数据截断,因为答案很可能就出现在后面被丢弃的那部分数据中,因此,我们需要保留所有的数据(当然,您也可以直接丢弃这些超长的数据)。`overflow_to_sample` 则可以标识当前数据在原数据的索引:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[CLS]仙剑奇侠传3第几集上天界[SEP]第35集雪见缓缓张开眼睛,景天又惊又喜之际,长卿和紫萱的仙船驶至,见众人无恙,也十分高兴。众人登船,用尽合力把自身的真气和水分输给她。雪见终于醒过来了,但却一脸木然,全无反应。众人向常胤求助,却发现人世界竟没有雪见的身世纪录。长卿询问清微的身世,清微语带双关说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入[SEP]\n",
+ "overflow_to_sample: 0\n",
+ "[CLS]仙剑奇侠传3第几集上天界[SEP]说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入。于是景天等人打下许多乌鸦,模仿重楼的翅膀,制作数对翅膀状巨物。刚佩戴在身,便被吸入洞口。众人摔落在地,抬头发现魔界守卫。景天和众魔套交情,自称和魔尊重楼相熟,众魔不理,打了起来。[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]\n",
+ "overflow_to_sample: 0\n"
+ ]
+ }
+ ],
+ "source": [
+ "for res in result:\n",
+ " tokens = tokenizer.convert_ids_to_tokens(res[\"input_ids\"])\n",
+ " print(\"\".join(tokens))\n",
+ " print(\"overflow_to_sample: \", res[\"overflow_to_sample\"])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "将两条数据均输出之后可以看到,它们都出自我们传入的数据,并且存在一部分重合。`tokenizer` 的 `stride` 参数可以设置重合部分的长度,这也可以帮助模型识别被分割开的两条数据;`overflow_to_sample` 的 `0` 则代表它们来自于第 `0` 条数据。\n",
+ "\n",
+ "基于以上信息,我们处理训练集的思路如下:\n",
+ "\n",
+ "1. 通过 `overflow_to_sample` 来获取原来的数据\n",
+ "2. 通过原数据的 `answers` 找到答案的起始位置\n",
+ "3. 通过 `offset_mapping` 给出的映射关系在分词处理后的数据中找到答案的起始位置,分别记录在 `start_pos` 和 `end_pos` 中;如果没有找到答案(比如答案被截断了),那么答案的起始位置就被标记为 `[CLS]` 的位置。\n",
+ "\n",
+ "这样 `_process_train` 函数就呼之欲出了,我们调用 `train_dataset.map` 函数,并将 `batched` 参数设置为 `True` ,将所有数据批量地进行更新。有一点需要注意的是,**在处理过后数据量会增加**。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'offset_mapping': [(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (0, 0), (0, 1), (1, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (19, 20), (20, 21), (21, 22), (22, 23), (23, 24), (24, 25), (25, 26), (26, 27), (27, 28), (28, 29), (29, 30), (30, 31), (31, 32), (32, 33), (33, 34), (34, 35), (35, 36), (36, 37), (37, 38), (38, 39), (39, 40), (40, 41), (41, 42), (42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 48), (48, 49), (49, 50), (50, 51), (51, 52), (52, 53), (53, 54), (54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 60), (60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 68), (68, 69), (69, 70), (70, 71), (71, 72), (72, 73), (73, 74), (74, 75), (75, 76), (76, 77), (77, 78), (78, 79), (79, 80), (80, 81), (81, 82), (82, 83), (83, 84), (84, 85), (85, 86), (86, 87), (87, 88), (88, 89), (89, 90), (90, 91), (91, 92), (92, 93), (93, 94), (94, 95), (95, 96), (96, 97), (97, 98), (98, 99), (99, 100), (100, 101), (101, 102), (102, 103), (103, 104), (104, 105), (105, 106), (106, 107), (107, 108), (108, 109), (109, 110), (110, 111), (111, 112), (112, 113), (113, 114), (114, 115), (115, 116), (116, 117), (117, 118), (118, 119), (119, 120), (120, 121), (121, 122), (122, 123), (123, 124), (124, 125), (125, 126), (126, 127), (127, 128), (128, 129), (129, 130), (130, 131), (131, 132), (132, 133), (133, 134), (134, 135), (135, 136), (136, 137), (137, 138), (138, 139), (139, 140), (140, 141), (141, 142), (142, 143), (143, 144), (144, 145), (145, 146), (146, 147), (147, 148), (148, 149), (149, 150), (150, 151), (151, 152), (152, 153), (153, 154), (154, 155), (155, 156), (156, 157), (157, 158), (158, 159), (159, 160), (160, 161), (161, 162), (162, 163), (163, 164), (164, 165), (165, 166), (166, 167), (167, 168), (168, 169), (169, 170), (170, 171), (171, 172), (172, 173), (173, 174), (174, 175), (175, 176), (176, 177), (177, 178), (178, 179), (179, 180), (180, 181), (181, 182), (182, 183), (183, 184), (184, 185), (185, 186), (186, 187), (187, 188), (188, 189), (189, 190), (190, 191), (191, 192), (192, 193), (193, 194), (194, 195), (195, 196), (196, 197), (197, 198), (198, 199), (199, 200), (200, 201), (201, 202), (202, 203), (203, 204), (204, 205), (205, 206), (206, 207), (207, 208), (208, 209), (209, 210), (210, 211), (211, 212), (212, 213), (213, 214), (214, 215), (215, 216), (216, 217), (217, 218), (218, 219), (219, 220), (220, 221), (221, 222), (222, 223), (223, 224), (224, 225), (225, 226), (226, 227), (227, 228), (228, 229), (229, 230), (230, 231), (231, 232), (232, 233), (233, 234), (234, 235), (235, 236), (236, 237), (237, 238), (238, 239), (239, 240), (240, 241), (241, 242), (0, 0)], 'input_ids': [1, 1034, 1189, 734, 2003, 241, 284, 131, 553, 271, 28, 125, 280, 2, 131, 1773, 271, 1097, 373, 1427, 1427, 501, 88, 662, 1906, 4, 561, 125, 311, 1168, 311, 692, 46, 430, 4, 84, 2073, 14, 1264, 3967, 5, 1034, 1020, 1829, 268, 4, 373, 539, 8, 154, 5210, 4, 105, 167, 59, 69, 685, 12043, 539, 8, 883, 1020, 4, 29, 720, 95, 90, 427, 67, 262, 5, 384, 266, 14, 101, 59, 789, 416, 237, 12043, 1097, 373, 616, 37, 1519, 93, 61, 15, 4, 255, 535, 7, 1529, 619, 187, 4, 62, 154, 451, 149, 12043, 539, 8, 253, 223, 3679, 323, 523, 4, 535, 34, 87, 8, 203, 280, 1186, 340, 9, 1097, 373, 5, 262, 203, 623, 704, 12043, 84, 2073, 1137, 358, 334, 702, 5, 262, 203, 4, 334, 702, 405, 360, 653, 129, 178, 7, 568, 28, 15, 125, 280, 518, 9, 1179, 487, 12043, 84, 2073, 1621, 1829, 1034, 1020, 4, 539, 8, 448, 91, 202, 466, 70, 262, 4, 638, 125, 280, 83, 299, 12043, 539, 8, 61, 45, 7, 1537, 176, 4, 84, 2073, 288, 39, 4, 889, 280, 14, 125, 280, 156, 538, 12043, 190, 889, 280, 71, 109, 124, 93, 292, 889, 46, 1248, 4, 518, 48, 883, 125, 12043, 539, 8, 268, 889, 280, 109, 270, 4, 1586, 845, 7, 669, 199, 5, 3964, 3740, 1084, 4, 255, 440, 616, 154, 72, 71, 109, 12043, 49, 61, 283, 3591, 34, 87, 297, 41, 9, 1993, 2602, 518, 52, 706, 109, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'overflow_to_sample': 0, 'start_pos': 14, 'end_pos': 16}\n",
+ "处理后的训练集大小: 26198\n"
+ ]
+ }
+ ],
+ "source": [
+ "max_length = 256\n",
+ "doc_stride = 128\n",
+ "def _process_train(data):\n",
+ "\n",
+ " contexts = [data[i][\"context\"] for i in range(len(data))]\n",
+ " questions = [data[i][\"question\"] for i in range(len(data))]\n",
+ "\n",
+ " tokenized_data_list = tokenizer(\n",
+ " questions,\n",
+ " contexts,\n",
+ " stride=doc_stride,\n",
+ " max_length=max_length,\n",
+ " padding=\"max_length\",\n",
+ " return_dict=False\n",
+ " )\n",
+ "\n",
+ " for i, tokenized_data in enumerate(tokenized_data_list):\n",
+ " # 获取 [CLS] 对应的位置\n",
+ " input_ids = tokenized_data[\"input_ids\"]\n",
+ " cls_index = input_ids.index(tokenizer.cls_token_id)\n",
+ "\n",
+ " # 在 tokenize 的过程中,汉字和 token 在位置上并非一一对应的\n",
+ " # 而 offset mapping 记录了每个 token 在原文中对应的起始位置\n",
+ " offsets = tokenized_data[\"offset_mapping\"]\n",
+ " # token_type_ids 记录了一条数据中哪些是问题,哪些是上下文\n",
+ " token_type_ids = tokenized_data[\"token_type_ids\"]\n",
+ "\n",
+ " # 一条数据可能因为长度过长而在 tokenized_data 中存在多个结果\n",
+ " # overflow_to_sample 表示了当前 tokenize_example 属于 data 中的哪一条数据\n",
+ " sample_index = tokenized_data[\"overflow_to_sample\"]\n",
+ " answers = data[sample_index][\"answers\"]\n",
+ "\n",
+ " # answers 和 answer_starts 均为长度为 1 的 list\n",
+ " # 我们可以计算出答案的结束位置\n",
+ " start_char = answers[\"answer_start\"][0]\n",
+ " end_char = start_char + len(answers[\"text\"][0])\n",
+ "\n",
+ " token_start_index = 0\n",
+ " while token_type_ids[token_start_index] != 1:\n",
+ " token_start_index += 1\n",
+ "\n",
+ " token_end_index = len(input_ids) - 1\n",
+ " while token_type_ids[token_end_index] != 1:\n",
+ " token_end_index -= 1\n",
+ " # 分词后一条数据的结尾一定是 [SEP],因此还需要减一\n",
+ " token_end_index -= 1\n",
+ "\n",
+ " if not (offsets[token_start_index][0] <= start_char and\n",
+ " offsets[token_end_index][1] >= end_char):\n",
+ " # 如果答案不在这条数据中,则将答案位置标记为 [CLS] 的位置\n",
+ " tokenized_data_list[i][\"start_pos\"] = cls_index\n",
+ " tokenized_data_list[i][\"end_pos\"] = cls_index\n",
+ " else:\n",
+ " # 否则,我们可以找到答案对应的 token 的起始位置,记录在 start_pos 和 end_pos 中\n",
+ " while token_start_index < len(offsets) and offsets[\n",
+ " token_start_index][0] <= start_char:\n",
+ " token_start_index += 1\n",
+ " tokenized_data_list[i][\"start_pos\"] = token_start_index - 1\n",
+ " while offsets[token_end_index][1] >= end_char:\n",
+ " token_end_index -= 1\n",
+ " tokenized_data_list[i][\"end_pos\"] = token_end_index + 1\n",
+ "\n",
+ " return tokenized_data_list\n",
+ "\n",
+ "train_dataset.map(_process_train, batched=True, num_workers=5)\n",
+ "print(train_dataset[0])\n",
+ "print(\"处理后的训练集大小:\", len(train_dataset))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 2.2 处理验证集\n",
+ "\n",
+ "对于验证集的处理则简单得多,我们只需要保存原数据的 `id` 并将 `offset_mapping` 中不属于上下文的部分设置为 `None` 即可。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "def _process_val(data):\n",
+ "\n",
+ " contexts = [data[i][\"context\"] for i in range(len(data))]\n",
+ " questions = [data[i][\"question\"] for i in range(len(data))]\n",
+ "\n",
+ " tokenized_data_list = tokenizer(\n",
+ " questions,\n",
+ " contexts,\n",
+ " stride=doc_stride,\n",
+ " max_length=max_length,\n",
+ " return_dict=False\n",
+ " )\n",
+ "\n",
+ " for i, tokenized_data in enumerate(tokenized_data_list):\n",
+ " token_type_ids = tokenized_data[\"token_type_ids\"]\n",
+ " # 保存数据对应的 id\n",
+ " sample_index = tokenized_data[\"overflow_to_sample\"]\n",
+ " tokenized_data_list[i][\"example_id\"] = data[sample_index][\"id\"]\n",
+ "\n",
+ " # 将不属于 context 的 offset 设置为 None\n",
+ " tokenized_data_list[i][\"offset_mapping\"] = [\n",
+ " (o if token_type_ids[k] == 1 else None)\n",
+ " for k, o in enumerate(tokenized_data[\"offset_mapping\"])\n",
+ " ]\n",
+ "\n",
+ " return tokenized_data_list\n",
+ "\n",
+ "val_dataset.map(_process_val, batched=True, num_workers=5)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 2.3 DataLoader\n",
+ "\n",
+ "最后使用 `PaddleDataLoader` 将数据集包裹起来即可。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from fastNLP.core import PaddleDataLoader\n",
+ "\n",
+ "train_dataloader = PaddleDataLoader(train_dataset, batch_size=32, shuffle=True)\n",
+ "val_dataloader = PaddleDataLoader(val_dataset, batch_size=16)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3. 模型训练:自己定义评测用的 Metric 实现更加自由的任务评测\n",
+ "\n",
+ "#### 3.1 损失函数\n",
+ "\n",
+ "对于阅读理解任务,我们使用的是 `ErnieForQuestionAnswering` 模型。该模型在接受输入后会返回两个值:`start_logits` 和 `end_logits` ,大小均为 `(batch_size, sequence_length)`,反映了每条数据每个词语为答案起始位置的可能性,因此我们需要自定义一个损失函数来计算 `loss`。 `CrossEntropyLossForSquad` 会分别对答案起始位置的预测值和真实值计算交叉熵,最后返回其平均值作为最终的损失。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class CrossEntropyLossForSquad(paddle.nn.Layer):\n",
+ " def __init__(self):\n",
+ " super(CrossEntropyLossForSquad, self).__init__()\n",
+ "\n",
+ " def forward(self, start_logits, end_logits, start_pos, end_pos):\n",
+ " start_pos = paddle.unsqueeze(start_pos, axis=-1)\n",
+ " end_pos = paddle.unsqueeze(end_pos, axis=-1)\n",
+ " start_loss = paddle.nn.functional.softmax_with_cross_entropy(\n",
+ " logits=start_logits, label=start_pos)\n",
+ " start_loss = paddle.mean(start_loss)\n",
+ " end_loss = paddle.nn.functional.softmax_with_cross_entropy(\n",
+ " logits=end_logits, label=end_pos)\n",
+ " end_loss = paddle.mean(end_loss)\n",
+ "\n",
+ " loss = (start_loss + end_loss) / 2\n",
+ " return loss"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2 定义模型\n",
+ "\n",
+ "模型的核心则是 `ErnieForQuestionAnswering` 的 `ernie-1.0-base-zh` 预训练模型,同时按照 `fastNLP` 的规定定义 `train_step` 和 `evaluate_step` 函数。这里 `evaluate_step` 函数并没有像文本分类那样直接返回该批次数据的评测结果,这一点我们将在下面为您讲解。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\u001b[32m[2022-06-27 19:00:15,825] [ INFO]\u001b[0m - Already cached /remote-home/shxing/.paddlenlp/models/ernie-1.0-base-zh/ernie_v1_chn_base.pdparams\u001b[0m\n",
+ "W0627 19:00:15.831080 21543 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 7.5, Driver API Version: 11.2, Runtime API Version: 11.2\n",
+ "W0627 19:00:15.843276 21543 gpu_context.cc:306] device: 0, cuDNN Version: 8.1.\n"
+ ]
+ }
+ ],
+ "source": [
+ "from paddlenlp.transformers import ErnieForQuestionAnswering\n",
+ "\n",
+ "class QAModel(paddle.nn.Layer):\n",
+ " def __init__(self, model_checkpoint):\n",
+ " super(QAModel, self).__init__()\n",
+ " self.model = ErnieForQuestionAnswering.from_pretrained(model_checkpoint)\n",
+ " self.loss_func = CrossEntropyLossForSquad()\n",
+ "\n",
+ " def forward(self, input_ids, token_type_ids):\n",
+ " start_logits, end_logits = self.model(input_ids, token_type_ids)\n",
+ " return start_logits, end_logits\n",
+ "\n",
+ " def train_step(self, input_ids, token_type_ids, start_pos, end_pos):\n",
+ " start_logits, end_logits = self(input_ids, token_type_ids)\n",
+ " loss = self.loss_func(start_logits, end_logits, start_pos, end_pos)\n",
+ " return {\"loss\": loss}\n",
+ "\n",
+ " def evaluate_step(self, input_ids, token_type_ids):\n",
+ " start_logits, end_logits = self(input_ids, token_type_ids)\n",
+ " return {\"start_logits\": start_logits, \"end_logits\": end_logits}\n",
+ "\n",
+ "model = QAModel(MODEL_NAME)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.3 自定义 Metric 进行数据的评估\n",
+ "\n",
+ "`paddlenlp` 为我们提供了评测 `SQuAD` 格式数据集的函数 `compute_prediction` 和 `squad_evaluate`:\n",
+ "- `compute_prediction` 函数要求传入原数据 `examples` 、处理后的数据 `features` 和 `features` 对应的结果 `predictions`(一个包含所有数据 `start_logits` 和 `end_logits` 的元组)\n",
+ "- `squad_evaluate` 要求传入原数据 `examples` 和预测结果 `all_predictions`(通常来自于 `compute_prediction`)\n",
+ "\n",
+ "在使用这两个函数的时候,我们需要向其中传入数据集,但显然根据 `fastNLP` 的设计,我们无法在 `evaluate_step` 里实现这一过程,并且 `fastNLP` 也并没有提供计算 `F1` 和 `EM` 的 `Metric`,故我们需要自己定义用于评测的 `Metric`。\n",
+ "\n",
+ "在初始化之外,一个 `Metric` 还需要实现三个函数:\n",
+ "\n",
+ "1. `reset` - 该函数会在验证数据集的迭代之前被调用,用于清空数据;在我们自定义的 `Metric` 中,我们需要将 `all_start_logits` 和 `all_end_logits` 清空,重新收集每个 `batch` 的结果。\n",
+ "2. `update` - 该函数会在在每个 `batch` 得到结果后被调用,用于更新 `Metric` 的状态;它的参数即为 `evaluate_step` 返回的内容。我们在这里将得到的 `start_logits` 和 `end_logits` 收集起来。\n",
+ "3. `get_metric` - 该函数会在数据集被迭代完毕后调用,用于计算评测的结果。现在我们有了整个验证集的 `all_start_logits` 和 `all_end_logits` ,将他们传入 `compute_predictions` 函数得到预测的结果,并继续使用 `squad_evaluate` 函数得到评测的结果。\n",
+ " - 注:`suqad_evaluate` 函数会自己输出评测结果,为了不让其干扰 `fastNLP` 输出,这里我们使用 `contextlib.redirect_stdout(None)` 将函数的标准输出屏蔽掉。\n",
+ "\n",
+ "综上,`SquadEvaluateMetric` 实现的评估过程是:将验证集中所有数据的 `logits` 收集起来,然后统一传入 `compute_prediction` 和 `squad_evaluate` 中进行评估。值得一提的是,`paddlenlp.datasets.load_dataset` 返回的结果是一个 `MapDataset` 类型,其 `data` 成员为加载时的数据,`new_data` 为经过 `map` 函数处理后更新的数据,因此可以分别作为 `examples` 和 `features` 传入。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from fastNLP.core import Metric\n",
+ "from paddlenlp.metrics.squad import squad_evaluate, compute_prediction\n",
+ "import contextlib\n",
+ "\n",
+ "class SquadEvaluateMetric(Metric):\n",
+ " def __init__(self, examples, features, testing=False):\n",
+ " super(SquadEvaluateMetric, self).__init__(\"paddle\", False)\n",
+ " self.examples = examples\n",
+ " self.features = features\n",
+ " self.all_start_logits = []\n",
+ " self.all_end_logits = []\n",
+ " self.testing = testing\n",
+ "\n",
+ " def reset(self):\n",
+ " self.all_start_logits = []\n",
+ " self.all_end_logits = []\n",
+ "\n",
+ " def update(self, start_logits, end_logits):\n",
+ " for start, end in zip(start_logits, end_logits):\n",
+ " self.all_start_logits.append(start.numpy())\n",
+ " self.all_end_logits.append(end.numpy())\n",
+ "\n",
+ " def get_metric(self):\n",
+ " all_predictions, _, _ = compute_prediction(\n",
+ " self.examples, self.features[:len(self.all_start_logits)],\n",
+ " (self.all_start_logits, self.all_end_logits),\n",
+ " False, 20, 30\n",
+ " )\n",
+ " with contextlib.redirect_stdout(None):\n",
+ " result = squad_evaluate(\n",
+ " examples=self.examples,\n",
+ " preds=all_predictions,\n",
+ " is_whitespace_splited=False\n",
+ " )\n",
+ "\n",
+ " if self.testing:\n",
+ " self.print_predictions(all_predictions)\n",
+ " return result\n",
+ "\n",
+ " def print_predictions(self, preds):\n",
+ " for i, data in enumerate(self.examples):\n",
+ " if i >= 5:\n",
+ " break\n",
+ " print()\n",
+ " print(\"原文:\", data[\"context\"])\n",
+ " print(\"问题:\", data[\"question\"], \\\n",
+ " \"答案:\", preds[data[\"id\"]], \\\n",
+ " \"正确答案:\", data[\"answers\"][\"text\"])\n",
+ "\n",
+ "metric = SquadEvaluateMetric(\n",
+ " val_dataloader.dataset.data,\n",
+ " val_dataloader.dataset.new_data,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.4 训练\n",
+ "\n",
+ "至此所有的准备工作已经完成,可以使用 `Trainer` 进行训练了。学习率我们依旧采用线性预热策略 `LinearDecayWithWarmup`,优化器为 `AdamW`;回调模块我们选择 `LRSchedCallback` 更新学习率和 `LoadBestModelCallback` 监视评测结果的 `f1` 分数。初始化好 `Trainer` 之后,就将训练的过程交给 `fastNLP` 吧。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "[19:04:54] INFO Running evaluator sanity check for 2 batches. trainer.py:631\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m[19:04:54]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=367046;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=96810;file://../fastNLP/core/controllers/trainer.py#631\u001b\\\u001b[2m631\u001b[0m\u001b]8;;\u001b\\\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:0, Batch:100 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m100\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"exact#squad\": 49.25899788285109,\n",
+ " \"f1#squad\": 66.55559127349602,\n",
+ " \"total#squad\": 1417,\n",
+ " \"HasAns_exact#squad\": 49.25899788285109,\n",
+ " \"HasAns_f1#squad\": 66.55559127349602,\n",
+ " \"HasAns_total#squad\": 1417\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m49.25899788285109\u001b[0m,\n",
+ " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m66.55559127349602\u001b[0m,\n",
+ " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m49.25899788285109\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m66.55559127349602\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:0, Batch:200 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m200\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"exact#squad\": 57.37473535638673,\n",
+ " \"f1#squad\": 70.93036525200617,\n",
+ " \"total#squad\": 1417,\n",
+ " \"HasAns_exact#squad\": 57.37473535638673,\n",
+ " \"HasAns_f1#squad\": 70.93036525200617,\n",
+ " \"HasAns_total#squad\": 1417\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m57.37473535638673\u001b[0m,\n",
+ " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m70.93036525200617\u001b[0m,\n",
+ " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m57.37473535638673\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m70.93036525200617\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:0, Batch:300 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m300\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"exact#squad\": 63.86732533521524,\n",
+ " \"f1#squad\": 78.62546663568186,\n",
+ " \"total#squad\": 1417,\n",
+ " \"HasAns_exact#squad\": 63.86732533521524,\n",
+ " \"HasAns_f1#squad\": 78.62546663568186,\n",
+ " \"HasAns_total#squad\": 1417\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m63.86732533521524\u001b[0m,\n",
+ " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m78.62546663568186\u001b[0m,\n",
+ " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m63.86732533521524\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m78.62546663568186\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:0, Batch:400 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m400\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"exact#squad\": 64.92589978828511,\n",
+ " \"f1#squad\": 79.36746074079691,\n",
+ " \"total#squad\": 1417,\n",
+ " \"HasAns_exact#squad\": 64.92589978828511,\n",
+ " \"HasAns_f1#squad\": 79.36746074079691,\n",
+ " \"HasAns_total#squad\": 1417\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m64.92589978828511\u001b[0m,\n",
+ " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m79.36746074079691\u001b[0m,\n",
+ " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m64.92589978828511\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m79.36746074079691\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:0, Batch:500 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m500\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"exact#squad\": 65.70218772053634,\n",
+ " \"f1#squad\": 80.33295482054824,\n",
+ " \"total#squad\": 1417,\n",
+ " \"HasAns_exact#squad\": 65.70218772053634,\n",
+ " \"HasAns_f1#squad\": 80.33295482054824,\n",
+ " \"HasAns_total#squad\": 1417\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n",
+ " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n",
+ " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:0, Batch:600 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m600\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"exact#squad\": 65.41990119971771,\n",
+ " \"f1#squad\": 79.7483487059053,\n",
+ " \"total#squad\": 1417,\n",
+ " \"HasAns_exact#squad\": 65.41990119971771,\n",
+ " \"HasAns_f1#squad\": 79.7483487059053,\n",
+ " \"HasAns_total#squad\": 1417\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m65.41990119971771\u001b[0m,\n",
+ " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m79.7483487059053\u001b[0m,\n",
+ " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m65.41990119971771\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m79.7483487059053\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:0, Batch:700 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m700\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"exact#squad\": 66.61961891319689,\n",
+ " \"f1#squad\": 80.32432238994133,\n",
+ " \"total#squad\": 1417,\n",
+ " \"HasAns_exact#squad\": 66.61961891319689,\n",
+ " \"HasAns_f1#squad\": 80.32432238994133,\n",
+ " \"HasAns_total#squad\": 1417\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m66.61961891319689\u001b[0m,\n",
+ " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m80.32432238994133\u001b[0m,\n",
+ " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m66.61961891319689\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m80.32432238994133\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:0, Batch:800 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m800\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"exact#squad\": 65.84333098094567,\n",
+ " \"f1#squad\": 79.23169801265415,\n",
+ " \"total#squad\": 1417,\n",
+ " \"HasAns_exact#squad\": 65.84333098094567,\n",
+ " \"HasAns_f1#squad\": 79.23169801265415,\n",
+ " \"HasAns_total#squad\": 1417\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m65.84333098094567\u001b[0m,\n",
+ " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m79.23169801265415\u001b[0m,\n",
+ " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m65.84333098094567\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m79.23169801265415\u001b[0m,\n",
+ " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "[19:20:28] INFO Loading best model from fnlp-ernie-squad/ load_best_model_callback.py:111\n",
+ " 2022-06-27-19_00_15_388554/best_so_far \n",
+ " with f1#squad: 80.33295482054824... \n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m[19:20:28]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Loading best model from fnlp-ernie-squad/ \u001b]8;id=163935;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=31503;file://../fastNLP/core/callbacks/load_best_model_callback.py#111\u001b\\\u001b[2m111\u001b[0m\u001b]8;;\u001b\\\n",
+ "\u001b[2;36m \u001b[0m \u001b[1;36m2022\u001b[0m-\u001b[1;36m06\u001b[0m-\u001b[1;36m27\u001b[0m-19_00_15_388554/best_so_far \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m with f1#squad: \u001b[1;36m80.33295482054824\u001b[0m\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " INFO Deleting fnlp-ernie-squad/2022-06-27-19_0 load_best_model_callback.py:131\n",
+ " 0_15_388554/best_so_far... \n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Deleting fnlp-ernie-squad/\u001b[1;36m2022\u001b[0m-\u001b[1;36m06\u001b[0m-\u001b[1;36m27\u001b[0m-19_0 \u001b]8;id=560859;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=573263;file://../fastNLP/core/callbacks/load_best_model_callback.py#131\u001b\\\u001b[2m131\u001b[0m\u001b]8;;\u001b\\\n",
+ "\u001b[2;36m \u001b[0m 0_15_388554/best_so_far\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from fastNLP import Trainer, LRSchedCallback, LoadBestModelCallback\n",
+ "from paddlenlp.transformers import LinearDecayWithWarmup\n",
+ "\n",
+ "n_epochs = 1\n",
+ "num_training_steps = len(train_dataloader) * n_epochs\n",
+ "lr_scheduler = LinearDecayWithWarmup(3e-5, num_training_steps, 0.1)\n",
+ "optimizer = paddle.optimizer.AdamW(\n",
+ " learning_rate=lr_scheduler,\n",
+ " parameters=model.parameters(),\n",
+ ")\n",
+ "callbacks=[\n",
+ " LRSchedCallback(lr_scheduler, step_on=\"batch\"),\n",
+ " LoadBestModelCallback(\"f1#squad\", larger_better=True, save_folder=\"fnlp-ernie-squad\")\n",
+ "]\n",
+ "trainer = Trainer(\n",
+ " model=model,\n",
+ " train_dataloader=train_dataloader,\n",
+ " evaluate_dataloaders=val_dataloader,\n",
+ " device=1,\n",
+ " optimizers=optimizer,\n",
+ " n_epochs=n_epochs,\n",
+ " callbacks=callbacks,\n",
+ " evaluate_every=100,\n",
+ " metrics={\"squad\": metric},\n",
+ ")\n",
+ "trainer.run()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.5 测试\n",
+ "\n",
+ "最后,我们可以使用 `Evaluator` 查看我们训练的结果。我们在之前为 `SquadEvaluateMetric` 设置了 `testing` 参数来在测试阶段进行输出,可以看到,训练的结果还是比较不错的。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "原文: 爬行垫根据中间材料的不同可以分为:XPE爬行垫、EPE爬行垫、EVA爬行垫、PVC爬行垫;其中XPE爬\n",
+ "行垫、EPE爬行垫都属于PE材料加保鲜膜复合而成,都是无异味的环保材料,但是XPE爬行垫是品质较好的爬\n",
+ "行垫,韩国进口爬行垫都是这种爬行垫,而EPE爬行垫是国内厂家为了减低成本,使用EPE(珍珠棉)作为原料生\n",
+ "产的一款爬行垫,该材料弹性差,易碎,开孔发泡防水性弱。EVA爬行垫、PVC爬行垫是用EVA或PVC作为原材料\n",
+ "与保鲜膜复合的而成的爬行垫,或者把图案转印在原材料上,这两款爬行垫通常有异味,如果是图案转印的爬\n",
+ "行垫,油墨外露容易脱落。 \n",
+ "当时我儿子爬的时候,我们也买了垫子,但是始终有味。最后就没用了,铺的就的薄毯子让他爬。\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "原文: 爬行垫根据中间材料的不同可以分为:XPE爬行垫、EPE爬行垫、EVA爬行垫、PVC爬行垫;其中XPE爬\n",
+ "行垫、EPE爬行垫都属于PE材料加保鲜膜复合而成,都是无异味的环保材料,但是XPE爬行垫是品质较好的爬\n",
+ "行垫,韩国进口爬行垫都是这种爬行垫,而EPE爬行垫是国内厂家为了减低成本,使用EPE(珍珠棉)作为原料生\n",
+ "产的一款爬行垫,该材料弹性差,易碎,开孔发泡防水性弱。EVA爬行垫、PVC爬行垫是用EVA或PVC作为原材料\n",
+ "与保鲜膜复合的而成的爬行垫,或者把图案转印在原材料上,这两款爬行垫通常有异味,如果是图案转印的爬\n",
+ "行垫,油墨外露容易脱落。 \n",
+ "当时我儿子爬的时候,我们也买了垫子,但是始终有味。最后就没用了,铺的就的薄毯子让他爬。\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "问题: 爬行垫什么材质的好 答案: EPE(珍珠棉 正确答案: ['XPE']\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "问题: 爬行垫什么材质的好 答案: EPE(珍珠棉 正确答案: ['XPE']\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "原文: 真实情况是160-162。她平时谎报的168是因为不离脚穿高水台恨天高(15厘米) 图1她穿着高水台恨\n",
+ "天高和刘亦菲一样高,(刘亦菲对外报身高172)范冰冰礼服下厚厚的高水台暴露了她的心机,对比一下两者的\n",
+ "鞋子吧 图2 穿着高水台恨天高才和刘德华谢霆锋持平,如果她真的有168,那么加上鞋高,刘和谢都要有180?\n",
+ "明显是不可能的。所以刘德华对外报的身高174减去10-15厘米才是范冰冰的真实身高 图3,范冰冰有一次脱\n",
+ "鞋上场,这个最说明问题了,看看她的身体比例吧。还有目测一下她手上鞋子的鞋跟有多高多厚吧,至少超过\n",
+ "10厘米。\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "原文: 真实情况是160-162。她平时谎报的168是因为不离脚穿高水台恨天高(15厘米) 图1她穿着高水台恨\n",
+ "天高和刘亦菲一样高,(刘亦菲对外报身高172)范冰冰礼服下厚厚的高水台暴露了她的心机,对比一下两者的\n",
+ "鞋子吧 图2 穿着高水台恨天高才和刘德华谢霆锋持平,如果她真的有168,那么加上鞋高,刘和谢都要有180?\n",
+ "明显是不可能的。所以刘德华对外报的身高174减去10-15厘米才是范冰冰的真实身高 图3,范冰冰有一次脱\n",
+ "鞋上场,这个最说明问题了,看看她的身体比例吧。还有目测一下她手上鞋子的鞋跟有多高多厚吧,至少超过\n",
+ "10厘米。\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "问题: 范冰冰多高真实身高 答案: 160-162 正确答案: ['160-162']\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "问题: 范冰冰多高真实身高 答案: 160-162 正确答案: ['160-162']\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "原文: 防水作为目前高端手机的标配,特别是苹果也支持防水之后,国产大多数高端旗舰手机都已经支持防\n",
+ "水。虽然我们真的不会故意把手机放入水中,但是有了防水之后,用户心里会多一重安全感。那么近日最为\n",
+ "火热的小米6防水吗?小米6的防水级别又是多少呢? 小编查询了很多资料发现,小米6确实是防水的,但是为\n",
+ "了保持低调,同时为了不被别人说防水等级不够,很多资料都没有标注小米是否防水。根据评测资料显示,小\n",
+ "米6是支持IP68级的防水,是绝对能够满足日常生活中的防水需求的。\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "原文: 防水作为目前高端手机的标配,特别是苹果也支持防水之后,国产大多数高端旗舰手机都已经支持防\n",
+ "水。虽然我们真的不会故意把手机放入水中,但是有了防水之后,用户心里会多一重安全感。那么近日最为\n",
+ "火热的小米6防水吗?小米6的防水级别又是多少呢? 小编查询了很多资料发现,小米6确实是防水的,但是为\n",
+ "了保持低调,同时为了不被别人说防水等级不够,很多资料都没有标注小米是否防水。根据评测资料显示,小\n",
+ "米6是支持IP68级的防水,是绝对能够满足日常生活中的防水需求的。\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "问题: 小米6防水等级 答案: IP68级 正确答案: ['IP68级']\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "问题: 小米6防水等级 答案: IP68级 正确答案: ['IP68级']\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "原文: 这位朋友你好,女性出现妊娠反应一般是从6-12周左右,也就是女性怀孕1个多月就会开始出现反应,\n",
+ "第3个月的时候,妊辰反应基本结束。 而大部分女性怀孕初期都会出现恶心、呕吐的感觉,这些症状都是因\n",
+ "人而异的,除非恶心、呕吐的非常厉害,才需要就医,否则这些都是刚怀孕的的正常症状。1-3个月的时候可\n",
+ "以观察一下自己的皮肤,一般女性怀孕初期可能会产生皮肤色素沉淀或是腹壁产生妊娠纹,特别是在怀孕的\n",
+ "后期更加明显。 还有很多女性怀孕初期会出现疲倦、嗜睡的情况。怀孕三个月的时候,膀胱会受到日益胀\n",
+ "大的子宫的压迫,容量会变小,所以怀孕期间也会有尿频的现象出现。月经停止也是刚怀孕最容易出现的症\n",
+ "状,只要是平时月经正常的女性,在性行为后超过正常经期两周,就有可能是怀孕了。 如果你想判断自己是\n",
+ "否怀孕,可以看看自己有没有这些反应。当然这也只是多数人的怀孕表现,也有部分女性怀孕表现并不完全\n",
+ "是这样,如果你无法确定自己是否怀孕,最好去医院检查一下。\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "原文: 这位朋友你好,女性出现妊娠反应一般是从6-12周左右,也就是女性怀孕1个多月就会开始出现反应,\n",
+ "第3个月的时候,妊辰反应基本结束。 而大部分女性怀孕初期都会出现恶心、呕吐的感觉,这些症状都是因\n",
+ "人而异的,除非恶心、呕吐的非常厉害,才需要就医,否则这些都是刚怀孕的的正常症状。1-3个月的时候可\n",
+ "以观察一下自己的皮肤,一般女性怀孕初期可能会产生皮肤色素沉淀或是腹壁产生妊娠纹,特别是在怀孕的\n",
+ "后期更加明显。 还有很多女性怀孕初期会出现疲倦、嗜睡的情况。怀孕三个月的时候,膀胱会受到日益胀\n",
+ "大的子宫的压迫,容量会变小,所以怀孕期间也会有尿频的现象出现。月经停止也是刚怀孕最容易出现的症\n",
+ "状,只要是平时月经正常的女性,在性行为后超过正常经期两周,就有可能是怀孕了。 如果你想判断自己是\n",
+ "否怀孕,可以看看自己有没有这些反应。当然这也只是多数人的怀孕表现,也有部分女性怀孕表现并不完全\n",
+ "是这样,如果你无法确定自己是否怀孕,最好去医院检查一下。\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "问题: 怀孕多久会有反应 答案: 6-12周左右 正确答案: ['6-12周左右', '6-12周', '1个多月']\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "问题: 怀孕多久会有反应 答案: 6-12周左右 正确答案: ['6-12周左右', '6-12周', '1个多月']\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "原文: 【东奥会计在线——中级会计职称频道推荐】根据《关于提高科技型中小企业研究开发费用税前加计\n",
+ "扣除比例的通知》的规定,研发费加计扣除比例提高到75%。|财政部、国家税务总局、科技部发布《关于提\n",
+ "高科技型中小企业研究开发费用税前加计扣除比例的通知》。|通知称,为进一步激励中小企业加大研发投\n",
+ "入,支持科技创新,就提高科技型中小企业研究开发费用(以下简称研发费用)税前加计扣除比例有关问题发\n",
+ "布通知。|通知明确,科技型中小企业开展研发活动中实际发生的研发费用,未形成无形资产计入当期损益的\n",
+ ",在按规定据实扣除的基础上,在2017年1月1日至2019年12月31日期间,再按照实际发生额的75%在税前加计\n",
+ "扣除;形成无形资产的,在上述期间按照无形资产成本的175%在税前摊销。|科技型中小企业享受研发费用税\n",
+ "前加计扣除政策的其他政策口径按照《财政部国家税务总局科技部关于完善研究开发费用税前加计扣除政\n",
+ "策的通知》(财税〔2015〕119号)规定执行。|科技型中小企业条件和管理办法由科技部、财政部和国家税\n",
+ "务总局另行发布。科技、财政和税务部门应建立信息共享机制,及时共享科技型中小企业的相关信息,加强\n",
+ "协调配合,保障优惠政策落实到位。|上一篇文章:关于2016年度企业研究开发费用税前加计扣除政策企业所\n",
+ "得税纳税申报问题的公告 下一篇文章:关于提高科技型中小企业研究开发费用税前加计扣除比例的通知\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "原文: 【东奥会计在线——中级会计职称频道推荐】根据《关于提高科技型中小企业研究开发费用税前加计\n",
+ "扣除比例的通知》的规定,研发费加计扣除比例提高到75%。|财政部、国家税务总局、科技部发布《关于提\n",
+ "高科技型中小企业研究开发费用税前加计扣除比例的通知》。|通知称,为进一步激励中小企业加大研发投\n",
+ "入,支持科技创新,就提高科技型中小企业研究开发费用(以下简称研发费用)税前加计扣除比例有关问题发\n",
+ "布通知。|通知明确,科技型中小企业开展研发活动中实际发生的研发费用,未形成无形资产计入当期损益的\n",
+ ",在按规定据实扣除的基础上,在2017年1月1日至2019年12月31日期间,再按照实际发生额的75%在税前加计\n",
+ "扣除;形成无形资产的,在上述期间按照无形资产成本的175%在税前摊销。|科技型中小企业享受研发费用税\n",
+ "前加计扣除政策的其他政策口径按照《财政部国家税务总局科技部关于完善研究开发费用税前加计扣除政\n",
+ "策的通知》(财税〔2015〕119号)规定执行。|科技型中小企业条件和管理办法由科技部、财政部和国家税\n",
+ "务总局另行发布。科技、财政和税务部门应建立信息共享机制,及时共享科技型中小企业的相关信息,加强\n",
+ "协调配合,保障优惠政策落实到位。|上一篇文章:关于2016年度企业研究开发费用税前加计扣除政策企业所\n",
+ "得税纳税申报问题的公告 下一篇文章:关于提高科技型中小企业研究开发费用税前加计扣除比例的通知\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "问题: 研发费用加计扣除比例 答案: 75% 正确答案: ['75%']\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "问题: 研发费用加计扣除比例 答案: 75% 正确答案: ['75%']\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " 'exact#squad': 65.70218772053634,\n",
+ " 'f1#squad': 80.33295482054824,\n",
+ " 'total#squad': 1417,\n",
+ " 'HasAns_exact#squad': 65.70218772053634,\n",
+ " 'HasAns_f1#squad': 80.33295482054824,\n",
+ " 'HasAns_total#squad': 1417\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[32m'exact#squad'\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n",
+ " \u001b[32m'f1#squad'\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n",
+ " \u001b[32m'total#squad'\u001b[0m: \u001b[1;36m1417\u001b[0m,\n",
+ " \u001b[32m'HasAns_exact#squad'\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n",
+ " \u001b[32m'HasAns_f1#squad'\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n",
+ " \u001b[32m'HasAns_total#squad'\u001b[0m: \u001b[1;36m1417\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from fastNLP import Evaluator\n",
+ "evaluator = Evaluator(\n",
+ " model=model,\n",
+ " dataloaders=val_dataloader,\n",
+ " device=1,\n",
+ " metrics={\n",
+ " \"squad\": SquadEvaluateMetric(\n",
+ " val_dataloader.dataset.data,\n",
+ " val_dataloader.dataset.new_data,\n",
+ " testing=True,\n",
+ " ),\n",
+ " },\n",
+ ")\n",
+ "result = evaluator.run()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.7.13 ('fnlp-paddle')",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.13"
+ },
+ "orig_nbformat": 4,
+ "vscode": {
+ "interpreter": {
+ "hash": "31f2d9d3efc23c441973d7c4273acfea8b132b6a578f002629b6b44b8f65e720"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/docs/source/tutorials/figures/E1-fig-glue-benchmark.png b/docs/source/tutorials/figures/E1-fig-glue-benchmark.png
new file mode 100644
index 00000000..515db700
Binary files /dev/null and b/docs/source/tutorials/figures/E1-fig-glue-benchmark.png differ
diff --git a/docs/source/tutorials/figures/E2-fig-p-tuning-v2-model.png b/docs/source/tutorials/figures/E2-fig-p-tuning-v2-model.png
new file mode 100644
index 00000000..b5a9c1b8
Binary files /dev/null and b/docs/source/tutorials/figures/E2-fig-p-tuning-v2-model.png differ
diff --git a/docs/source/tutorials/figures/E2-fig-pet-model.png b/docs/source/tutorials/figures/E2-fig-pet-model.png
new file mode 100644
index 00000000..c3c377c0
Binary files /dev/null and b/docs/source/tutorials/figures/E2-fig-pet-model.png differ
diff --git a/docs/source/tutorials/figures/T0-fig-parameter-matching.png b/docs/source/tutorials/figures/T0-fig-parameter-matching.png
new file mode 100644
index 00000000..24013cc1
Binary files /dev/null and b/docs/source/tutorials/figures/T0-fig-parameter-matching.png differ
diff --git a/docs/source/tutorials/figures/T0-fig-trainer-and-evaluator.png b/docs/source/tutorials/figures/T0-fig-trainer-and-evaluator.png
new file mode 100644
index 00000000..38222ee8
Binary files /dev/null and b/docs/source/tutorials/figures/T0-fig-trainer-and-evaluator.png differ
diff --git a/docs/source/tutorials/figures/T0-fig-training-structure.png b/docs/source/tutorials/figures/T0-fig-training-structure.png
new file mode 100644
index 00000000..edc2e2ff
Binary files /dev/null and b/docs/source/tutorials/figures/T0-fig-training-structure.png differ
diff --git a/docs/source/tutorials/figures/T1-fig-dataset-and-vocabulary.png b/docs/source/tutorials/figures/T1-fig-dataset-and-vocabulary.png
new file mode 100644
index 00000000..803cf34a
Binary files /dev/null and b/docs/source/tutorials/figures/T1-fig-dataset-and-vocabulary.png differ
diff --git a/docs/source/tutorials/figures/paddle-ernie-1.0-masking-levels.png b/docs/source/tutorials/figures/paddle-ernie-1.0-masking-levels.png
new file mode 100644
index 00000000..ff2519c4
Binary files /dev/null and b/docs/source/tutorials/figures/paddle-ernie-1.0-masking-levels.png differ
diff --git a/docs/source/tutorials/figures/paddle-ernie-1.0-masking.png b/docs/source/tutorials/figures/paddle-ernie-1.0-masking.png
new file mode 100644
index 00000000..ed003a2f
Binary files /dev/null and b/docs/source/tutorials/figures/paddle-ernie-1.0-masking.png differ
diff --git a/docs/source/tutorials/figures/paddle-ernie-2.0-continual-pretrain.png b/docs/source/tutorials/figures/paddle-ernie-2.0-continual-pretrain.png
new file mode 100644
index 00000000..d45f65d8
Binary files /dev/null and b/docs/source/tutorials/figures/paddle-ernie-2.0-continual-pretrain.png differ
diff --git a/docs/source/tutorials/figures/paddle-ernie-3.0-framework.png b/docs/source/tutorials/figures/paddle-ernie-3.0-framework.png
new file mode 100644
index 00000000..f50ddb1c
Binary files /dev/null and b/docs/source/tutorials/figures/paddle-ernie-3.0-framework.png differ
diff --git a/docs/source/tutorials/tutorial_1_data_preprocess.rst b/docs/source/tutorials/tutorial_1_data_preprocess.rst
deleted file mode 100644
index d9132546..00000000
--- a/docs/source/tutorials/tutorial_1_data_preprocess.rst
+++ /dev/null
@@ -1,172 +0,0 @@
-==============================
-fastNLP中的DataSet
-==============================
-
-:class:`~fastNLP.DataSet` 是fastNLP用于承载数据的类,一般训练集、验证集和测试集会被加载为三个单独的 :class:`~fastNLP.DataSet` 对象。
-
-:class:`~fastNLP.DataSet` 中的数据组织形式类似一个表格,比如下面 :class:`~fastNLP.DataSet` 一共有3列,列在fastNLP中被称为field。
-
-.. csv-table::
- :header: "raw_chars", "chars", "seq_len"
-
- "历任公司副总经理、总工程师,", "[历 任 公 司 副 总 经 理 、 总 工 程 师 ,]", 6
- "Third instance .", "[Third, instance, .]", 3
- "...", "[...]", "..."
-
-每一行是一个instance (在fastNLP中被称为 :mod:`~fastNLP.core.Instance` ),
-每一列是一个field (在fastNLP中称为 :mod:`~fastNLP.core.FieldArray` )。
-
-DataSet的构建
------------------------------
-
-我们使用传入字典的方式初始化一个DataSet,这是 :class:`~fastNLP.DataSet` 初始化的最基础的方式
-
-.. code-block:: python
-
- from fastNLP import DataSet
- data = {'raw_words':["This is the first instance .", "Second instance .", "Third instance ."],
- 'words': [['this', 'is', 'the', 'first', 'instance', '.'], ['Second', 'instance', '.'], ['Third', 'instance', '.']],
- 'seq_len': [6, 3, 3]}
- dataset = DataSet(data)
- # 传入的dict的每个key的value应该为具有相同长度的list
- print(dataset)
-
-输出为::
-
- +------------------------------+------------------------------------------------+---------+
- | raw_words | words | seq_len |
- +------------------------------+------------------------------------------------+---------+
- | This is the first instance . | ['this', 'is', 'the', 'first', 'instance', ... | 6 |
- | Second instance . | ['Second', 'instance', '.'] | 3 |
- | Third instance . | ['Third', 'instance', '.'] | 3 |
- +------------------------------+------------------------------------------------+---------+
-
-
-我们还可以使用 :func:`~fastNLP.DataSet.append` 方法向DataSet增加数据
-
-.. code-block:: python
-
- from fastNLP import DataSet
- from fastNLP import Instance
- dataset = DataSet()
- instance = Instance(raw_words="This is the first instance",
- words=['this', 'is', 'the', 'first', 'instance', '.'],
- seq_len=6)
- dataset.append(instance)
- # 可以继续append更多内容,但是append的instance应该和前面的instance拥有完全相同的field
-
-另外,我们还可以用 :class:`~fastNLP.Instance` 数组的方式构建DataSet
-
-.. code-block:: python
-
- from fastNLP import DataSet
- from fastNLP import Instance
- dataset = DataSet([
- Instance(raw_words="This is the first instance",
- words=['this', 'is', 'the', 'first', 'instance', '.'],
- seq_len=6),
- Instance(raw_words="Second instance .",
- words=['Second', 'instance', '.'],
- seq_len=3)
- ])
-
-在初步构建完DataSet之后,我们可以通过 `for` 循环遍历 :class:`~fastNLP.DataSet` 中的内容。
-
-.. code-block:: python
-
- for instance in dataset:
- # do something
-
-DataSet的删除
------------------------------
-
-FastNLP 同样提供了多种删除数据的方法 :func:`~fastNLP.DataSet.drop` 、 :func:`~fastNLP.DataSet.delete_instance` 和 :func:`~fastNLP.DataSet.delete_field`
-我们先用下面的代码生成一个只有两列的样例DataSet,第一列的值分别为 -5 ~ 4,第二列的值均为 0.
-
-.. code-block:: python
-
- from fastNLP import DataSet
- dataset = DataSet({'a': range(-5, 5), 'c': [0]*10})
-
-然后我们使用三种方法进行删除,删除后的DataSet仅包含名为 c 的一列,包含4个值为0 的数据。
-
-.. code-block:: python
-
- # 不改变dataset,生成一个删除了满足条件的instance的新 DataSet
- dropped_dataset = dataset.drop(lambda ins:ins['a']<0, inplace=False)
- # 在dataset中删除满足条件的instance
- dataset.drop(lambda ins:ins['a']<0)
- # 删除第3个instance
- dataset.delete_instance(2)
- # 删除名为'a'的field
- dataset.delete_field('a')
-
-
-简单的数据预处理
------------------------------
-
-因为 fastNLP 中的数据是按列存储的,所以大部分的数据预处理操作是以列( :mod:`~fastNLP.core.field` )为操作对象的。
-首先,我们可以检查特定名称的 :mod:`~fastNLP.core.field` 是否存在,并对其进行改名。
-
-.. code-block:: python
-
- # 检查是否存在名为'a'的field
- dataset.has_field('a') # 或 ('a' in dataset)
- # 将名为'c'的field改名为'b'
- dataset.rename_field('c', 'b')
- # DataSet的长度
- len(dataset)
-
-其次,我们可以使用 :func:`~fastNLP.DataSet.apply` 或 :func:`~fastNLP.DataSet.apply_field` 进行数据预处理操作操作。
-使用以上的两个方法需要传入一个函数,函数可以是 lambda 匿名函数,也可以是完整定义的函数,fastNLP将对DataSet遍历地应用该函数。
-同时,你还可以用 ``new_field_name`` 参数指定函数返回值组成的新 :mod:`~fastNLP.core.field` 的名称。
-
-.. code-block:: python
-
- from fastNLP import DataSet
- data = {'raw_words':["This is the first instance .", "Second instance .", "Third instance ."]}
- dataset = DataSet(data)
-
- # 将句子分成单词形式, 详见DataSet.apply()方法
- dataset.apply(lambda ins: ins['raw_words'].split(), new_field_name='words')
-
- # 或使用DataSet.apply_field()
- dataset.apply_field(lambda sent:sent.split(), field_name='raw_words', new_field_name='words')
-
- # 除了匿名函数,也可以定义函数传递进去
- def get_words(instance):
- sentence = instance['raw_words']
- words = sentence.split()
- return words
- dataset.apply(get_words, new_field_name='words')
-
-除了手动处理数据集之外,你还可以使用 fastNLP 提供的各种 :class:`~fastNLP.io.Loader` 和 :class:`~fastNLP.io.Pipe` 来进行数据处理。
-详细请参考这篇教程 :doc:`使用Loader和Pipe处理数据 ` 。
-
-
-fastNLP中field的命名习惯
------------------------------
-
-在英文任务中,fastNLP常用的field名称有:
-
- - **raw_words**: 表示的是原始的str。例如"This is a demo sentence ."。存在多个raw_words的情况,例如matching任务,它们会被定义为raw_words0, raw_words1。但在conll格式下,raw_words列也可能为["This", "is", "a", "demo", "sentence", "."]的形式。
- - **words**: 表示的是已经tokenize后的词语。例如["This", "is", "a", "demo", "sentence"], 但由于str并不能直接被神经网络所使用,所以words中的内容往往被转换为int,如[3, 10, 4, 2, 7, ...]等。多列words的情况,会被命名为words0, words1
- - **target**: 表示目标值。分类场景下,只有一个值;序列标注场景下是一个序列。
- - **seq_len**: 一般用于表示words列的长度
-
-在中文任务中,fastNLP常用的field名称有:
-
- - **raw_words**: 如果原始汉字序列中已经包含了词语的边界,则该列称为raw_words。如"上海 浦东 开发 与 法制 建设 同步"。
- - **words**: 表示单独的汉字词语序列。例如["上海", "", "浦东", "开发", "与", "法制", "建设", ...]或[2, 3, 4, ...]
- - **raw_chars**: 表示的是原始的连续汉字序列。例如"这是一个示例。"
- - **chars**: 表示已经切分为单独的汉字的序列。例如["这", "是", "一", "个", "示", "例", "。"]。但由于神经网络不能识别汉字,所以一般该列会被转为int形式,如[3, 4, 5, 6, ...]。
- - **target**: 表示目标值。分类场景下,只有一个值;序列标注场景下是一个序列
- - **seq_len**: 表示输入序列的长度
-
-----------------------------------
-代码下载
-----------------------------------
-
-.. raw:: html
-
- 点击下载 IPython Notebook 文件
diff --git a/docs/source/tutorials/tutorial_2_vocabulary.rst b/docs/source/tutorials/tutorial_2_vocabulary.rst
deleted file mode 100644
index e8855d99..00000000
--- a/docs/source/tutorials/tutorial_2_vocabulary.rst
+++ /dev/null
@@ -1,140 +0,0 @@
-==============================
-fastNLP中的Vocabulary
-==============================
-
-:class:`~fastNLP.Vocabulary` 是包含字或词与index关系的类,用于将文本转换为index。
-
-
-构建Vocabulary
------------------------------
-
-.. code-block:: python
-
- from fastNLP import Vocabulary
-
- vocab = Vocabulary()
- vocab.add_word_lst(['复', '旦', '大', '学']) # 加入新的字
- vocab.add_word('上海') # `上海`会作为一个整体
- vocab.to_index('复') # 应该会为3
- vocab.to_index('我') # 会输出1,Vocabulary中默认pad的index为0, unk(没有找到的词)的index为1
-
- # 在构建target的Vocabulary时,词表中应该用不上pad和unk,可以通过以下的初始化
- vocab = Vocabulary(unknown=None, padding=None)
- vocab.add_word_lst(['positive', 'negative'])
- vocab.to_index('positive') # 输出0
- vocab.to_index('neutral') # 会报错,因为没有unk这种情况
-
-除了通过以上的方式建立词表,Vocabulary还可以通过使用下面的函数直接从 :class:`~fastNLP.DataSet` 中的某一列建立词表以及将该列转换为index
-
-.. code-block:: python
-
- from fastNLP import Vocabulary
- from fastNLP import DataSet
-
- dataset = DataSet({'chars': [
- ['今', '天', '天', '气', '很', '好', '。'],
- ['被', '这', '部', '电', '影', '浪', '费', '了', '两', '个', '小', '时', '。']
- ],
- 'target': ['neutral', 'negative']
- })
-
- vocab = Vocabulary()
- # 从该dataset中的chars列建立词表
- vocab.from_dataset(dataset, field_name='chars')
- # 使用vocabulary将chars列转换为index
- vocab.index_dataset(dataset, field_name='chars')
-
- target_vocab = Vocabulary(padding=None, unknown=None)
- target_vocab.from_dataset(dataset, field_name='target')
- target_vocab.index_dataset(dataset, field_name='target')
- print(dataset)
-
-输出内容为::
-
- +---------------------------------------------------+--------+
- | chars | target |
- +---------------------------------------------------+--------+
- | [4, 2, 2, 5, 6, 7, 3] | 0 |
- | [8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 3] | 1 |
- +---------------------------------------------------+--------+
-
-
-一些使用tips
------------------------------
-
-在使用from_dataset()函数建立词表时,将测试集和验证集放入参数no_create_entry_dataset中,如下所示
-
-.. code-block:: python
-
- from fastNLP import Vocabulary
- from fastNLP import DataSet
-
- tr_data = DataSet({'chars': [
- ['今', '天', '心', '情', '很', '好', '。'],
- ['被', '这', '部', '电', '影', '浪', '费', '了', '两', '个', '小', '时', '。']
- ],
- 'target': ['positive', 'negative']
- })
- dev_data = DataSet({'chars': [
- ['住', '宿', '条', '件', '还', '不', '错'],
- ['糟', '糕', '的', '天', '气', ',', '无', '法', '出', '行', '。']
- ],
- 'target': ['positive', 'negative']
- })
-
- vocab = Vocabulary()
- # 将验证集或者测试集在建立词表是放入no_create_entry_dataset这个参数中。
- vocab.from_dataset(tr_data, field_name='chars', no_create_entry_dataset=[dev_data])
-
-:class:`~fastNLP.Vocabulary` 中的 `no_create_entry` ,如果您并不关心具体的原理,您可以直接采取以下的建议:在添加来自于非训练集的词的时候将该参数置为True, 或将非训练集数据
-传入 `no_create_entry_dataset` 参数。它们的意义是在接下来的模型会使用pretrain的embedding(包括glove, word2vec, elmo与bert)且会finetune的
-情况下,如果仅使用来自于train的数据建立vocabulary,会导致只出现在test与dev中的词语无法充分利用到来自于预训练embedding的信息(因为他们
-会被认为是unk),所以在建立词表的时候将test与dev考虑进来会使得最终的结果更好。
-
-通过与fastNLP中的各种Embedding配合使用,会有如下的效果,
-如果一个词出现在了train中,但是没在预训练模型中,embedding会为随机初始化,且它单独的一个vector,如果finetune embedding的话,
-这个词在更新之后可能会有更好的表示; 而如果这个词仅出现在了dev或test中,那么就不能为它们单独建立vector,而应该让它指向unk这个vector的
-值(当unk的值更新时,这个词也使用的是更新之后的vector)。所以被认为是no_create_entry的token,将首先从预训练的词表中寻找它的表示,如
-果找到了,就使用该表示; 如果没有找到,则认为该词的表示应该为unk的表示。
-
-下面我们结合部分 :class:`~fastNLP.embeddings.StaticEmbedding` 的例子来说明下该值造成的影响,如果您对 :class:`~fastNLP.embeddings.StaticEmbedding` 不太了解,您可以先参考 :doc:`使用Embedding模块将文本转成向量 ` 部分再来阅读该部分
-
-.. code-block:: python
-
- import torch
- from fastNLP.embeddings import StaticEmbedding
- from fastNLP import Vocabulary
-
- vocab = Vocabulary()
- vocab.add_word('train')
- vocab.add_word('only_in_train') # 仅在train出现,但肯定在预训练词表中不存在
- vocab.add_word('test', no_create_entry=True) # 该词只在dev或test中出现
- vocab.add_word('only_in_test', no_create_entry=True) # 这个词在预训练的词表中找不到
-
- embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d')
- print(embed(torch.LongTensor([vocab.to_index('train')])))
- print(embed(torch.LongTensor([vocab.to_index('only_in_train')])))
- print(embed(torch.LongTensor([vocab.to_index('test')])))
- print(embed(torch.LongTensor([vocab.to_index('only_in_test')])))
- print(embed(torch.LongTensor([vocab.unknown_idx])))
-
-输出结果(只截取了部分vector)::
-
- tensor([[ 0.9497, 0.3433, 0.8450, -0.8852, ...]], grad_fn=) # train,en-glove-6b-50d,找到了该词
- tensor([[ 0.0540, -0.0557, -0.0514, -0.1688, ...]], grad_fn=) # only_in_train,en-glove-6b-50d,使用了随机初始化
- tensor([[ 0.1318, -0.2552, -0.0679, 0.2619, ...]], grad_fn=) # test,在en-glove-6b-50d中找到了这个词
- tensor([[0., 0., 0., 0., 0., ...]], grad_fn=) # only_in_test, en-glove-6b-50d中找不到这个词,使用unk的vector
- tensor([[0., 0., 0., 0., 0., ...]], grad_fn=) # unk,使用zero初始化
-
-首先train和test都能够从预训练中找到对应的vector,所以它们是各自的vector表示; only_in_train在预训练中找不到,StaticEmbedding为它
-新建了一个entry,所以它有一个单独的vector; 而only_in_test在预训练中找不到改词,因此被指向了unk的值(fastNLP用零向量初始化unk),与最后一行unk的
-表示相同。
-
-
-----------------------------------
-代码下载
-----------------------------------
-
-.. raw:: html
-
- 点击下载 IPython Notebook 文件
diff --git a/docs/source/tutorials/tutorial_3_embedding.rst b/docs/source/tutorials/tutorial_3_embedding.rst
deleted file mode 100644
index 147fbb8c..00000000
--- a/docs/source/tutorials/tutorial_3_embedding.rst
+++ /dev/null
@@ -1,462 +0,0 @@
-=========================================
-使用Embedding模块将文本转成向量
-=========================================
-
-这一部分是一个关于在fastNLP当中使用embedding的教程。
-
-教程目录:
-
- - `Part I: embedding介绍`_
- - `Part II: 使用预训练的静态embedding`_
- - `Part III: 使用随机初始化的embedding`_
- - `Part IV: ELMo Embedding`_
- - `Part V: Bert Embedding`_
- - `Part VI: 使用character-level的embedding`_
- - `Part VII: 叠加使用多个embedding`_
- - `Part VIII: Embedding的其它说明`_
- - `Part IX: StaticEmbedding的使用建议`_
-
-
-
-Part I: embedding介绍
----------------------------------------
-
-Embedding是一种词嵌入技术,可以将字或者词转换为实向量。目前使用较多的预训练词嵌入有word2vec, fasttext, glove, character embedding,
-elmo以及bert。
-但使用这些词嵌入方式的时候都需要做一些加载上的处理,比如预训练的word2vec, fasttext以及glove都有着超过几十万个词语的表示,但一般任务大概
-只会用到其中的几万个词,如果直接加载所有的词汇,会导致内存占用变大以及训练速度变慢,需要从预训练文件中抽取本次实验的用到的词汇;而对于英文的
-elmo和character embedding, 需要将word拆分成character才能使用;Bert的使用更是涉及到了Byte pair encoding(BPE)相关的内容。为了方便
-大家的使用,fastNLP通过 :class:`~fastNLP.Vocabulary` 统一了不同embedding的使用。下面我们将讲述一些例子来说明一下
-
-
-
-Part II: 使用预训练的静态embedding
----------------------------------------
-
-在fastNLP中,加载预训练的word2vec, glove以及fasttext都使用的是 :class:`~fastNLP.embeddings.StaticEmbedding` 。另外,为了方便大家的
-使用,fastNLP提供了多种静态词向量的自动下载并缓存(默认缓存到~/.fastNLP/embeddings文件夹下)的功能,支持自动下载的预训练向量可以在
-`下载文档 `_ 查看。
-
-.. code-block:: python
-
- import torch
- from fastNLP.embeddings import StaticEmbedding
- from fastNLP import Vocabulary
-
- vocab = Vocabulary()
- vocab.add_word_lst("this is a demo .".split())
-
- embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d')
-
- words = torch.LongTensor([[vocab.to_index(word) for word in "this is a demo .".split()]]) # 将文本转为index
- print(embed(words).size()) # StaticEmbedding的使用和pytorch的nn.Embedding是类似的
-
-输出为::
-
- torch.Size([1, 5, 50])
-
-fastNLP的StaticEmbedding在初始化之后,就和pytorch中的Embedding是类似的了。 :class:`~fastNLP.embeddings.StaticEmbedding` 的初始化
-主要是从model_dir_or_name提供的词向量中抽取出 :class:`~fastNLP.Vocabulary` 中词语的vector。
-
-除了可以通过使用预先提供的Embedding, :class:`~fastNLP.embeddings.StaticEmbedding` 也支持加载本地的预训练词向量,glove, word2vec以及
-fasttext格式的。通过将model_dir_or_name修改为本地的embedding文件路径,即可使用本地的embedding。
-
-
-Part III: 使用随机初始化的embedding
----------------------------------------
-
-有时候需要使用随机初始化的Embedding,也可以通过使用 :class:`~fastNLP.embeddings.StaticEmbedding` 获得。只需要将model_dir_or_name
-置为None,且传入embedding_dim,如下例所示
-
-.. code-block:: python
-
- from fastNLP.embeddings import StaticEmbedding
- from fastNLP import Vocabulary
-
- vocab = Vocabulary()
- vocab.add_word_lst("this is a demo .".split())
-
- embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=30)
-
- words = torch.LongTensor([[vocab.to_index(word) for word in "this is a demo .".split()]])
- print(embed(words).size())
-
-输出为::
-
- torch.Size([1, 5, 30])
-
-
-
-Part IV: ELMo Embedding
------------------------------------------------------------
-
-在fastNLP中,我们提供了ELMo和BERT的embedding: :class:`~fastNLP.embeddings.ElmoEmbedding`
-和 :class:`~fastNLP.embeddings.BertEmbedding` 。可自动下载的ElmoEmbedding可以
-从 `下载文档 `_ 找到。
-
-与静态embedding类似,ELMo的使用方法如下:
-
-.. code-block:: python
-
- from fastNLP.embeddings import ElmoEmbedding
- from fastNLP import Vocabulary
-
- vocab = Vocabulary()
- vocab.add_word_lst("this is a demo .".split())
-
- embed = ElmoEmbedding(vocab, model_dir_or_name='en-small', requires_grad=False)
- words = torch.LongTensor([[vocab.to_index(word) for word in "this is a demo .".split()]])
- print(embed(words).size())
-
-输出为::
-
- torch.Size([1, 5, 256])
-
-也可以输出多层的ELMo结果,fastNLP将在不同层的结果在最后一维上拼接,下面的代码需要在上面的代码执行结束之后执行
-
-.. code-block:: python
-
- embed = ElmoEmbedding(vocab, model_dir_or_name='en-small', requires_grad=False, layers='1,2')
- print(embed(words).size())
-
-输出为::
-
- torch.Size([1, 5, 512])
-
-另外,根据 `Deep contextualized word representations `_ ,不同层之间使用可学习的权重可以使得ELMo的效果更好,在fastNLP中可以通过以下的初始化
-实现3层输出的结果通过可学习的权重进行加法融合。
-
-.. code-block:: python
-
- embed = ElmoEmbedding(vocab, model_dir_or_name='en-small', requires_grad=True, layers='mix')
- print(embed(words).size()) # 三层输出按照权重element-wise的加起来
-
-输出为::
-
- torch.Size([1, 5, 256])
-
-
-
-Part V: Bert Embedding
------------------------------------------------------------
-
-虽然Bert并不算严格意义上的Embedding,但通过将Bert封装成Embedding的形式将极大减轻使用的复杂程度。可自动下载的Bert Embedding可以
-从 `下载文档 `_ 找到。我们将使用下面的例子讲述一下
-BertEmbedding的使用
-
-.. code-block:: python
-
- from fastNLP.embeddings import BertEmbedding
- from fastNLP import Vocabulary
-
- vocab = Vocabulary()
- vocab.add_word_lst("this is a demo .".split())
-
- embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased')
- words = torch.LongTensor([[vocab.to_index(word) for word in "this is a demo .".split()]])
- print(embed(words).size())
-
-输出为::
-
- torch.Size([1, 5, 768])
-
-可以通过申明使用指定层数的output也可以使用多层的output,下面的代码需要在上面的代码执行结束之后执行
-
-.. code-block:: python
-
- # 使用后面两层的输出
- embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='10,11')
- print(embed(words).size()) # 结果将是在最后一维做拼接
-
-输出为::
-
- torch.Size([1, 5, 1536])
-
-在Bert中还存在两个特殊的字符[CLS]和[SEP],默认情况下这两个字符是自动加入并且在计算结束之后会自动删除,以使得输入的序列长度和输出的序列
-长度是一致的,但是有些分类的情况,必须需要使用[CLS]的表示,这种情况可以通过在初始化时申明一下需要保留[CLS]的表示,如下例所示
-
-.. code-block:: python
-
- embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='-1', include_cls_sep=True)
- print(embed(words).size()) # 结果将在序列维度上增加2
- # 取出句子的cls表示
- cls_reps = embed(words)[:, 0] # shape: [batch_size, 768]
-
-输出为::
-
- torch.Size([1, 7, 768])
-
-在英文Bert模型中,一个英文单词可能会被切分为多个subword,例如"fairness"会被拆分为 ``["fair", "##ness"]`` ,这样一个word对应的将有两个输出,
-:class:`~fastNLP.embeddings.BertEmbedding` 会使用pooling方法将一个word的subword的表示合并成一个vector,通过pool_method可以控制
-该pooling方法,支持的有"first"(即使用fair的表示作为fairness的表示), "last"(使用##ness的表示作为fairness的表示), "max"(对fair和
-##ness在每一维上做max),"avg"(对fair和##ness每一维做average)。
-
-.. code-block:: python
-
- embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='-1', pool_method='max')
- print(embed(words).size())
-
-输出为::
-
- torch.Size([1, 5, 768])
-
-另外,根据 `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding `_ ,
-Bert在针对具有两句话的任务时(如matching,Q&A任务),句子之间通过[SEP]拼接起来,前一句话的token embedding为0,
-后一句话的token embedding为1。BertEmbedding能够自动识别句子中间的[SEP]来正确设置对应的token_type_id的。
-
-.. code-block:: python
-
- vocab = Vocabulary()
- vocab.add_word_lst("this is a demo . [SEP] another sentence .".split())
-
- embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='-1', pool_method='max')
- words = torch.LongTensor([[vocab.to_index(word) for word in "this is a demo . [SEP] another sentence .".split()]])
- print(embed(words).size())
-
-输出为::
-
- torch.Size([1, 9, 768])
-
-在多个[SEP]的情况下,将会使token_type_id不断0,1循环。比如"first sentence [SEP] second sentence [SEP] third sentence", 它们的
-token_type_id将是[0, 0, 0, 1, 1, 1, 0, 0]。但请注意[SEP]一定要大写的,不能是[sep],否则无法识别。
-
-更多 :class:`~fastNLP.embedding.BertEmbedding` 的使用,请参考 :doc:`/tutorials/extend_1_bert_embedding`
-
-
-Part VI: 使用character-level的embedding
------------------------------------------------------
-
-除了预训练的embedding以外,fastNLP还提供了两种Character Embedding: :class:`~fastNLP.embeddings.CNNCharEmbedding` 和
-:class:`~fastNLP.embeddings.LSTMCharEmbedding` 。一般在使用character embedding时,需要在预处理的时候将word拆分成character,这
-会使得预处理过程变得非常繁琐。在fastNLP中,使用character embedding也只需要传入 :class:`~fastNLP.Vocabulary` 即可,而且该
-Vocabulary与其它Embedding使用的Vocabulary是一致的,下面我们看两个例子。
-
-CNNCharEmbedding的使用例子如下:
-
-.. code-block:: python
-
- from fastNLP.embeddings import CNNCharEmbedding
- from fastNLP import Vocabulary
-
- vocab = Vocabulary()
- vocab.add_word_lst("this is a demo .".split())
-
- # character的embedding维度大小为50,返回的embedding结果维度大小为64。
- embed = CNNCharEmbedding(vocab, embed_size=64, char_emb_size=50)
- words = torch.LongTensor([[vocab.to_index(word) for word in "this is a demo .".split()]])
- print(embed(words).size())
-
-输出为::
-
- torch.Size([1, 5, 64])
-
-与CNNCharEmbedding类似,LSTMCharEmbedding的使用例子如下:
-
-.. code-block:: python
-
- from fastNLP.embeddings import LSTMCharEmbedding
- from fastNLP import Vocabulary
-
- vocab = Vocabulary()
- vocab.add_word_lst("this is a demo .".split())
-
- # character的embedding维度大小为50,返回的embedding结果维度大小为64。
- embed = LSTMCharEmbedding(vocab, embed_size=64, char_emb_size=50)
- words = torch.LongTensor([[vocab.to_index(word) for word in "this is a demo .".split()]])
- print(embed(words).size())
-
-输出为::
-
- torch.Size([1, 5, 64])
-
-
-Part VII: 叠加使用多个embedding
------------------------------------------------------
-
-单独使用Character Embedding往往效果并不是很好,需要同时结合word embedding。在fastNLP中可以通过 :class:`~fastNLP.embeddings.StackEmbedding`
-来叠加embedding,具体的例子如下所示
-
-.. code-block:: python
-
- from fastNLP.embeddings import StaticEmbedding, StackEmbedding, CNNCharEmbedding
- from fastNLP import Vocabulary
-
- vocab = Vocabulary()
- vocab.add_word_lst("this is a demo .".split())
-
- word_embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d')
- char_embed = CNNCharEmbedding(vocab, embed_size=64, char_emb_size=50)
- embed = StackEmbedding([word_embed, char_embed])
-
- words = torch.LongTensor([[vocab.to_index(word) for word in "this is a demo .".split()]])
- print(embed(words).size()) # 输出embedding的维度为50+64=114
-
-输出为::
-
- torch.Size([1, 5, 114])
-
-:class:`~fastNLP.embeddings.StaticEmbedding` , :class:`~fastNLP.embeddings.ElmoEmbedding` ,
-:class:`~fastNLP.embeddings.CNNCharEmbedding` , :class:`~fastNLP.embeddings.BertEmbedding` 等都可以互相拼接。
-:class:`~fastNLP.embeddings.StackEmbedding` 的使用也是和其它Embedding是一致的,即输出index返回对应的表示。但能够拼接起来的Embedding
-必须使用同样的 :class:`~fastNLP.Vocabulary` ,因为只有使用同样的 :class:`~fastNLP.Vocabulary` 才能保证同一个index指向的是同一个词或字
-
-
-
-Part VIII: Embedding的其它说明
------------------------------------------------------------
-
-(1) 获取各种Embedding的dimension
-
-.. code-block:: python
-
- from fastNLP.embeddings import *
-
- vocab = Vocabulary()
- vocab.add_word_lst("this is a demo .".split())
-
- static_embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d')
- print(static_embed.embedding_dim) # 50
- char_embed = CNNCharEmbedding(vocab, embed_size=30)
- print(char_embed.embedding_dim) # 30
- elmo_embed_1 = ElmoEmbedding(vocab, model_dir_or_name='en-small', layers='2')
- print(elmo_embed_1.embedding_dim) # 256
- elmo_embed_2 = ElmoEmbedding(vocab, model_dir_or_name='en-small', layers='1,2')
- print(elmo_embed_2.embedding_dim) # 512
- bert_embed_1 = BertEmbedding(vocab, layers='-1', model_dir_or_name='en-base-cased')
- print(bert_embed_1.embedding_dim) # 768
- bert_embed_2 = BertEmbedding(vocab, layers='2,-1', model_dir_or_name='en-base-cased')
- print(bert_embed_2.embedding_dim) # 1536
- stack_embed = StackEmbedding([static_embed, char_embed])
- print(stack_embed.embedding_dim) # 80
-
-(2) 设置Embedding的权重是否更新
-
-.. code-block:: python
-
- from fastNLP.embeddings import *
-
- vocab = Vocabulary()
- vocab.add_word_lst("this is a demo .".split())
-
- embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', requires_grad=True) # 初始化时设定为需要更新
- embed.requires_grad = False # 修改BertEmbedding的权重为不更新
-
-(3) 各种Embedding中word_dropout与dropout的说明
-
-fastNLP中所有的Embedding都支持传入word_dropout和dropout参数,word_dropout指示的是以多大概率将输入的word置为unk的index,这样既可以
-是的unk得到训练,也可以有一定的regularize效果; dropout参数是在获取到word的表示之后,以多大概率将一些维度的表示置为0。
-
-如果使用 :class:`~fastNLP.embeddings.StackEmbedding` 且需要用到word_dropout,建议将word_dropout设置在 :class:`~fastNLP.embeddings.StackEmbedding` 上。
-
-
-
-Part IX: StaticEmbedding的使用建议
------------------------------------------------------------
-
-在英文的命名实体识别(NER)任务中,由 `Named Entity Recognition with Bidirectional LSTM-CNNs `_ 指出,同时使用cnn character embedding和word embedding
-会使得NER的效果有比较大的提升。正如你在上节中看到的那样,fastNLP支持将 :class:`~fastNLP.embeddings.CNNCharEmbedding`
-与 :class:`~fastNLP.embeddings.StaticEmbedding` 拼成一个 :class:`~fastNLP.embeddings.StackEmbedding` 。如果通过这种方式使用,需要
-在预处理文本时,不要将词汇小写化(因为Character Embedding需要利用词语中的大小写信息)且不要将出现频次低于某个阈值的word设置为unk(因为
-Character embedding需要利用字形信息);但 :class:`~fastNLP.embeddings.StaticEmbedding` 使用的某些预训练词嵌入的词汇表中只有小写的词
-语, 且某些低频词并未在预训练中出现需要被剔除。即(1) character embedding需要保留大小写,而预训练词向量不需要保留大小写。(2)
-character embedding需要保留所有的字形, 而static embedding需要设置一个最低阈值以学到更好的表示。
-
-(1) fastNLP如何解决关于大小写的问题
-
-fastNLP通过在 :class:`~fastNLP.embeddings.StaticEmbedding` 增加了一个lower参数解决该问题。如下面的例子所示
-
-.. code-block:: python
-
- from fastNLP.embeddings import StaticEmbedding
- from fastNLP import Vocabulary
-
- vocab = Vocabulary().add_word_lst("The the a A".split())
- # 下面用随机的StaticEmbedding演示,但与使用预训练词向量时效果是一致的
- embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5)
- print(embed(torch.LongTensor([vocab.to_index('The')])))
- print(embed(torch.LongTensor([vocab.to_index('the')])))
-
-输出为::
-
- tensor([[-0.4685, 0.4572, 0.5159, -0.2618, -0.6871]], grad_fn=)
- tensor([[ 0.2615, 0.1490, -0.2491, 0.4009, -0.3842]], grad_fn=)
-
-可以看到"The"与"the"的vector是不一致的。但如果我们在初始化 :class:`~fastNLP.embeddings.StaticEmbedding` 将lower设置为True,效果将
-如下所示
-
-.. code-block:: python
-
- from fastNLP.embeddings import StaticEmbedding
- from fastNLP import Vocabulary
-
- vocab = Vocabulary().add_word_lst("The the a A".split())
- # 下面用随机的StaticEmbedding演示,但与使用预训练时效果是一致的
- embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5, lower=True)
- print(embed(torch.LongTensor([vocab.to_index('The')])))
- print(embed(torch.LongTensor([vocab.to_index('the')])))
-
-输出为::
-
- tensor([[-0.2237, 0.6825, -0.3459, -0.1795, 0.7516]], grad_fn=)
- tensor([[-0.2237, 0.6825, -0.3459, -0.1795, 0.7516]], grad_fn=)
-
-可以看到"The"与"the"的vector是一致的。他们实际上也是引用的同一个vector。通过将lower设置为True,可以在 :class:`~fastNLP.embeddings.StaticEmbedding`
-实现类似具备相同小写结果的词语引用同一个vector。
-
-(2) fastNLP如何解决min_freq的问题
-
-fastNLP通过在 :class:`~fastNLP.embeddings.StaticEmbedding` 增加了一个min_freq参数解决该问题。如下面的例子所示
-
-.. code-block:: python
-
- from fastNLP.embeddings import StaticEmbedding
- from fastNLP import Vocabulary
-
- vocab = Vocabulary().add_word_lst("the the the a".split())
- # 下面用随机的StaticEmbedding演示,但与使用预训练时效果是一致的
- embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5, min_freq=2)
- print(embed(torch.LongTensor([vocab.to_index('the')])))
- print(embed(torch.LongTensor([vocab.to_index('a')])))
- print(embed(torch.LongTensor([vocab.unknown_idx])))
-
-输出为::
-
- tensor([[ 0.0454, 0.3375, 0.6758, -0.2026, -0.4715]], grad_fn=)
- tensor([[-0.7602, 0.0149, 0.2733, 0.3974, 0.7371]], grad_fn=)
- tensor([[-0.7602, 0.0149, 0.2733, 0.3974, 0.7371]], grad_fn=)
-
-其中最后一行为unknown值的vector,可以看到a的vector表示与unknown是一样的,这是由于a的频次低于了2,所以被指向了unknown的表示;而the由于
-词频超过了2次,所以它是单独的表示。
-
-在计算min_freq时,也会考虑到lower的作用,比如
-
-.. code-block:: python
-
- from fastNLP.embeddings import StaticEmbedding
- from fastNLP import Vocabulary
-
- vocab = Vocabulary().add_word_lst("the the the a A".split())
- # 下面用随机的StaticEmbedding演示,但与使用预训练时效果是一致的
- embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5, min_freq=2, lower=True)
- print(embed(torch.LongTensor([vocab.to_index('the')])))
- print(embed(torch.LongTensor([vocab.to_index('a')])))
- print(embed(torch.LongTensor([vocab.to_index('A')])))
- print(embed(torch.LongTensor([vocab.unknown_idx])))
-
-输出为::
-
- tensor([[-0.7453, -0.5542, 0.5039, 0.6195, -0.4723]], grad_fn=) # the
- tensor([[ 0.0170, -0.0995, -0.5743, -0.2469, -0.2095]], grad_fn=) # a
- tensor([[ 0.0170, -0.0995, -0.5743, -0.2469, -0.2095]], grad_fn=) # A
- tensor([[ 0.6707, -0.5786, -0.6967, 0.0111, 0.1209]], grad_fn=) # unk
-
-可以看到a不再和最后一行的unknown共享一个表示了,这是由于a与A都算入了a的词频,且A的表示也是a的表示。
-
-
-----------------------------------
-代码下载
-----------------------------------
-
-.. raw:: html
-
- 点击下载 IPython Notebook 文件
diff --git a/docs/source/tutorials/tutorial_4_load_dataset.rst b/docs/source/tutorials/tutorial_4_load_dataset.rst
deleted file mode 100644
index 4fb69d1b..00000000
--- a/docs/source/tutorials/tutorial_4_load_dataset.rst
+++ /dev/null
@@ -1,219 +0,0 @@
-=======================================
-使用Loader和Pipe加载并处理数据集
-=======================================
-
-这一部分是关于如何加载数据集的教程
-
-教程目录:
-
- - `Part I: 数据集容器DataBundle`_
- - `Part II: 加载的各种数据集的Loader`_
- - `Part III: 使用Pipe对数据集进行预处理`_
- - `Part IV: fastNLP封装好的Loader和Pipe`_
- - `Part V: 不同格式类型的基础Loader`_
-
-
-Part I: 数据集容器DataBundle
-------------------------------------
-
-而由于对于同一个任务,训练集,验证集和测试集会共用同一个词表以及具有相同的目标值,所以在fastNLP中我们使用了 :class:`~fastNLP.io.DataBundle`
-来承载同一个任务的多个数据集 :class:`~fastNLP.DataSet` 以及它们的词表 :class:`~fastNLP.Vocabulary` 。下面会有例子介绍 :class:`~fastNLP.io.DataBundle`
-的相关使用。
-
-:class:`~fastNLP.io.DataBundle` 在fastNLP中主要在各个 :class:`~fastNLP.io.Loader` 和 :class:`~fastNLP.io.Pipe` 中被使用。
-下面我们先介绍一下 :class:`~fastNLP.io.Loader` 和 :class:`~fastNLP.io.Pipe` 。
-
-Part II: 加载的各种数据集的Loader
--------------------------------------
-
-在fastNLP中,所有的 :class:`~fastNLP.io.Loader` 都可以通过其文档判断其支持读取的数据格式,以及读取之后返回的 :class:`~fastNLP.DataSet` 的格式,
-例如 :class:`~fastNLP.io.ChnSentiCorpLoader` 。
-
- - **download()** 函数:自动将该数据集下载到缓存地址,默认缓存地址为~/.fastNLP/datasets/。由于版权等原因,不是所有的Loader都实现了该方法。该方法会返回下载后文件所处的缓存地址。
- - **_load()** 函数:从一个数据文件中读取数据,返回一个 :class:`~fastNLP.DataSet` 。返回的DataSet的格式可从Loader文档判断。
- - **load()** 函数:从文件或者文件夹中读取数据为 :class:`~fastNLP.DataSet` 并将它们组装成 :class:`~fastNLP.io.DataBundle`。支持接受的参数类型有以下的几种
-
- - None, 将尝试读取自动缓存的数据,仅支持提供了自动下载数据的Loader
- - 文件夹路径, 默认将尝试在该文件夹下匹配文件名中含有 `train` , `test` , `dev` 的文件,如果有多个文件含有相同的关键字,将无法通过该方式读取
- - dict, 例如{'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"}。
-
-.. code-block:: python
-
- from fastNLP.io import CWSLoader
-
- loader = CWSLoader(dataset_name='pku')
- data_bundle = loader.load()
- print(data_bundle)
-
-输出内容为::
-
- In total 3 datasets:
- dev has 1831 instances.
- train has 17223 instances.
- test has 1944 instances.
-
-这里表示一共有3个数据集。其中:
-
- - 3个数据集的名称分别为train、dev、test,分别有17223、1831、1944个instance
-
-也可以取出DataSet,并打印DataSet中的具体内容
-
-.. code-block:: python
-
- tr_data = data_bundle.get_dataset('train')
- print(tr_data[:2])
-
-输出为::
-
- +--------------------------------------------------------------------------------------+
- | raw_words |
- +--------------------------------------------------------------------------------------+
- | 迈向 充满 希望 的 新 世纪 —— 一九九八年 新年 讲话 ( 附 图片 1 张 ) |
- | 中共中央 总书记 、 国家 主席 江 泽民 |
- +--------------------------------------------------------------------------------------+
-
-Part III: 使用Pipe对数据集进行预处理
-------------------------------------------
-通过 :class:`~fastNLP.io.Loader` 可以将文本数据读入,但并不能直接被神经网络使用,还需要进行一定的预处理。
-
-在fastNLP中,我们使用 :class:`~fastNLP.io.Pipe` 的子类作为数据预处理的类, :class:`~fastNLP.io.Loader` 和 :class:`~fastNLP.io.Pipe` 一般具备一一对应的关系,该关系可以从其名称判断,
-例如 :class:`~fastNLP.io.CWSLoader` 与 :class:`~fastNLP.io.CWSPipe` 是一一对应的。一般情况下Pipe处理包含以下的几个过程,(1)将raw_words或
-raw_chars进行tokenize以切分成不同的词或字; (2) 再建立词或字的 :class:`~fastNLP.Vocabulary` , 并将词或字转换为index; (3)将target
-列建立词表并将target列转为index;
-
-所有的Pipe都可通过其文档查看该Pipe支持处理的 :class:`~fastNLP.DataSet` 以及返回的 :class:`~fastNLP.io.DataBundle` 中的Vocabulary的情况;
-如 :class:`~fastNLP.io.OntoNotesNERPipe`
-
-各种数据集的Pipe当中,都包含了以下的两个函数:
-
- - process() 函数:对输入的 :class:`~fastNLP.io.DataBundle` 进行处理, 然后返回处理之后的 :class:`~fastNLP.io.DataBundle` 。process函数的文档中包含了该Pipe支持处理的DataSet的格式。
- - process_from_file() 函数:输入数据集所在文件夹,使用对应的Loader读取数据(所以该函数支持的参数类型是由于其对应的Loader的load函数决定的),然后调用相对应的process函数对数据进行预处理。相当于是把Load和process放在一个函数中执行。
-
-接着上面 :class:`~fastNLP.io.CWSLoader` 的例子,我们展示一下 :class:`~fastNLP.io.CWSPipe` 的功能:
-
-.. code-block:: python
-
- from fastNLP.io import CWSPipe
-
- data_bundle = CWSPipe().process(data_bundle)
- print(data_bundle)
-
-输出内容为::
-
- In total 3 datasets:
- dev has 1831 instances.
- train has 17223 instances.
- test has 1944 instances.
- In total 2 vocabs:
- chars has 4777 entries.
- target has 4 entries.
-
-表示一共有3个数据集和2个词表。其中:
-
- - 3个数据集的名称分别为train、dev、test,分别有17223、1831、1944个instance
- - 2个词表分别为chars词表与target词表。其中chars词表为句子文本所构建的词表,一共有4777个不同的字;target词表为目标标签所构建的词表,一共有4种标签。
-
-相较于之前CWSLoader读取的DataBundle,新增了两个Vocabulary。 我们可以打印一下处理之后的DataSet
-
-.. code-block:: python
-
- tr_data = data_bundle.get_dataset('train')
- print(tr_data[:2])
-
-输出为::
-
- +---------------------------------------------------+------------------------------------+------------------------------------+---------+
- | raw_words | chars | target | seq_len |
- +---------------------------------------------------+------------------------------------+------------------------------------+---------+
- | 迈向 充满 希望 的 新 世纪 —— 一九九八年... | [1224, 178, 674, 544, 573, 435,... | [0, 1, 0, 1, 0, 1, 2, 2, 0, 1, ... | 29 |
- | 中共中央 总书记 、 国家 主席 江 泽民 | [11, 212, 11, 335, 124, 256, 10... | [0, 3, 3, 1, 0, 3, 1, 2, 0, 1, ... | 15 |
- +---------------------------------------------------+------------------------------------+------------------------------------+---------+
-
-可以看到有两列为int的field: chars和target。这两列的名称同时也是DataBundle中的Vocabulary的名称。可以通过下列的代码获取并查看Vocabulary的
-信息
-
-.. code-block:: python
-
- vocab = data_bundle.get_vocab('target')
- print(vocab)
-
-输出为::
-
- Vocabulary(['B', 'E', 'S', 'M']...)
-
-
-Part IV: fastNLP封装好的Loader和Pipe
-------------------------------------------
-
-fastNLP封装了多种任务/数据集的 :class:`~fastNLP.io.Loader` 和 :class:`~fastNLP.io.Pipe` 并提供自动下载功能,具体参见文档
-`数据集 `_
-
-
-Part V: 不同格式类型的基础Loader
---------------------------------------------------------
-
-除了上面提到的针对具体任务的Loader,我们还提供了CSV格式和JSON格式的Loader
-
-:class:`~fastNLP.io.loader.CSVLoader` 读取CSV类型的数据集文件。例子如下:
-
- .. code-block:: python
-
- from fastNLP.io.loader import CSVLoader
- data_set_loader = CSVLoader(
- headers=('raw_words', 'target'), sep='\t'
- )
- # 表示将CSV文件中每一行的第一项将填入'raw_words' field,第二项填入'target' field。
- # 其中项之间由'\t'分割开来
-
- data_set = data_set_loader._load('path/to/your/file')
-
- 文件内容样例如下 ::
-
- But it does not leave you with much . 1
- You could hate it for the same reason . 1
- The performances are an absolute joy . 4
-
- 读取之后的DataSet具有以下的field
-
- .. csv-table::
- :header: raw_words, target
-
- "But it does not leave you with much .", "1"
- "You could hate it for the same reason .", "1"
- "The performances are an absolute joy .", "4"
-
-:class:`~fastNLP.io.JsonLoader` 读取Json类型的数据集文件,数据必须按行存储,每行是一个包含各类属性的Json对象。例子如下:
-
- .. code-block:: python
-
- from fastNLP.io.loader import JsonLoader
- loader = JsonLoader(
- fields={'sentence1': 'raw_words1', 'sentence2': 'raw_words2', 'gold_label': 'target'}
- )
- # 表示将Json对象中'sentence1'、'sentence2'和'gold_label'对应的值赋给'raw_words1'、'raw_words2'、'target'这三个fields
-
- data_set = loader._load('path/to/your/file')
-
- 数据集内容样例如下 ::
-
- {"annotator_labels": ["neutral"], "captionID": "3416050480.jpg#4", "gold_label": "neutral", "pairID": "3416050480.jpg#4r1n", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is training his horse for a competition.", "sentence2_binary_parse": "( ( A person ) ( ( is ( ( training ( his horse ) ) ( for ( a competition ) ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (VP (VBG training) (NP (PRP$ his) (NN horse)) (PP (IN for) (NP (DT a) (NN competition))))) (. .)))"}
- {"annotator_labels": ["contradiction"], "captionID": "3416050480.jpg#4", "gold_label": "contradiction", "pairID": "3416050480.jpg#4r1c", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is at a diner, ordering an omelette.", "sentence2_binary_parse": "( ( A person ) ( ( ( ( is ( at ( a diner ) ) ) , ) ( ordering ( an omelette ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (PP (IN at) (NP (DT a) (NN diner))) (, ,) (S (VP (VBG ordering) (NP (DT an) (NN omelette))))) (. .)))"}
- {"annotator_labels": ["entailment"], "captionID": "3416050480.jpg#4", "gold_label": "entailment", "pairID": "3416050480.jpg#4r1e", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is outdoors, on a horse.", "sentence2_binary_parse": "( ( A person ) ( ( ( ( is outdoors ) , ) ( on ( a horse ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (ADVP (RB outdoors)) (, ,) (PP (IN on) (NP (DT a) (NN horse)))) (. .)))"}
-
- 读取之后的DataSet具有以下的field
-
- .. csv-table::
- :header: raw_words0, raw_words1, target
-
- "A person on a horse jumps over a broken down airplane.", "A person is training his horse for a competition.", "neutral"
- "A person on a horse jumps over a broken down airplane.", "A person is at a diner, ordering an omelette.", "contradiction"
- "A person on a horse jumps over a broken down airplane.", "A person is outdoors, on a horse.", "entailment"
-
-
-----------------------------------
-代码下载
-----------------------------------
-
-.. raw:: html
-
- 点击下载 IPython Notebook 文件
diff --git a/docs/source/tutorials/tutorial_5_loss_optimizer.rst b/docs/source/tutorials/tutorial_5_loss_optimizer.rst
deleted file mode 100644
index 846a07a5..00000000
--- a/docs/source/tutorials/tutorial_5_loss_optimizer.rst
+++ /dev/null
@@ -1,248 +0,0 @@
-==============================================================================
-使用Trainer和Tester快速训练和测试
-==============================================================================
-
-我们使用前面介绍过的 :doc:`/tutorials/文本分类` 任务来进行详细的介绍。这里我们把数据集换成了SST2,使用 :class:`~fastNLP.Trainer` 和 :class:`~fastNLP.Tester` 来进行快速训练和测试。
-
-.. note::
-
- 本教程中的代码没有使用 GPU 。读者可以自行修改代码,扩大数据量并使用 GPU 进行训练。
-
-数据读入和处理
------------------
-
-数据读入
- 我们可以使用 fastNLP :mod:`fastNLP.io` 模块中的 :class:`~fastNLP.io.SST2Pipe` 类,轻松地读取以及预处理SST2数据集。:class:`~fastNLP.io.SST2Pipe` 对象的
- :meth:`~fastNLP.io.SST2Pipe.process_from_file` 方法能够对读入的SST2数据集进行数据的预处理,方法的参数为paths, 指要处理的文件所在目录,如果paths为None,则会自动下载数据集,函数默认paths值为None。
- 此函数返回一个 :class:`~fastNLP.io.DataBundle`,包含SST2数据集的训练集、测试集、验证集以及source端和target端的字典。其训练、测试、验证数据集含有四个 :mod:`~fastNLP.core.field` :
-
- * raw_words: 原source句子
- * target: 标签值
- * words: index之后的raw_words
- * seq_len: 句子长度
-
- 读入数据代码如下:
-
- .. code-block:: python
-
- from fastNLP.io import SST2Pipe
-
- pipe = SST2Pipe()
- databundle = pipe.process_from_file()
- vocab = databundle.get_vocab('words')
- print(databundle)
- print(databundle.get_dataset('train')[0])
- print(databundle.get_vocab('words'))
-
-
- 输出数据如下::
-
- In total 3 datasets:
- test has 1821 instances.
- train has 67349 instances.
- dev has 872 instances.
- In total 2 vocabs:
- words has 16293 entries.
- target has 2 entries.
-
- +-------------------------------------------+--------+--------------------------------------+---------+
- | raw_words | target | words | seq_len |
- +-------------------------------------------+--------+--------------------------------------+---------+
- | hide new secretions from the parental ... | 1 | [4111, 98, 12010, 38, 2, 6844, 9042] | 7 |
- +-------------------------------------------+--------+--------------------------------------+---------+
-
- Vocabulary(['hide', 'new', 'secretions', 'from', 'the']...)
-
- 除了可以对数据进行读入的Pipe类,fastNLP还提供了读入和下载数据的Loader类,不同数据集的Pipe和Loader及其用法详见 :doc:`/tutorials/tutorial_4_load_dataset` 。
-
-数据集分割
- 由于SST2数据集的测试集并不带有标签数值,故我们分割出一部分训练集作为测试集。下面这段代码展示了 :meth:`~fastNLP.DataSet.split` 的使用方法,
- 为了能让读者快速运行完整个教程,我们只取了训练集的前5000个数据。
-
- .. code-block:: python
-
- train_data = databundle.get_dataset('train')[:5000]
- train_data, test_data = train_data.split(0.015)
- dev_data = databundle.get_dataset('dev')
- print(len(train_data),len(dev_data),len(test_data))
-
- 输出结果为::
-
- 4925 872 75
-
-数据集 :meth:`~fastNLP.DataSet.set_input` 和 :meth:`~fastNLP.DataSet.set_target` 函数
- :class:`~fastNLP.io.SST2Pipe` 类的 :meth:`~fastNLP.io.SST2Pipe.process_from_file` 方法在预处理过程中还将训练、测试、验证
- 集的 `words` 、`seq_len` :mod:`~fastNLP.core.field` 设定为input,同时将 `target` :mod:`~fastNLP.core.field` 设定
- 为target。我们可以通过 :class:`~fastNLP.core.Dataset` 类的 :meth:`~fastNLP.core.Dataset.print_field_meta` 方法查看各个 :mod:`~fastNLP.core.field` 的设定情况,代码如下:
-
- .. code-block:: python
-
- train_data.print_field_meta()
-
- 输出结果为::
-
- +-------------+-----------+--------+-------+---------+
- | field_names | raw_words | target | words | seq_len |
- +-------------+-----------+--------+-------+---------+
- | is_input | False | False | True | True |
- | is_target | False | True | False | False |
- | ignore_type | | False | False | False |
- | pad_value | | 0 | 0 | 0 |
- +-------------+-----------+--------+-------+---------+
-
- 其中is_input和is_target分别表示是否为input和target。ignore_type为true时指使用 :class:`~fastNLP.DataSetIter` 取出batch数
- 据时fastNLP不会进行自动padding,pad_value指对应 :mod:`~fastNLP.core.field` padding所用的值,这两者只有
- 当 :mod:`~fastNLP.core.field` 设定为input或者target的时候才有存在的意义。
-
- is_input为true的 :mod:`~fastNLP.core.field` 在 :class:`~fastNLP.DataSetIter` 迭代取出的batch_x 中,而is_target为true
- 的 :mod:`~fastNLP.core.field` 在 :class:`~fastNLP.DataSetIter` 迭代取出的 batch_y 中。
- 具体分析见 :doc:`使用DataSetIter实现自定义训练过程 ` 。
-
-使用内置模型训练
----------------------
-模型定义和初始化
- 我们可以导入 fastNLP 内置的文本分类模型 :class:`~fastNLP.models.CNNText` 来对模型进行定义,代码如下:
-
- .. code-block:: python
-
- from fastNLP.models import CNNText
-
- #词嵌入的维度
- EMBED_DIM = 100
-
- #使用CNNText的时候第一个参数输入一个tuple,作为模型定义embedding的参数
- #还可以传入 kernel_nums, kernel_sizes, padding, dropout的自定义值
- model_cnn = CNNText((len(vocab),EMBED_DIM), num_classes=2, dropout=0.1)
-
- 使用fastNLP快速搭建自己的模型详见 :doc:`/tutorials/tutorial_8_modules_models` 。
-
-评价指标
- 训练模型需要提供一个评价指标。这里使用准确率做为评价指标。
-
- * ``pred`` 参数对应的是模型的 forward 方法返回的 dict 中的一个 key 的名字。
- * ``target`` 参数对应的是 :class:`~fastNLP.DataSet` 中作为标签的 :mod:`~fastNLP.core.field` 的名字。
-
- 这里我们用 :class:`~fastNLP.Const` 来辅助命名,如果你自己编写模型中 forward 方法的返回值或
- 数据集中 :mod:`~fastNLP.core.field` 的名字与本例不同, 你可以把 ``pred`` 参数和 ``target`` 参数设定符合自己代码的值。代码如下:
-
- .. code-block:: python
-
- from fastNLP import AccuracyMetric
- from fastNLP import Const
-
- # metrics=AccuracyMetric() 在本例中与下面这行代码等价
- metrics=AccuracyMetric(pred=Const.OUTPUT, target=Const.TARGET)
-
-
-损失函数
- 训练模型需要提供一个损失函数
- ,fastNLP中提供了直接可以导入使用的四种loss,分别为:
-
- * :class:`~fastNLP.CrossEntropyLoss`:包装了torch.nn.functional.cross_entropy()函数,返回交叉熵损失(可以运用于多分类场景)
- * :class:`~fastNLP.BCELoss`:包装了torch.nn.functional.binary_cross_entropy()函数,返回二分类的交叉熵
- * :class:`~fastNLP.L1Loss`:包装了torch.nn.functional.l1_loss()函数,返回L1 损失
- * :class:`~fastNLP.NLLLoss`:包装了torch.nn.functional.nll_loss()函数,返回负对数似然损失
-
- 下面提供了一个在分类问题中常用的交叉熵损失。注意它的 **初始化参数** 。
-
- * ``pred`` 参数对应的是模型的 forward 方法返回的 dict 中的一个 key 的名字。
- * ``target`` 参数对应的是 :class:`~fastNLP.DataSet` 中作为标签的 :mod:`~fastNLP.core.field` 的名字。
-
- 这里我们用 :class:`~fastNLP.Const` 来辅助命名,如果你自己编写模型中 forward 方法的返回值或
- 数据集中 :mod:`~fastNLP.core.field` 的名字与本例不同, 你可以把 ``pred`` 参数和 ``target`` 参数设定符合自己代码的值。
-
- .. code-block:: python
-
- from fastNLP import CrossEntropyLoss
-
- # loss = CrossEntropyLoss() 在本例中与下面这行代码等价
- loss = CrossEntropyLoss(pred=Const.OUTPUT, target=Const.TARGET)
-
- 除了使用fastNLP已经包装好的了损失函数,也可以通过fastNLP中的LossFunc类来构建自己的损失函数,方法如下:
-
- .. code-block:: python
-
- # 这表示构建了一个损失函数类,由func计算损失函数,其中将从模型返回值或者DataSet的target=True的field
- # 当中找到一个参数名为`pred`的参数传入func一个参数名为`input`的参数;找到一个参数名为`label`的参数
- # 传入func作为一个名为`target`的参数
- #下面自己构建了一个交叉熵函数,和之后直接使用fastNLP中的交叉熵函数是一个效果
- import torch
- from fastNLP import LossFunc
- func = torch.nn.functional.cross_entropy
- loss_func = LossFunc(func, input=Const.OUTPUT, target=Const.TARGET)
-
-优化器
- 定义模型运行的时候使用的优化器,可以直接使用torch.optim.Optimizer中的优化器,并在实例化 :class:`~fastNLP.Trainer` 类的时候传入优化器实参
-
- .. code-block:: python
-
- import torch.optim as optim
-
- #使用 torch.optim 定义优化器
- optimizer=optim.RMSprop(model_cnn.parameters(), lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False)
-
-快速训练
- 现在我们对上面定义的模型使用 :class:`~fastNLP.Trainer` 进行训练。
- 除了使用 :class:`~fastNLP.Trainer`进行训练,我们也可以通过使用 :class:`~fastNLP.DataSetIter` 来编写自己的训练过程,具体见 :doc:`/tutorials/tutorial_6_datasetiter`
-
- .. code-block:: python
-
- from fastNLP import Trainer
-
- #训练的轮数和batch size
- N_EPOCHS = 10
- BATCH_SIZE = 16
-
- #如果在定义trainer的时候没有传入optimizer参数,模型默认的优化器为torch.optim.Adam且learning rate为lr=4e-3
- #这里只使用了loss作为损失函数输入,感兴趣可以尝试其他损失函数(如之前自定义的loss_func)作为输入
- trainer = Trainer(model=model_cnn, train_data=train_data, dev_data=dev_data, loss=loss, metrics=metrics,
- optimizer=optimizer,n_epochs=N_EPOCHS, batch_size=BATCH_SIZE)
- trainer.train()
-
- 训练过程的输出如下::
-
- input fields after batch(if batch size is 2):
- words: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 13])
- seq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2])
- target fields after batch(if batch size is 2):
- target: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2])
-
- training epochs started 2020-02-26-16-45-40
- Evaluate data in 0.5 seconds!
- Evaluation on dev at Epoch 1/10. Step:308/3080:
- AccuracyMetric: acc=0.677752
-
- ......
-
- Evaluate data in 0.44 seconds!
- Evaluation on dev at Epoch 10/10. Step:3080/3080:
- AccuracyMetric: acc=0.725917
-
-
- In Epoch:5/Step:1540, got best dev performance:
- AccuracyMetric: acc=0.740826
- Reloaded the best model.
-
-快速测试
- 与 :class:`~fastNLP.Trainer` 对应,fastNLP 也提供了 :class:`~fastNLP.Tester` 用于快速测试,用法如下
-
- .. code-block:: python
-
- from fastNLP import Tester
-
- tester = Tester(test_data, model_cnn, metrics=AccuracyMetric())
- tester.test()
-
- 训练过程输出如下::
-
- Evaluate data in 0.43 seconds!
- [tester]
- AccuracyMetric: acc=0.773333
-
-----------------------------------
-代码下载
-----------------------------------
-
-.. raw:: html
-
- 点击下载 IPython Notebook 文件
diff --git a/docs/source/tutorials/tutorial_6_datasetiter.rst b/docs/source/tutorials/tutorial_6_datasetiter.rst
deleted file mode 100644
index eab14301..00000000
--- a/docs/source/tutorials/tutorial_6_datasetiter.rst
+++ /dev/null
@@ -1,423 +0,0 @@
-==============================================================================
-使用DataSetIter实现自定义训练过程
-==============================================================================
-
-我们使用前面介绍过的 :doc:`/tutorials/文本分类` 任务来进行详细的介绍。这里我们把数据集换成了SST2,使用 :class:`~fastNLP.DataSetIter` 类来编写自己的训练过程。
-DataSetIter初探之前的内容与 :doc:`/tutorials/tutorial_5_loss_optimizer` 中的完全一样,如已经阅读过可以跳过。
-
-.. note::
-
- 本教程中的代码没有使用 GPU 。读者可以自行修改代码,扩大数据量并使用 GPU 进行训练。
-
-数据读入和预处理
---------------------
-
-数据读入
- 我们可以使用 fastNLP :mod:`fastNLP.io` 模块中的 :class:`~fastNLP.io.SST2Pipe` 类,轻松地读取以及预处理SST2数据集。:class:`~fastNLP.io.SST2Pipe` 对象的
- :meth:`~fastNLP.io.SST2Pipe.process_from_file` 方法能够对读入的SST2数据集进行数据的预处理,方法的参数为paths, 指要处理的文件所在目录,如果paths为None,则会自动下载数 据集,函数默认paths值为None。
- 此函数返回一个 :class:`~fastNLP.io.DataBundle`,包含SST2数据集的训练集、测试集、验证集以及source端和target端的字典。其训练、测试、验证数据集含有四个 :mod:`~fastNLP.core.field` :
-
- * raw_words: 原source句子
- * target: 标签值
- * words: index之后的raw_words
- * seq_len: 句子长度
-
- 读入数据代码如下:
-
- .. code-block:: python
-
- from fastNLP.io import SST2Pipe
-
- pipe = SST2Pipe()
- databundle = pipe.process_from_file()
- vocab = databundle.get_vocab('words')
- print(databundle)
- print(databundle.get_dataset('train')[0])
- print(databundle.get_vocab('words'))
-
-
- 输出数据如下::
-
- In total 3 datasets:
- test has 1821 instances.
- train has 67349 instances.
- dev has 872 instances.
- In total 2 vocabs:
- words has 16293 entries.
- target has 2 entries.
-
- +-------------------------------------------+--------+--------------------------------------+---------+
- | raw_words | target | words | seq_len |
- +-------------------------------------------+--------+--------------------------------------+---------+
- | hide new secretions from the parental ... | 1 | [4111, 98, 12010, 38, 2, 6844, 9042] | 7 |
- +-------------------------------------------+--------+--------------------------------------+---------+
-
- Vocabulary(['hide', 'new', 'secretions', 'from', 'the']...)
-
- 除了可以对数据进行读入的Pipe类,fastNLP还提供了读入和下载数据的Loader类,不同数据集的Pipe和Loader及其用法详见 :doc:`/tutorials/tutorial_4_load_dataset` 。
-
-数据集分割
- 由于SST2数据集的测试集并不带有标签数值,故我们分割出一部分训练集作为测试集。下面这段代码展示了 :meth:`~fastNLP.DataSet.split` 的使用方法,
- 为了能让读者快速运行完整个教程,我们只取了训练集的前5000个数据。
-
- .. code-block:: python
-
- train_data = databundle.get_dataset('train')[:5000]
- train_data, test_data = train_data.split(0.015)
- dev_data = databundle.get_dataset('dev')
- print(len(train_data),len(dev_data),len(test_data))
-
- 输出结果为::
-
- 4925 872 75
-
-数据集 :meth:`~fastNLP.DataSet.set_input` 和 :meth:`~fastNLP.DataSet.set_target` 函数
- :class:`~fastNLP.io.SST2Pipe` 类的 :meth:`~fastNLP.io.SST2Pipe.process_from_file` 方法在预处理过程中还将训练、测试、验证集
- 的 `words` 、`seq_len` :mod:`~fastNLP.core.field` 设定为input,同时将`target` :mod:`~fastNLP.core.field` 设定为target。
- 我们可以通过 :class:`~fastNLP.core.Dataset` 类的 :meth:`~fastNLP.core.Dataset.print_field_meta` 方法查看各个
- :mod:`~fastNLP.core.field` 的设定情况,代码如下:
-
- .. code-block:: python
-
- train_data.print_field_meta()
-
- 输出结果为::
-
- +-------------+-----------+--------+-------+---------+
- | field_names | raw_words | target | words | seq_len |
- +-------------+-----------+--------+-------+---------+
- | is_input | False | False | True | True |
- | is_target | False | True | False | False |
- | ignore_type | | False | False | False |
- | pad_value | | 0 | 0 | 0 |
- +-------------+-----------+--------+-------+---------+
-
- 其中is_input和is_target分别表示是否为input和target。ignore_type为true时指使用 :class:`~fastNLP.DataSetIter` 取出batch数
- 据时fastNLP不会进行自动padding,pad_value指对应 :mod:`~fastNLP.core.field` padding所用的值,这两者只有当
- :mod:`~fastNLP.core.field` 设定为input或者target的时候才有存在的意义。
-
- is_input为true的 :mod:`~fastNLP.core.field` 在 :class:`~fastNLP.DataSetIter` 迭代取出的 batch_x 中,
- 而 is_target为true的 :mod:`~fastNLP.core.field` 在 :class:`~fastNLP.DataSetIter` 迭代取出的 batch_y 中。
- 具体分析见下面DataSetIter的介绍过程。
-
-
-评价指标
- 训练模型需要提供一个评价指标。这里使用准确率做为评价指标。
-
- * ``pred`` 参数对应的是模型的 forward 方法返回的 dict 中的一个 key 的名字。
- * ``target`` 参数对应的是 :class:`~fastNLP.DataSet` 中作为标签的 :mod:`~fastNLP.core.field` 的名字。
-
- 这里我们用 :class:`~fastNLP.Const` 来辅助命名,如果你自己编写模型中 forward 方法的返回值或
- 数据集中 :mod:`~fastNLP.core.field` 的名字与本例不同, 你可以把 ``pred`` 参数和 ``target`` 参数设定符合自己代码的值。代码如下:
-
- .. code-block:: python
-
- from fastNLP import AccuracyMetric
- from fastNLP import Const
-
- # metrics=AccuracyMetric() 在本例中与下面这行代码等价
- metrics=AccuracyMetric(pred=Const.OUTPUT, target=Const.TARGET)
-
-
-DataSetIter初探
---------------------------
-
-DataSetIter
- fastNLP定义的 :class:`~fastNLP.DataSetIter` 类,用于定义一个batch,并实现batch的多种功能,在初始化时传入的参数有:
-
- * dataset: :class:`~fastNLP.DataSet` 对象, 数据集
- * batch_size: 取出的batch大小
- * sampler: 规定使用的 :class:`~fastNLP.Sampler` 若为 None, 使用 :class:`~fastNLP.RandomSampler` (Default: None)
- * as_numpy: 若为 True, 输出batch为 `numpy.array`. 否则为 `torch.Tensor` (Default: False)
- * prefetch: 若为 True使用多进程预先取出下一batch. (Default: False)
-
-sampler
- fastNLP 实现的采样器有:
-
- * :class:`~fastNLP.BucketSampler` 可以随机地取出长度相似的元素 【初始化参数: num_buckets:bucket的数量; batch_size:batch大小; seq_len_field_name:dataset中对应序列长度的 :mod:`~fastNLP.core.field` 的名字】
- * SequentialSampler: 顺序取出元素的采样器【无初始化参数】
- * RandomSampler:随机化取元素的采样器【无初始化参数】
-
-Padder
- 在fastNLP里,pad是与一个 :mod:`~fastNLP.core.field` 绑定的。即不同的 :mod:`~fastNLP.core.field` 可以使用不同的pad方式,比如在英文任务中word需要的pad和
- character的pad方式往往是不同的。fastNLP是通过一个叫做 :class:`~fastNLP.Padder` 的子类来完成的。
- 默认情况下,所有field使用 :class:`~fastNLP.AutoPadder`
- 。大多数情况下直接使用 :class:`~fastNLP.AutoPadder` 就可以了。
- 如果 :class:`~fastNLP.AutoPadder` 或 :class:`~fastNLP.EngChar2DPadder` 无法满足需求,
- 也可以自己写一个 :class:`~fastNLP.Padder` 。
-
-DataSetIter自动padding
- 以下代码展示了DataSetIter的简单使用:
-
- .. code-block:: python
-
- from fastNLP import BucketSampler
- from fastNLP import DataSetIter
-
- tmp_data = dev_data[:10]
- # 定义一个Batch,传入DataSet,规定batch_size和去batch的规则。
- # 顺序(Sequential),随机(Random),相似长度组成一个batch(Bucket)
- sampler = BucketSampler(batch_size=2, seq_len_field_name='seq_len')
- batch = DataSetIter(batch_size=2, dataset=tmp_data, sampler=sampler)
- for batch_x, batch_y in batch:
- print("batch_x: ",batch_x)
- print("batch_y: ", batch_y)
-
- 输出结果如下::
-
- batch_x: {'words': tensor([[ 13, 830, 7746, 174, 3, 47, 6, 83, 5752, 15,
- 2177, 15, 63, 57, 406, 84, 1009, 4973, 27, 17,
- 13785, 3, 533, 3687, 15623, 39, 375, 8, 15624, 8,
- 1323, 4398, 7],
- [ 1045, 11113, 16, 104, 5, 4, 176, 1824, 1704, 3,
- 2, 18, 11, 4, 1018, 432, 143, 33, 245, 308,
- 7, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0]]), 'seq_len': tensor([33, 21])}
- batch_y: {'target': tensor([1, 0])}
- batch_x: {'words': tensor([[ 14, 10, 4, 311, 5, 154, 1418, 609, 7],
- [ 14, 10, 437, 32, 78, 3, 78, 437, 7]]), 'seq_len': tensor([9, 9])}
- batch_y: {'target': tensor([0, 1])}
- batch_x: {'words': tensor([[ 4, 277, 685, 18, 7],
- [15618, 3204, 5, 1675, 0]]), 'seq_len': tensor([5, 4])}
- batch_y: {'target': tensor([1, 1])}
- batch_x: {'words': tensor([[ 2, 155, 3, 4426, 3, 239, 3, 739, 5, 1136,
- 41, 43, 2427, 736, 2, 648, 10, 15620, 2285, 7],
- [ 24, 95, 28, 46, 8, 336, 38, 239, 8, 2133,
- 2, 18, 10, 15622, 1421, 6, 61, 5, 387, 7]]), 'seq_len': tensor([20, 20])}
- batch_y: {'target': tensor([0, 0])}
- batch_x: {'words': tensor([[ 879, 96, 8, 1026, 12, 8067, 11, 13623, 8, 15619,
- 4, 673, 662, 15, 4, 1154, 240, 639, 417, 7],
- [ 45, 752, 327, 180, 10, 15621, 16, 72, 8904, 9,
- 1217, 7, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([20, 12])}
- batch_y: {'target': tensor([0, 1])}
-
- 可以看到那些设定为input的 :mod:`~fastNLP.core.field` 都出现在batch_x中,而设定为target的 :mod:`~fastNLP.core.field` 则出现在batch_y中。同时对于同一个batch_x中的两个数据,长度偏短的那个会被自动padding到和长度偏长的句子长度一致,默认的padding值为0。
-
-Dataset改变padding值
- 可以通过 :meth:`~fastNLP.core.Dataset.set_pad_val` 方法修改默认的pad值,代码如下:
-
- .. code-block:: python
-
- tmp_data.set_pad_val('words',-1)
- batch = DataSetIter(batch_size=2, dataset=tmp_data, sampler=sampler)
- for batch_x, batch_y in batch:
- print("batch_x: ",batch_x)
- print("batch_y: ", batch_y)
-
- 输出结果如下::
-
- batch_x: {'words': tensor([[ 13, 830, 7746, 174, 3, 47, 6, 83, 5752, 15,
- 2177, 15, 63, 57, 406, 84, 1009, 4973, 27, 17,
- 13785, 3, 533, 3687, 15623, 39, 375, 8, 15624, 8,
- 1323, 4398, 7],
- [ 1045, 11113, 16, 104, 5, 4, 176, 1824, 1704, 3,
- 2, 18, 11, 4, 1018, 432, 143, 33, 245, 308,
- 7, -1, -1, -1, -1, -1, -1, -1, -1, -1,
- -1, -1, -1]]), 'seq_len': tensor([33, 21])}
- batch_y: {'target': tensor([1, 0])}
- batch_x: {'words': tensor([[ 14, 10, 4, 311, 5, 154, 1418, 609, 7],
- [ 14, 10, 437, 32, 78, 3, 78, 437, 7]]), 'seq_len': tensor([9, 9])}
- batch_y: {'target': tensor([0, 1])}
- batch_x: {'words': tensor([[ 2, 155, 3, 4426, 3, 239, 3, 739, 5, 1136,
- 41, 43, 2427, 736, 2, 648, 10, 15620, 2285, 7],
- [ 24, 95, 28, 46, 8, 336, 38, 239, 8, 2133,
- 2, 18, 10, 15622, 1421, 6, 61, 5, 387, 7]]), 'seq_len': tensor([20, 20])}
- batch_y: {'target': tensor([0, 0])}
- batch_x: {'words': tensor([[ 4, 277, 685, 18, 7],
- [15618, 3204, 5, 1675, -1]]), 'seq_len': tensor([5, 4])}
- batch_y: {'target': tensor([1, 1])}
- batch_x: {'words': tensor([[ 879, 96, 8, 1026, 12, 8067, 11, 13623, 8, 15619,
- 4, 673, 662, 15, 4, 1154, 240, 639, 417, 7],
- [ 45, 752, 327, 180, 10, 15621, 16, 72, 8904, 9,
- 1217, 7, -1, -1, -1, -1, -1, -1, -1, -1]]), 'seq_len': tensor([20, 12])}
- batch_y: {'target': tensor([0, 1])}
-
- 可以看到使用了-1进行padding。
-
-Dataset个性化padding
- 如果我们希望对某一些 :mod:`~fastNLP.core.field` 进行个性化padding,可以自己构造Padder类,并使用 :meth:`~fastNLP.core.Dataset.set_padder` 函数修改padder来实现。下面通过构造一个将数据padding到固定长度的padder进行展示:
-
- .. code-block:: python
-
- from fastNLP.core.field import Padder
- import numpy as np
- class FixLengthPadder(Padder):
- def __init__(self, pad_val=0, length=None):
- super().__init__(pad_val=pad_val)
- self.length = length
- assert self.length is not None, "Creating FixLengthPadder with no specific length!"
-
- def __call__(self, contents, field_name, field_ele_dtype, dim):
- #计算当前contents中的最大长度
- max_len = max(map(len, contents))
- #如果当前contents中的最大长度大于指定的padder length的话就报错
- assert max_len <= self.length, "Fixed padder length smaller than actual length! with length {}".format(max_len)
- array = np.full((len(contents), self.length), self.pad_val, dtype=field_ele_dtype)
- for i, content_i in enumerate(contents):
- array[i, :len(content_i)] = content_i
- return array
-
- #设定FixLengthPadder的固定长度为40
- tmp_padder = FixLengthPadder(pad_val=0,length=40)
- #利用dataset的set_padder函数设定words field的padder
- tmp_data.set_padder('words',tmp_padder)
- batch = DataSetIter(batch_size=2, dataset=tmp_data, sampler=sampler)
- for batch_x, batch_y in batch:
- print("batch_x: ",batch_x)
- print("batch_y: ", batch_y)
-
- 输出结果如下::
-
- batch_x: {'words': tensor([[ 45, 752, 327, 180, 10, 15621, 16, 72, 8904, 9,
- 1217, 7, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [ 879, 96, 8, 1026, 12, 8067, 11, 13623, 8, 15619,
- 4, 673, 662, 15, 4, 1154, 240, 639, 417, 7,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([12, 20])}
- batch_y: {'target': tensor([1, 0])}
- batch_x: {'words': tensor([[ 13, 830, 7746, 174, 3, 47, 6, 83, 5752, 15,
- 2177, 15, 63, 57, 406, 84, 1009, 4973, 27, 17,
- 13785, 3, 533, 3687, 15623, 39, 375, 8, 15624, 8,
- 1323, 4398, 7, 0, 0, 0, 0, 0, 0, 0],
- [ 1045, 11113, 16, 104, 5, 4, 176, 1824, 1704, 3,
- 2, 18, 11, 4, 1018, 432, 143, 33, 245, 308,
- 7, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([33, 21])}
- batch_y: {'target': tensor([1, 0])}
- batch_x: {'words': tensor([[ 14, 10, 4, 311, 5, 154, 1418, 609, 7, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0],
- [ 14, 10, 437, 32, 78, 3, 78, 437, 7, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0]]), 'seq_len': tensor([9, 9])}
- batch_y: {'target': tensor([0, 1])}
- batch_x: {'words': tensor([[ 2, 155, 3, 4426, 3, 239, 3, 739, 5, 1136,
- 41, 43, 2427, 736, 2, 648, 10, 15620, 2285, 7,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [ 24, 95, 28, 46, 8, 336, 38, 239, 8, 2133,
- 2, 18, 10, 15622, 1421, 6, 61, 5, 387, 7,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([20, 20])}
- batch_y: {'target': tensor([0, 0])}
- batch_x: {'words': tensor([[ 4, 277, 685, 18, 7, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [15618, 3204, 5, 1675, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([5, 4])}
- batch_y: {'target': tensor([1, 1])}
-
- 在这里所有的 `words` 都被pad成了长度为40的list。
-
-
-使用DataSetIter自己编写训练过程
-------------------------------------
- 如果你想用类似 PyTorch 的使用方法,自己编写训练过程,可以参考下面这段代码。
- 其中使用了 fastNLP 提供的 :class:`~fastNLP.DataSetIter` 来获得小批量训练的小批量数据,
- 使用 :class:`~fastNLP.BucketSampler` 做为 :class:`~fastNLP.DataSetIter` 的参数来选择采样的方式。
-
- 以下代码使用BucketSampler作为 :class:`~fastNLP.DataSetIter` 初始化的输入,运用 :class:`~fastNLP.DataSetIter` 自己写训练程序
-
- .. code-block:: python
-
- from fastNLP import BucketSampler
- from fastNLP import DataSetIter
- from fastNLP.models import CNNText
- from fastNLP import Tester
- import torch
- import time
-
- embed_dim = 100
- model = CNNText((len(vocab),embed_dim), num_classes=2, dropout=0.1)
-
- def train(epoch, data, devdata):
- optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
- lossfunc = torch.nn.CrossEntropyLoss()
- batch_size = 32
-
- # 定义一个Batch,传入DataSet,规定batch_size和去batch的规则。
- # 顺序(Sequential),随机(Random),相似长度组成一个batch(Bucket)
- train_sampler = BucketSampler(batch_size=batch_size, seq_len_field_name='seq_len')
- train_batch = DataSetIter(batch_size=batch_size, dataset=data, sampler=train_sampler)
-
- start_time = time.time()
- print("-"*5+"start training"+"-"*5)
- for i in range(epoch):
- loss_list = []
- for batch_x, batch_y in train_batch:
- optimizer.zero_grad()
- output = model(batch_x['words'])
- loss = lossfunc(output['pred'], batch_y['target'])
- loss.backward()
- optimizer.step()
- loss_list.append(loss.item())
-
- #这里verbose如果为0,在调用Tester对象的test()函数时不输出任何信息,返回评估信息; 如果为1,打印出验证结果,返回评估信息
- #在调用过Tester对象的test()函数后,调用其_format_eval_results(res)函数,结构化输出验证结果
- tester_tmp = Tester(devdata, model, metrics=AccuracyMetric(), verbose=0)
- res=tester_tmp.test()
-
- print('Epoch {:d} Avg Loss: {:.2f}'.format(i, sum(loss_list) / len(loss_list)),end=" ")
- print(tester_tmp._format_eval_results(res),end=" ")
- print('{:d}ms'.format(round((time.time()-start_time)*1000)))
- loss_list.clear()
-
- train(10, train_data, dev_data)
- #使用tester进行快速测试
- tester = Tester(test_data, model, metrics=AccuracyMetric())
- tester.test()
-
- 这段代码的输出如下::
-
- -----start training-----
-
- Evaluate data in 2.68 seconds!
- Epoch 0 Avg Loss: 0.66 AccuracyMetric: acc=0.708716 29307ms
-
- Evaluate data in 0.38 seconds!
- Epoch 1 Avg Loss: 0.41 AccuracyMetric: acc=0.770642 52200ms
-
- Evaluate data in 0.51 seconds!
- Epoch 2 Avg Loss: 0.16 AccuracyMetric: acc=0.747706 70268ms
-
- Evaluate data in 0.96 seconds!
- Epoch 3 Avg Loss: 0.06 AccuracyMetric: acc=0.741972 90349ms
-
- Evaluate data in 1.04 seconds!
- Epoch 4 Avg Loss: 0.03 AccuracyMetric: acc=0.740826 114250ms
-
- Evaluate data in 0.8 seconds!
- Epoch 5 Avg Loss: 0.02 AccuracyMetric: acc=0.738532 134742ms
-
- Evaluate data in 0.65 seconds!
- Epoch 6 Avg Loss: 0.01 AccuracyMetric: acc=0.731651 154503ms
-
- Evaluate data in 0.8 seconds!
- Epoch 7 Avg Loss: 0.01 AccuracyMetric: acc=0.738532 175397ms
-
- Evaluate data in 0.36 seconds!
- Epoch 8 Avg Loss: 0.01 AccuracyMetric: acc=0.733945 192384ms
-
- Evaluate data in 0.84 seconds!
- Epoch 9 Avg Loss: 0.01 AccuracyMetric: acc=0.744266 214417ms
-
- Evaluate data in 0.04 seconds!
- [tester]
- AccuracyMetric: acc=0.786667
-
-
-
-----------------------------------
-代码下载
-----------------------------------
-
-.. raw:: html
-
- 点击下载 IPython Notebook 文件
diff --git a/docs/source/tutorials/tutorial_7_metrics.rst b/docs/source/tutorials/tutorial_7_metrics.rst
deleted file mode 100644
index 5ab09c24..00000000
--- a/docs/source/tutorials/tutorial_7_metrics.rst
+++ /dev/null
@@ -1,135 +0,0 @@
-===============================
-使用Metric快速评测你的模型
-===============================
-
-在进行训练时,fastNLP提供了各种各样的 :mod:`~fastNLP.core.metrics` 。
-如前面的教程中所介绍,:class:`~fastNLP.AccuracyMetric` 类的对象被直接传到 :class:`~fastNLP.Trainer` 中用于训练
-
-.. code-block:: python
-
- trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,
- loss=loss, device=device, metrics=metric)
- trainer.train()
-
-除了 :class:`~fastNLP.AccuracyMetric` 之外,:class:`~fastNLP.SpanFPreRecMetric` 也是一种非常见的评价指标,
-例如在序列标注问题中,常以span的方式计算 F-measure, precision, recall。
-
-另外,fastNLP 还实现了用于抽取式QA(如SQuAD)的metric :class:`~fastNLP.ExtractiveQAMetric`。
-用户可以参考下面这个表格,点击第一列查看各个 :mod:`~fastNLP.core.metrics` 的详细文档。
-
-.. csv-table::
- :header: 名称, 介绍
-
- :class:`~fastNLP.core.metrics.MetricBase` , 自定义metrics需继承的基类
- :class:`~fastNLP.core.metrics.AccuracyMetric` , 简单的正确率metric
- :class:`~fastNLP.core.metrics.SpanFPreRecMetric` , "同时计算 F-measure, precision, recall 值的 metric"
- :class:`~fastNLP.core.metrics.ExtractiveQAMetric` , 用于抽取式QA任务 的metric
-
-更多的 :mod:`~fastNLP.core.metrics` 正在被添加到 fastNLP 当中,敬请期待。
-
-------------------------------
-定义自己的metrics
-------------------------------
-
-在定义自己的metrics类时需继承 fastNLP 的 :class:`~fastNLP.core.metrics.MetricBase`,
-并覆盖写入 ``evaluate`` 和 ``get_metric`` 方法。
-
- evaluate(xxx) 中传入一个批次的数据,将针对一个批次的预测结果做评价指标的累计
-
- get_metric(xxx) 当所有数据处理完毕时调用该方法,它将根据 evaluate函数累计的评价指标统计量来计算最终的评价结果
-
-以分类问题中,accuracy 计算为例,假设 model 的 `forward` 返回 dict 中包含 `pred` 这个 key , 并且该 key 需要用于 accuracy::
-
- class Model(nn.Module):
- def __init__(xxx):
- # do something
- def forward(self, xxx):
- # do something
- return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes
-
-假设dataset中 `target` 这个 field 是需要预测的值,并且该 field 被设置为了 target 对应的 `AccMetric` 可以按如下的定义( Version 1, 只使用这一次)::
-
- from fastNLP import MetricBase
-
- class AccMetric(MetricBase):
-
- def __init__(self):
- super().__init__()
- # 根据你的情况自定义指标
- self.total = 0
- self.acc_count = 0
-
- # evaluate的参数需要和DataSet 中 field 名以及模型输出的结果 field 名一致,不然找不到对应的value
- # pred, target 的参数是 fastNLP 的默认配置
- def evaluate(self, pred, target):
- # dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric
- self.total += target.size(0)
- self.acc_count += target.eq(pred).sum().item()
-
- def get_metric(self, reset=True): # 在这里定义如何计算metric
- acc = self.acc_count/self.total
- if reset: # 是否清零以便重新计算
- self.acc_count = 0
- self.total = 0
- return {'acc': acc}
- # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中
-
-
-如果需要复用 metric,比如下一次使用 `AccMetric` 时,dataset中目标field不叫 `target` 而叫 `y` ,或者model的输出不是 `pred` (Version 2)::
-
- class AccMetric(MetricBase):
- def __init__(self, pred=None, target=None):
- """
- 假设在另一场景使用时,目标field叫y,model给出的key为pred_y。则只需要在初始化AccMetric时,
- acc_metric = AccMetric(pred='pred_y', target='y')即可。
- 当初始化为acc_metric = AccMetric() 时,fastNLP会直接使用 'pred', 'target' 作为key去索取对应的的值
- """
-
- super().__init__()
-
- # 如果没有注册该则效果与 Version 1 就是一样的
- self._init_param_map(pred=pred, target=target) # 该方法会注册 pred 和 target . 仅需要注册evaluate()方法会用到的参数名即可
-
- # 根据你的情况自定义指标
- self.total = 0
- self.acc_count = 0
-
- # evaluate的参数需要和DataSet 中 field 名以及模型输出的结果 field 名一致,不然找不到对应的value
- # pred, target 的参数是 fastNLP 的默认配置
- def evaluate(self, pred, target):
- # dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric
- self.total += target.size(0)
- self.acc_count += target.eq(pred).sum().item()
-
- def get_metric(self, reset=True): # 在这里定义如何计算metric
- acc = self.acc_count/self.total
- if reset: # 是否清零以便重新计算
- self.acc_count = 0
- self.total = 0
- return {'acc': acc}
- # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中
-
-``MetricBase`` 将会在输入的字典 ``pred_dict`` 和 ``target_dict`` 中进行检查.
-``pred_dict`` 是模型当中 ``forward()`` 函数或者 ``predict()`` 函数的返回值.
-``target_dict`` 是DataSet当中的ground truth, 判定ground truth的条件是field的 ``is_target`` 被设置为True.
-
-``MetricBase`` 会进行以下的类型检测:
-
-1. self.evaluate当中是否有 varargs, 这是不支持的.
-2. self.evaluate当中所需要的参数是否既不在 ``pred_dict`` 也不在 ``target_dict`` .
-3. self.evaluate当中所需要的参数是否既在 ``pred_dict`` 也在 ``target_dict`` .
-
-除此以外,在参数被传入self.evaluate以前,这个函数会检测 ``pred_dict`` 和 ``target_dict`` 当中没有被用到的参数
-如果kwargs是self.evaluate的参数,则不会检测
-
-self.evaluate将计算一个批次(batch)的评价指标,并累计。 没有返回值
-self.get_metric将统计当前的评价指标并返回评价结果, 返回值需要是一个dict, key是指标名称,value是指标的值
-
-
-----------------------------------
-代码下载
-----------------------------------
-
-.. raw:: html
-
- 点击下载 IPython Notebook 文件
diff --git a/docs/source/tutorials/tutorial_8_modules_models.rst b/docs/source/tutorials/tutorial_8_modules_models.rst
deleted file mode 100644
index 226c3be5..00000000
--- a/docs/source/tutorials/tutorial_8_modules_models.rst
+++ /dev/null
@@ -1,193 +0,0 @@
-======================================
-使用Modules和Models快速搭建自定义模型
-======================================
-
-:mod:`~fastNLP.modules` 和 :mod:`~fastNLP.models` 用于构建 fastNLP 所需的神经网络模型,它可以和 torch.nn 中的模型一起使用。
-下面我们会分三节介绍编写构建模型的具体方法。
-
-
-使用 models 中的模型
-----------------------
-
-fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models.CNNText` 、
-:class:`~fastNLP.models.SeqLabeling` 等完整的模型,以供用户直接使用。
-以文本分类的任务为例,我们从 models 中导入 :class:`~fastNLP.models.CNNText` 模型,用它进行训练。
-
-.. code-block:: python
-
- from fastNLP.models import CNNText
-
- model_cnn = CNNText((len(vocab),100), num_classes=2, dropout=0.1)
-
- trainer = Trainer(train_data=train_data, dev_data=dev_data, metrics=metric,
- loss=loss, device=device, model=model_cnn)
- trainer.train()
-
-在 iPython 环境输入 `model_cnn` ,我们可以看到 ``model_cnn`` 的网络结构
-
-.. parsed-literal::
-
- CNNText(
- (embed): Embedding(
- (embed): Embedding(16292, 100)
- (dropout): Dropout(p=0.0, inplace=False)
- )
- (conv_pool): ConvMaxpool(
- (convs): ModuleList(
- (0): Conv1d(100, 30, kernel_size=(1,), stride=(1,), bias=False)
- (1): Conv1d(100, 40, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
- (2): Conv1d(100, 50, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
- )
- )
- (dropout): Dropout(p=0.1, inplace=False)
- (fc): Linear(in_features=120, out_features=2, bias=True)
- )
-
-FastNLP 中内置的 models 如下表所示,您可以点击具体的名称查看详细的 API:
-
-.. csv-table::
- :header: 名称, 介绍
-
- :class:`~fastNLP.models.CNNText` , 使用 CNN 进行文本分类的模型
- :class:`~fastNLP.models.SeqLabeling` , 简单的序列标注模型
- :class:`~fastNLP.models.AdvSeqLabel` , 更大网络结构的序列标注模型
- :class:`~fastNLP.models.ESIM` , ESIM 模型的实现
- :class:`~fastNLP.models.StarTransEnc` , 带 word-embedding的Star-Transformer模 型
- :class:`~fastNLP.models.STSeqLabel` , 用于序列标注的 Star-Transformer 模型
- :class:`~fastNLP.models.STNLICls` ,用于自然语言推断 (NLI) 的 Star-Transformer 模型
- :class:`~fastNLP.models.STSeqCls` , 用于分类任务的 Star-Transformer 模型
- :class:`~fastNLP.models.BiaffineParser` , Biaffine 依存句法分析网络的实现
- :class:`~fastNLP.models.BiLSTMCRF`, 使用BiLSTM与CRF进行序列标注
-
-
-使用 nn.torch 编写模型
-----------------------------
-
-FastNLP 完全支持使用 pyTorch 编写的模型,但与 pyTorch 中编写模型的常见方法不同,
-用于 fastNLP 的模型中 forward 函数需要返回一个字典,字典中至少需要包含 ``pred`` 这个字段。
-
-下面是使用 pyTorch 中的 torch.nn 模块编写的文本分类,注意观察代码中标注的向量维度。
-由于 pyTorch 使用了约定俗成的维度设置,使得 forward 中需要多次处理维度顺序
-
-.. code-block:: python
-
- import torch
- import torch.nn as nn
-
- class LSTMText(nn.Module):
- def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):
- super().__init__()
-
- self.embedding = nn.Embedding(vocab_size, embedding_dim)
- self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=True, dropout=dropout)
- self.fc = nn.Linear(hidden_dim * 2, output_dim)
- self.dropout = nn.Dropout(dropout)
-
- def forward(self, words):
- # (input) words : (batch_size, seq_len)
- words = words.permute(1,0)
- # words : (seq_len, batch_size)
-
- embedded = self.dropout(self.embedding(words))
- # embedded : (seq_len, batch_size, embedding_dim)
- output, (hidden, cell) = self.lstm(embedded)
- # output: (seq_len, batch_size, hidden_dim * 2)
- # hidden: (num_layers * 2, batch_size, hidden_dim)
- # cell: (num_layers * 2, batch_size, hidden_dim)
-
- hidden = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)
- hidden = self.dropout(hidden)
- # hidden: (batch_size, hidden_dim * 2)
-
- pred = self.fc(hidden.squeeze(0))
- # result: (batch_size, output_dim)
- return {"pred":pred}
-
-我们同样可以在 iPython 环境中查看这个模型的网络结构
-
-.. parsed-literal::
-
- LSTMText(
- (embedding): Embedding(16292, 100)
- (lstm): LSTM(100, 64, num_layers=2, dropout=0.5, bidirectional=True)
- (fc): Linear(in_features=128, out_features=2, bias=True)
- (dropout): Dropout(p=0.5, inplace=False)
- )
-
-
-使用 modules 编写模型
-----------------------------
-
-下面我们使用 :mod:`fastNLP.modules` 中的组件来构建同样的网络。由于 fastNLP 统一把 ``batch_size`` 放在第一维,
-在编写代码的过程中会有一定的便利。
-
-.. code-block:: python
-
- from fastNLP.modules import Embedding, LSTM, MLP
-
- class MyText(nn.Module):
- def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):
- super().__init__()
-
- self.embedding = Embedding((vocab_size, embedding_dim))
- self.lstm = LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=True)
- self.mlp = MLP([hidden_dim*2,output_dim], dropout=dropout)
-
- def forward(self, words):
- embedded = self.embedding(words)
- _,(hidden,_) = self.lstm(embedded)
- pred = self.mlp(torch.cat((hidden[-1],hidden[-2]),dim=1))
- return {"pred":pred}
-
-我们自己编写模型的网络结构如下
-
-.. parsed-literal::
-
- MyText(
- (embedding): Embedding(
- (embed): Embedding(16292, 100)
- (dropout): Dropout(p=0.0, inplace=False)
- )
- (lstm): LSTM(
- (lstm): LSTM(100, 64, num_layers=2, batch_first=True, bidirectional=True)
- )
- (mlp): MLP(
- (hiddens): ModuleList()
- (output): Linear(in_features=128, out_features=2, bias=True)
- (dropout): Dropout(p=0.5, inplace=False)
- )
- )
-
-FastNLP 中包含的各种模块如下表,您可以点击具体的名称查看详细的 API,也可以通过 :doc:`/fastNLP.modules` 进行了解。
-
-.. csv-table::
- :header: 名称, 介绍
-
- :class:`~fastNLP.modules.ConvolutionCharEncoder` , char级别的卷积 encoder
- :class:`~fastNLP.modules.LSTMCharEncoder` , char级别基于LSTM的 encoder
- :class:`~fastNLP.modules.ConvMaxpool` , 结合了Convolution和Max-Pooling于一体的模块
- :class:`~fastNLP.modules.LSTM` , LSTM模块, 轻量封装了PyTorch的LSTM
- :class:`~fastNLP.modules.StarTransformer` , Star-Transformer 的encoder部分
- :class:`~fastNLP.modules.TransformerEncoder` , Transformer的encoder模块,不包含embedding层
- :class:`~fastNLP.modules.VarRNN` , Variational Dropout RNN 模块
- :class:`~fastNLP.modules.VarLSTM` , Variational Dropout LSTM 模块
- :class:`~fastNLP.modules.VarGRU` , Variational Dropout GRU 模块
- :class:`~fastNLP.modules.MaxPool` , Max-pooling模块
- :class:`~fastNLP.modules.MaxPoolWithMask` , 带mask矩阵的max pooling。在做 max-pooling的时候不会考虑mask值为0的位置。
- :class:`~fastNLP.modules.AvgPool` , Average-pooling模块
- :class:`~fastNLP.modules.AvgPoolWithMask` , 带mask矩阵的average pooling。在做 average-pooling的时候不会考虑mask值为0的位置。
- :class:`~fastNLP.modules.MultiHeadAttention` , MultiHead Attention 模块
- :class:`~fastNLP.modules.MLP` , 简单的多层感知器模块
- :class:`~fastNLP.modules.ConditionalRandomField` , 条件随机场模块
- :class:`~fastNLP.modules.viterbi_decode` , 给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 (与 :class:`~fastNLP.modules.ConditionalRandomField` 配合使用)
- :class:`~fastNLP.modules.allowed_transitions` , 给定一个id到label的映射表,返回所有可以跳转的列表(与 :class:`~fastNLP.modules.ConditionalRandomField` 配合使用)
- :class:`~fastNLP.modules.TimestepDropout` , 简单包装过的Dropout 组件
-
-
-----------------------------------
-代码下载
-----------------------------------
-
-.. raw:: html
-
- 点击下载 IPython Notebook 文件
diff --git a/docs/source/tutorials/tutorial_9_callback.rst b/docs/source/tutorials/tutorial_9_callback.rst
deleted file mode 100644
index 5ecdb88d..00000000
--- a/docs/source/tutorials/tutorial_9_callback.rst
+++ /dev/null
@@ -1,140 +0,0 @@
-===================================================
-使用 Callback 自定义你的训练过程
-===================================================
-
-- `什么是Callback`_
-- `使用 Callback`_
-- `fastNLP 中的 Callback`_
-- `自定义 Callback`_
-
-
-什么是Callback
----------------------
-
-:class:`~fastNLP.core.callback.Callback` 是与 :class:`~fastNLP.core.trainer.Trainer` 紧密结合的模块,利用 Callback 可以在 :class:`~fastNLP.core.trainer.Trainer` 训练时,加入自定义的操作,比如梯度裁剪,学习率调节,测试模型的性能等。定义的 Callback 会在训练的特定阶段被调用。
-
-fastNLP 中提供了很多常用的 :class:`~fastNLP.core.callback.Callback` ,开箱即用。
-
-
-使用 Callback
----------------------
-
-使用 Callback 很简单,将需要的 callback 按 list 存储,以对应参数 ``callbacks`` 传入对应的 Trainer。Trainer 在训练时就会自动执行这些 Callback 指定的操作了。
-
-
-.. code-block:: python
-
- from fastNLP import (Callback, EarlyStopCallback,
- Trainer, CrossEntropyLoss, AccuracyMetric)
- from fastNLP.models import CNNText
- import torch.cuda
-
- # prepare data
- def get_data():
- from fastNLP.io import ChnSentiCorpPipe as pipe
- data = pipe().process_from_file()
- print(data)
- data.rename_field('chars', 'words')
- train_data = data.get_dataset('train')
- dev_data = data.get_dataset('dev')
- test_data = data.get_dataset('test')
- vocab = data.get_vocab('words')
- tgt_vocab = data.get_vocab('target')
- return train_data, dev_data, test_data, vocab, tgt_vocab
-
- # prepare model
- train_data, dev_data, _, vocab, tgt_vocab = get_data()
- device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
- model = CNNText((len(vocab),50), num_classes=len(tgt_vocab))
-
- # define callback
- callbacks=[EarlyStopCallback(5)]
-
- # pass callbacks to Trainer
- def train_with_callback(cb_list):
- trainer = Trainer(
- device=device,
- n_epochs=3,
- model=model,
- train_data=train_data,
- dev_data=dev_data,
- loss=CrossEntropyLoss(),
- metrics=AccuracyMetric(),
- callbacks=cb_list,
- check_code_level=-1
- )
- trainer.train()
-
- train_with_callback(callbacks)
-
-
-
-fastNLP 中的 Callback
----------------------
-
-fastNLP 中提供了很多常用的 Callback,如梯度裁剪,训练时早停和测试验证集,fitlog 等等。具体 Callback 请参考 :mod:`fastNLP.core.callback`
-
-.. code-block:: python
-
- from fastNLP import EarlyStopCallback, GradientClipCallback, EvaluateCallback
- callbacks = [
- EarlyStopCallback(5),
- GradientClipCallback(clip_value=5, clip_type='value'),
- EvaluateCallback(dev_data)
- ]
-
- train_with_callback(callbacks)
-
-自定义 Callback
----------------------
-
-这里我们以一个简单的 Callback作为例子,它的作用是打印每一个 Epoch 平均训练 loss。
-
-1. 创建 Callback
-
- 要自定义 Callback,我们要实现一个类,继承 :class:`~fastNLP.core.callback.Callback` 。这里我们定义 ``MyCallBack`` ,继承 fastNLP.Callback 。
-
-2. 指定 Callback 调用的阶段
-
- Callback 中所有以 `on_` 开头的类方法会在 Trainer 的训练中在特定阶段调用。 如 on_train_begin() 会在训练开始时被调用,on_epoch_end()
- 会在每个 epoch 结束时调用。 具体有哪些类方法,参见 :class:`~fastNLP.core.callback.Callback` 文档。这里, MyCallBack 在求得loss时调用 on_backward_begin() 记录
- 当前 loss,在每一个 epoch 结束时调用 on_epoch_end() ,求当前 epoch 平均loss并输出。
-
-3. 使用 Callback 的属性访问 Trainer 的内部信息
-
- 为了方便使用,可以使用 :class:`~fastNLP.core.callback.Callback` 的属性,访问 :class:`~fastNLP.core.trainer.Trainer` 中的对应信息,如 optimizer, epoch, n_epochs,分别对应训练时的优化器,
- 当前 epoch 数,和总 epoch 数。 具体可访问的属性,参见 :class:`~fastNLP.core.callback.Callback` 。这里, MyCallBack 为了求平均 loss ,需要知道当前 epoch 的总步
- 数,可以通过 self.step 属性得到当前训练了多少步。
-
-.. code-block:: python
-
- from fastNLP import Callback
- from fastNLP import logger
-
- class MyCallBack(Callback):
- """Print average loss in each epoch"""
- def __init__(self):
- super().__init__()
- self.total_loss = 0
- self.start_step = 0
-
- def on_backward_begin(self, loss):
- self.total_loss += loss.item()
-
- def on_epoch_end(self):
- n_steps = self.step - self.start_step
- avg_loss = self.total_loss / n_steps
- logger.info('Avg loss at epoch %d, %.6f', self.epoch, avg_loss)
- self.start_step = self.step
-
- callbacks = [MyCallBack()]
- train_with_callback(callbacks)
-
-
-----------------------------------
-代码下载
-----------------------------------
-
-.. raw:: html
-
- 点击下载 IPython Notebook 文件
diff --git a/docs/source/tutorials/序列标注.rst b/docs/source/tutorials/序列标注.rst
deleted file mode 100644
index 3fec110d..00000000
--- a/docs/source/tutorials/序列标注.rst
+++ /dev/null
@@ -1,208 +0,0 @@
-=====================
-序列标注
-=====================
-
-这一部分的内容主要展示如何使用fastNLP实现序列标注(Sequence labeling)任务。您可以使用fastNLP的各个组件快捷,方便地完成序列标注任务,达到出色的效果。
-在阅读这篇教程前,希望您已经熟悉了fastNLP的基础使用,尤其是数据的载入以及模型的构建。通过这个小任务,能让您进一步熟悉fastNLP的使用。
-
-.. note::
-
- 本教程推荐使用 GPU 进行实验
-
-命名实体识别(name entity recognition, NER)
-------------------------------------------
-
-命名实体识别任务是从文本中抽取出具有特殊意义或者指代性非常强的实体,通常包括人名、地名、机构名和时间等。
-如下面的例子中
-
- 我来自复旦大学。
-
-其中“复旦大学”就是一个机构名,命名实体识别就是要从中识别出“复旦大学”这四个字是一个整体,且属于机构名这个类别。这个问题在实际做的时候会被
-转换为序列标注问题
-
-针对"我来自复旦大学"这句话,我们的预测目标将是[O, O, O, B-ORG, I-ORG, I-ORG, I-ORG],其中O表示out,即不是一个实体,B-ORG是ORG(
-organization的缩写)这个类别的开头(Begin),I-ORG是ORG类别的中间(Inside)。
-
-在本tutorial中我们将通过fastNLP尝试写出一个能够执行以上任务的模型。
-
-载入数据
-------------------------------------------
-fastNLP的数据载入主要是由Loader与Pipe两个基类衔接完成的,您可以通过 :doc:`使用Loader和Pipe处理数据 `
-了解如何使用fastNLP提供的数据加载函数。下面我们以微博命名实体任务来演示一下在fastNLP进行序列标注任务。
-
-.. code-block:: python
-
- from fastNLP.io import WeiboNERPipe
- data_bundle = WeiboNERPipe().process_from_file()
- print(data_bundle.get_dataset('train')[:2])
-
-打印的数据如下 ::
-
- +-------------------------------------------------+------------------------------------------+------------------------------------------+---------+
- | raw_chars | target | chars | seq_len |
- +-------------------------------------------------+------------------------------------------+------------------------------------------+---------+
- | ['一', '节', '课', '的', '时', '间', '真', '... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, ... | [8, 211, 775, 3, 49, 245, 89, 26, 101... | 16 |
- | ['回', '复', '支', '持', ',', '赞', '成', '... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... | [116, 480, 127, 109, 2, 446, 134, 2, ... | 59 |
- +-------------------------------------------------+------------------------------------------+------------------------------------------+---------+
-
-
-模型构建
---------------------------------
-
-首先选择需要使用的Embedding类型。关于Embedding的相关说明可以参见 :doc:`使用Embedding模块将文本转成向量 ` 。
-在这里我们使用通过word2vec预训练的中文汉字embedding。
-
-.. code-block:: python
-
- from fastNLP.embeddings import StaticEmbedding
-
- embed = StaticEmbedding(vocab=data_bundle.get_vocab('chars'), model_dir_or_name='cn-char-fastnlp-100d')
-
-选择好Embedding之后,我们可以使用fastNLP中自带的 :class:`fastNLP.models.BiLSTMCRF` 作为模型。
-
-.. code-block:: python
-
- from fastNLP.models import BiLSTMCRF
-
- data_bundle.rename_field('chars', 'words') # 这是由于BiLSTMCRF模型的forward函数接受的words,而不是chars,所以需要把这一列重新命名
- model = BiLSTMCRF(embed=embed, num_classes=len(data_bundle.get_vocab('target')), num_layers=1, hidden_size=200, dropout=0.5,
- target_vocab=data_bundle.get_vocab('target'))
-
-进行训练
---------------------------------
-
-下面我们选择用来评估模型的metric,以及优化用到的优化函数。
-
-.. code-block:: python
-
- from fastNLP import SpanFPreRecMetric
- from torch.optim import Adam
- from fastNLP import LossInForward
-
- metric = SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))
- optimizer = Adam(model.parameters(), lr=1e-2)
- loss = LossInForward()
-
-使用Trainer进行训练, 您可以通过修改 device 的值来选择显卡。
-
-.. code-block:: python
-
- from fastNLP import Trainer
- import torch
-
- device= 0 if torch.cuda.is_available() else 'cpu'
- trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer,
- dev_data=data_bundle.get_dataset('dev'), metrics=metric, device=device)
- trainer.train()
-
-训练过程输出为::
-
- input fields after batch(if batch size is 2):
- target: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26])
- seq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2])
- words: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26])
- target fields after batch(if batch size is 2):
- target: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26])
- seq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2])
-
- training epochs started 2019-09-25-10-43-09
- Evaluate data in 0.62 seconds!
- Evaluation on dev at Epoch 1/10. Step:43/430:
- SpanFPreRecMetric: f=0.070352, pre=0.100962, rec=0.053985
-
- ...
-
- Evaluate data in 0.61 seconds!
- Evaluation on dev at Epoch 10/10. Step:430/430:
- SpanFPreRecMetric: f=0.51223, pre=0.581699, rec=0.457584
-
-
- In Epoch:7/Step:301, got best dev performance:
- SpanFPreRecMetric: f=0.515528, pre=0.65098, rec=0.426735
- Reloaded the best model.
-
-进行测试
---------------------------------
-
-训练结束之后过,可以通过 :class:`~fastNLP.Tester` 测试其在测试集上的性能
-
-.. code-block:: python
-
- from fastNLP import Tester
-
- tester = Tester(data_bundle.get_dataset('test'), model, metrics=metric)
- tester.test()
-
-输出为::
-
- [tester]
- SpanFPreRecMetric: f=0.482399, pre=0.530086, rec=0.442584
-
-
-使用更强的Bert做序列标注
---------------------------------
-
-在fastNLP使用Bert进行任务,您只需要把 :class:`fastNLP.embeddings.StaticEmbedding` 切换为 :class:`fastNLP.embeddings.BertEmbedding` (可修改 device 选择显卡)。
-
-.. code-block:: python
-
- from fastNLP.io import WeiboNERPipe
- from fastNLP.models import BiLSTMCRF
-
- data_bundle = WeiboNERPipe().process_from_file()
- data_bundle.rename_field('chars', 'words')
-
- from fastNLP.embeddings import BertEmbedding
- embed = BertEmbedding(vocab=data_bundle.get_vocab('words'), model_dir_or_name='cn')
- model = BiLSTMCRF(embed=embed, num_classes=len(data_bundle.get_vocab('target')), num_layers=1, hidden_size=200, dropout=0.5,
- target_vocab=data_bundle.get_vocab('target'))
-
- from fastNLP import SpanFPreRecMetric
- from torch.optim import Adam
- from fastNLP import LossInForward
- metric = SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))
- optimizer = Adam(model.parameters(), lr=2e-5)
- loss = LossInForward()
-
- from fastNLP import Trainer
- import torch
- device= 0 if torch.cuda.is_available() else 'cpu'
- trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer, batch_size=12,
- dev_data=data_bundle.get_dataset('dev'), metrics=metric, device=device)
- trainer.train()
-
- from fastNLP import Tester
- tester = Tester(data_bundle.get_dataset('test'), model, metrics=metric)
- tester.test()
-
-输出为::
-
- training epochs started 2019-09-25-07-15-43
- Evaluate data in 2.02 seconds!
- Evaluation on dev at Epoch 1/10. Step:113/1130:
- SpanFPreRecMetric: f=0.0, pre=0.0, rec=0.0
-
- ...
-
- Evaluate data in 2.17 seconds!
- Evaluation on dev at Epoch 10/10. Step:1130/1130:
- SpanFPreRecMetric: f=0.647332, pre=0.589852, rec=0.717224
-
- In Epoch:6/Step:678, got best dev performance:
- SpanFPreRecMetric: f=0.669963, pre=0.645238, rec=0.696658
- Reloaded the best model.
-
- Evaluate data in 1.82 seconds!
- [tester]
- SpanFPreRecMetric: f=0.641774, pre=0.626424, rec=0.657895
-
-可以看出通过使用Bert,效果有明显的提升,从48.2提升到了64.1。
-
-
-----------------------------------
-代码下载
-----------------------------------
-
-.. raw:: html
-
- 点击下载 IPython Notebook 文件
diff --git a/docs/source/tutorials/文本分类.rst b/docs/source/tutorials/文本分类.rst
deleted file mode 100644
index 30f6cf4f..00000000
--- a/docs/source/tutorials/文本分类.rst
+++ /dev/null
@@ -1,542 +0,0 @@
-文本分类
-=============================
-
-文本分类(Text classification)任务是将一句话或一段话划分到某个具体的类别。比如垃圾邮件识别,文本情绪分类等。这篇教程可以带你从零开始了解 fastNLP 的使用
-
-.. note::
-
- 本教程推荐使用 GPU 进行实验
-
-.. code-block:: text
-
- 1, 商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错!
-
-其中开头的1是只这条评论的标签,表示是正面的情绪。我们将使用到的数据可以通过 `此链接 `_
-下载并解压,当然也可以通过fastNLP自动下载该数据。
-
-数据中的内容如下图所示。接下来,我们将用fastNLP在这个数据上训练一个分类网络。
-
-.. figure:: ./cn_cls_example.png
- :alt: jupyter
-
-步骤
-----
-
-一共有以下的几个步骤:
-
-1. `读取数据 <#id4>`_
-
-2. `预处理数据 <#id5>`_
-
-3. `选择预训练词向量 <#id6>`_
-
-4. `创建模型 <#id7>`_
-
-5. `训练模型 <#id8>`_
-
-(1) 读取数据
-~~~~~~~~~~~~~~~~~~~~
-
-fastNLP提供多种数据的自动下载与自动加载功能,对于这里我们要用到的数据,我们可以用 :class:`~fastNLP.io.Loader` 自动下载并加载该数据。
-更多有关Loader的使用可以参考 :mod:`~fastNLP.io.loader`
-
-.. code-block:: python
-
- from fastNLP.io import ChnSentiCorpLoader
-
- loader = ChnSentiCorpLoader() # 初始化一个中文情感分类的loader
- data_dir = loader.download() # 这一行代码将自动下载数据到默认的缓存地址, 并将该地址返回
- data_bundle = loader.load(data_dir) # 这一行代码将从{data_dir}处读取数据至DataBundle
-
-
-DataBundle的相关介绍,可以参考 :class:`~fastNLP.io.DataBundle` 。我们可以打印该data\_bundle的基本信息。
-
-.. code-block:: python
-
- print(data_bundle)
-
-
-.. code-block:: text
-
- In total 3 datasets:
- dev has 1200 instances.
- train has 9600 instances.
- test has 1200 instances.
- In total 0 vocabs:
-
-
-
-可以看出,该data\_bundle中一个含有三个 :class:`~fastNLP.DataSet` 。通过下面的代码,我们可以查看DataSet的基本情况
-
-.. code-block:: python
-
- print(data_bundle.get_dataset('train')[:2]) # 查看Train集前两个sample
-
-
-.. code-block:: text
-
- +-----------------------------+--------+
- | raw_chars | target |
- +-----------------------------+--------+
- | 选择珠江花园的原因就是方... | 1 |
- | 15.4寸笔记本的键盘确实爽... | 1 |
- +-----------------------------+--------+
-
-(2) 预处理数据
-~~~~~~~~~~~~~~~~~~~~
-
-在NLP任务中,预处理一般包括:
-
-(a) 将一整句话切分成汉字或者词;
-
-(b) 将文本转换为index
-
-fastNLP中也提供了多种数据集的处理类,这里我们直接使用fastNLP的ChnSentiCorpPipe。更多关于Pipe的说明可以参考 :mod:`~fastNLP.io.pipe` 。
-
-.. code-block:: python
-
- from fastNLP.io import ChnSentiCorpPipe
-
- pipe = ChnSentiCorpPipe()
- data_bundle = pipe.process(data_bundle) # 所有的Pipe都实现了process()方法,且输入输出都为DataBundle类型
-
- print(data_bundle) # 打印data_bundle,查看其变化
-
-
-.. code-block:: text
-
- In total 3 datasets:
- dev has 1200 instances.
- train has 9600 instances.
- test has 1200 instances.
- In total 2 vocabs:
- chars has 4409 entries.
- target has 2 entries.
-
-
-
-可以看到除了之前已经包含的3个 :class:`~fastNLP.DataSet` ,还新增了两个 :class:`~fastNLP.Vocabulary` 。我们可以打印DataSet中的内容
-
-.. code-block:: python
-
- print(data_bundle.get_dataset('train')[:2])
-
-
-.. code-block:: text
-
- +-----------------+--------+-----------------+---------+
- | raw_chars | target | chars | seq_len |
- +-----------------+--------+-----------------+---------+
- | 选择珠江花园... | 0 | [338, 464, 1... | 106 |
- | 15.4寸笔记本... | 0 | [50, 133, 20... | 56 |
- +-----------------+--------+-----------------+---------+
-
-
-新增了一列为数字列表的chars,以及变为数字的target列。可以看出这两列的名称和刚好与data\_bundle中两个Vocabulary的名称是一致的,我们可以打印一下Vocabulary看一下里面的内容。
-
-.. code-block:: python
-
- char_vocab = data_bundle.get_vocab('chars')
- print(char_vocab)
-
-
-.. code-block:: text
-
- Vocabulary(['选', '择', '珠', '江', '花']...)
-
-
-Vocabulary是一个记录着词语与index之间映射关系的类,比如
-
-.. code-block:: python
-
- index = char_vocab.to_index('选')
- print("'选'的index是{}".format(index)) # 这个值与上面打印出来的第一个instance的chars的第一个index是一致的
- print("index:{}对应的汉字是{}".format(index, char_vocab.to_word(index)))
-
-
-.. code-block:: text
-
- '选'的index是338
- index:338对应的汉字是选
-
-
-(3) 选择预训练词向量
-~~~~~~~~~~~~~~~~~~~~
-
-由于Word2vec, Glove, Elmo, Bert等预训练模型可以增强模型的性能,所以在训练具体任务前,选择合适的预训练词向量非常重要。
-在fastNLP中我们提供了多种Embedding使得加载这些预训练模型的过程变得更加便捷。
-这里我们先给出一个使用word2vec的中文汉字预训练的示例,之后再给出一个使用Bert的文本分类。
-这里使用的预训练词向量为'cn-fastnlp-100d',fastNLP将自动下载该embedding至本地缓存,
-fastNLP支持使用名字指定的Embedding以及相关说明可以参见 :mod:`fastNLP.embeddings`
-
-.. code-block:: python
-
- from fastNLP.embeddings import StaticEmbedding
-
- word2vec_embed = StaticEmbedding(char_vocab, model_dir_or_name='cn-char-fastnlp-100d')
-
-
-.. code-block:: text
-
- Found 4321 out of 4409 compound in the pre-training embedding.
-
-(4) 创建模型
-~~~~~~~~~~~~
-
-.. code-block:: python
-
- from torch import nn
- from fastNLP.modules import LSTM
- import torch
-
- # 定义模型
- class BiLSTMMaxPoolCls(nn.Module):
- def __init__(self, embed, num_classes, hidden_size=400, num_layers=1, dropout=0.3):
- super().__init__()
- self.embed = embed
-
- self.lstm = LSTM(self.embed.embedding_dim, hidden_size=hidden_size//2, num_layers=num_layers,
- batch_first=True, bidirectional=True)
- self.dropout_layer = nn.Dropout(dropout)
- self.fc = nn.Linear(hidden_size, num_classes)
-
- def forward(self, chars, seq_len): # 这里的名称必须和DataSet中相应的field对应,比如之前我们DataSet中有chars,这里就必须为chars
- # chars:[batch_size, max_len]
- # seq_len: [batch_size, ]
- chars = self.embed(chars)
- outputs, _ = self.lstm(chars, seq_len)
- outputs = self.dropout_layer(outputs)
- outputs, _ = torch.max(outputs, dim=1)
- outputs = self.fc(outputs)
-
- return {'pred':outputs} # [batch_size,], 返回值必须是dict类型,且预测值的key建议设为pred
-
- # 初始化模型
- model = BiLSTMMaxPoolCls(word2vec_embed, len(data_bundle.get_vocab('target')))
-
-(5) 训练模型
-~~~~~~~~~~~~
-
-fastNLP提供了Trainer对象来组织训练过程,包括完成loss计算(所以在初始化Trainer的时候需要指定loss类型),梯度更新(所以在初始化Trainer的时候需要提供优化器optimizer)以及在验证集上的性能验证(所以在初始化时需要提供一个Metric)
-
-.. code-block:: python
-
- from fastNLP import Trainer
- from fastNLP import CrossEntropyLoss
- from torch.optim import Adam
- from fastNLP import AccuracyMetric
-
- loss = CrossEntropyLoss()
- optimizer = Adam(model.parameters(), lr=0.001)
- metric = AccuracyMetric()
- device = 0 if torch.cuda.is_available() else 'cpu' # 如果有gpu的话在gpu上运行,训练速度会更快
-
- trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, loss=loss,
- optimizer=optimizer, batch_size=32, dev_data=data_bundle.get_dataset('dev'),
- metrics=metric, device=device)
- trainer.train() # 开始训练,训练完成之后默认会加载在dev上表现最好的模型
-
- # 在测试集上测试一下模型的性能
- from fastNLP import Tester
- print("Performance on test is:")
- tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)
- tester.test()
-
-
-.. code-block:: text
-
- input fields after batch(if batch size is 2):
- target: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2])
- chars: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 106])
- seq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2])
- target fields after batch(if batch size is 2):
- target: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2])
- seq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2])
-
- Evaluate data in 0.01 seconds!
- training epochs started 2019-09-03-23-57-10
-
- Evaluate data in 0.43 seconds!
- Evaluation on dev at Epoch 1/10. Step:300/3000:
- AccuracyMetric: acc=0.81
-
- Evaluate data in 0.44 seconds!
- Evaluation on dev at Epoch 2/10. Step:600/3000:
- AccuracyMetric: acc=0.8675
-
- Evaluate data in 0.44 seconds!
- Evaluation on dev at Epoch 3/10. Step:900/3000:
- AccuracyMetric: acc=0.878333
-
- ....
-
- Evaluate data in 0.48 seconds!
- Evaluation on dev at Epoch 9/10. Step:2700/3000:
- AccuracyMetric: acc=0.8875
-
- Evaluate data in 0.43 seconds!
- Evaluation on dev at Epoch 10/10. Step:3000/3000:
- AccuracyMetric: acc=0.895833
-
- In Epoch:7/Step:2100, got best dev performance:
- AccuracyMetric: acc=0.8975
- Reloaded the best model.
-
- Evaluate data in 0.34 seconds!
- [tester]
- AccuracyMetric: acc=0.8975
-
- {'AccuracyMetric': {'acc': 0.8975}}
-
-
-
-PS: 使用Bert进行文本分类
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-.. code-block:: python
-
- # 只需要切换一下Embedding即可
- from fastNLP.embeddings import BertEmbedding
-
- # 这里为了演示一下效果,所以默认Bert不更新权重
- bert_embed = BertEmbedding(char_vocab, model_dir_or_name='cn', auto_truncate=True, requires_grad=False)
- model = BiLSTMMaxPoolCls(bert_embed, len(data_bundle.get_vocab('target')))
-
-
- import torch
- from fastNLP import Trainer
- from fastNLP import CrossEntropyLoss
- from torch.optim import Adam
- from fastNLP import AccuracyMetric
-
- loss = CrossEntropyLoss()
- optimizer = Adam(model.parameters(), lr=2e-5)
- metric = AccuracyMetric()
- device = 0 if torch.cuda.is_available() else 'cpu' # 如果有gpu的话在gpu上运行,训练速度会更快
-
- trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, loss=loss,
- optimizer=optimizer, batch_size=16, dev_data=data_bundle.get_dataset('test'),
- metrics=metric, device=device, n_epochs=3)
- trainer.train() # 开始训练,训练完成之后默认会加载在dev上表现最好的模型
-
- # 在测试集上测试一下模型的性能
- from fastNLP import Tester
- print("Performance on test is:")
- tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)
- tester.test()
-
-
-.. code-block:: text
-
- loading vocabulary file ~/.fastNLP/embedding/bert-chinese-wwm/vocab.txt
- Load pre-trained BERT parameters from file ~/.fastNLP/embedding/bert-chinese-wwm/chinese_wwm_pytorch.bin.
- Start to generating word pieces for word.
- Found(Or segment into word pieces) 4286 words out of 4409.
- input fields after batch(if batch size is 2):
- target: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2])
- chars: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 106])
- seq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2])
- target fields after batch(if batch size is 2):
- target: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2])
- seq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2])
-
- Evaluate data in 0.05 seconds!
- training epochs started 2019-09-04-00-02-37
-
- Evaluate data in 15.89 seconds!
- Evaluation on dev at Epoch 1/3. Step:1200/3600:
- AccuracyMetric: acc=0.9
-
- Evaluate data in 15.92 seconds!
- Evaluation on dev at Epoch 2/3. Step:2400/3600:
- AccuracyMetric: acc=0.904167
-
- Evaluate data in 15.91 seconds!
- Evaluation on dev at Epoch 3/3. Step:3600/3600:
- AccuracyMetric: acc=0.918333
-
- In Epoch:3/Step:3600, got best dev performance:
- AccuracyMetric: acc=0.918333
- Reloaded the best model.
- Performance on test is:
-
- Evaluate data in 29.24 seconds!
- [tester]
- AccuracyMetric: acc=0.919167
-
- {'AccuracyMetric': {'acc': 0.919167}}
-
-
-PS: 基于词进行文本分类
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-由于汉字中没有显示的字与字的边界,一般需要通过分词器先将句子进行分词操作。
-下面的例子演示了如何不基于fastNLP已有的数据读取、预处理代码进行文本分类。
-
-(1) 读取数据
-~~~~~~~~~~~~~~~~~~~~
-
-这里我们继续以之前的数据为例,但这次我们不使用fastNLP自带的数据读取代码
-
-.. code-block:: python
-
- from fastNLP.io import ChnSentiCorpLoader
-
- loader = ChnSentiCorpLoader() # 初始化一个中文情感分类的loader
- data_dir = loader.download() # 这一行代码将自动下载数据到默认的缓存地址, 并将该地址返回
-
-获取到的data_dir下应该有类似以下的文件
-
-.. code-block:: text
-
- - chn_senti_corp
- - train.tsv
- - dev.tsv
- - test.tsv
-
-如果打开任何一个文件查看,会发现里面的格式均为
-
-.. code-block:: text
-
- target raw_chars
- 1 这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般
- 0 怀着十分激动的心情放映...
-
-下面我们先定义一个read_file_to_dataset的函数, 即给定一个文件路径,读取其中的内容,并返回一个DataSet。然后我们将所有的DataSet放入到DataBundle对象中来方便接下来的预处理
-
-.. code-block:: python
-
- import os
- from fastNLP import DataSet, Instance
- from fastNLP.io import DataBundle
-
-
- def read_file_to_dataset(fp):
- ds = DataSet()
- with open(fp, 'r') as f:
- f.readline() # 第一行是title名称,忽略掉
- for line in f:
- line = line.strip()
- target, chars = line.split('\t')
- ins = Instance(target=target, raw_chars=chars)
- ds.append(ins)
- return ds
-
- data_bundle = DataBundle()
- for name in ['train.tsv', 'dev.tsv', 'test.tsv']:
- fp = os.path.join(data_dir, name)
- ds = read_file_to_dataset(fp)
- data_bundle.set_dataset(name=name.split('.')[0], dataset=ds)
-
- print(data_bundle) # 查看以下数据集的情况
- # In total 3 datasets:
- # train has 9600 instances.
- # dev has 1200 instances.
- # test has 1200 instances.
-
-(2) 数据预处理
-~~~~~~~~~~~~~~~~~~~~
-
-在这里,我们首先把句子通过 fastHan_ 进行分词操作,然后创建词表,并将词语转换为序号。
-
-.. _fastHan: https://gitee.com/fastnlp/fastHan
-
-.. code-block:: python
-
- from fastHan import FastHan
- from fastNLP import Vocabulary
-
- model=FastHan()
- # model.set_device('cuda') # 可以注视掉这一行增加速度
-
- # 定义分词处理操作
- def word_seg(ins):
- raw_chars = ins['raw_chars']
- # 由于有些句子比较长,我们只截取前128个汉字
- raw_words = model(raw_chars[:128], target='CWS')[0]
- return raw_words
-
- for name, ds in data_bundle.iter_datasets():
- # apply函数将对内部的instance依次执行word_seg操作,并把其返回值放入到raw_words这个field
- ds.apply(word_seg, new_field_name='raw_words')
- # 除了apply函数,fastNLP还支持apply_field, apply_more(可同时创建多个field)等操作
- # 同时我们增加一个seq_len的field
- ds.add_seq_len('raw_words')
-
- vocab = Vocabulary()
-
- # 对raw_words列创建词表, 建议把非训练集的dataset放在no_create_entry_dataset参数中
- # 也可以通过add_word(), add_word_lst()等建立词表,请参考http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_2_vocabulary.html
- vocab.from_dataset(data_bundle.get_dataset('train'), field_name='raw_words',
- no_create_entry_dataset=[data_bundle.get_dataset('dev'),
- data_bundle.get_dataset('test')])
-
- # 将建立好词表的Vocabulary用于对raw_words列建立词表,并把转为序号的列存入到words列
- vocab.index_dataset(data_bundle.get_dataset('train'), data_bundle.get_dataset('dev'),
- data_bundle.get_dataset('test'), field_name='raw_words', new_field_name='words')
-
- # 建立target的词表,target的词表一般不需要padding和unknown
- target_vocab = Vocabulary(padding=None, unknown=None)
- # 一般情况下我们可以只用训练集建立target的词表
- target_vocab.from_dataset(data_bundle.get_dataset('train'), field_name='target')
- # 如果没有传递new_field_name, 则默认覆盖原词表
- target_vocab.index_dataset(data_bundle.get_dataset('train'), data_bundle.get_dataset('dev'),
- data_bundle.get_dataset('test'), field_name='target')
-
- # 我们可以把词表保存到data_bundle中,方便之后使用
- data_bundle.set_vocab(field_name='words', vocab=vocab)
- data_bundle.set_vocab(field_name='target', vocab=target_vocab)
-
- # 我们把words和target分别设置为input和target,这样它们才会在训练循环中被取出并自动padding, 有关这部分更多的内容参考
- # http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_6_datasetiter.html
- data_bundle.set_target('target')
- data_bundle.set_input('words') # DataSet也有这两个接口
- # 如果某些field,您希望它被设置为target或者input,但是不希望fastNLP自动padding或需要使用特定的padding方式,请参考
- # http://www.fastnlp.top/docs/fastNLP/fastNLP.core.dataset.html
-
- print(data_bundle.get_dataset('train')[:2]) # 我们可以看一下当前dataset的内容
-
- # +--------+-----------------------+-----------------------+----------------------+
- # | target | raw_chars | raw_words | words |
- # +--------+-----------------------+-----------------------+----------------------+
- # | 0 | 选择珠江花园的原因... | ['选择', '珠江', ... | [2, 3, 4, 5, 6, 7... |
- # | 0 | 15.4寸笔记本的键盘... | ['15.4', '寸', '笔... | [71, 72, 73, 74, ... |
- # +--------+-----------------------+-----------------------+----------------------+
-
- # 由于之后需要使用之前定义的BiLSTMMaxPoolCls模型,所以需要将words这个field修改为chars
- data_bundle.rename_field('words', 'chars')
-
-我们可以打印一下vocab看一下当前的词表内容
-
-.. code-block:: python
-
- print(data_bundle.get_vocab('chars'))
- # Vocabulary([选择, 珠江, 花园, 的, 原因]...)
-
-(3) 选择预训练词向量
-~~~~~~~~~~~~~~~~~~~~
-
-这里我们选择腾讯的预训练中文词向量,可以在 腾讯词向量_ 处下载并解压。这里我们不能直接使用BERT,因为BERT是基于中文字进行预训练的。
-
-.. _腾讯词向量: https://ai.tencent.com/ailab/nlp/en/embedding.html
-
-下面我们使用 :mod:`fastNLP.embeddings` 加载该词向量,fastNLP会抽取vocabulary中包含的词的向量,并随机初始化不包含在文件中的词语的词向量。
-
-.. code-block:: python
-
- from fastNLP.embeddings import StaticEmbedding
-
- word2vec_embed = StaticEmbedding(data_bundle.get_vocab('chars'), model_dir_or_name='/path/to/Tencent_AILab_ChineseEmbedding.txt')
-
-再之后的模型定义与训练过程与上面是一致的,这里就不再赘述了。
-
-
-
-----------------------------------
-代码下载
-----------------------------------
-
-.. raw:: html
-
- 点击下载 IPython Notebook 文件
-
diff --git a/docs/source/user/api_update.rst b/docs/source/user/api_update.rst
deleted file mode 100644
index 08a6bdbe..00000000
--- a/docs/source/user/api_update.rst
+++ /dev/null
@@ -1,15 +0,0 @@
-===========================
-API变动列表
-===========================
-
-2020.4.14
-========================
-
-修改了 :class:`fastNLP.core.callback.ControlC` 的 API。
-
-原来的参数 ``quit_all`` 修改为 ``quit_and_do`` ,仍然接收一个 bool 值。新增可选参数 ``action`` ,接收一个待执行的函数,
-在 ``quit_and_do`` 的值为 ``True`` 时,退出训练过程后执行该函数。 ``action`` 的默认值是退出整个程序,与原有功能一致。
-
-.. note::
- 原有用法 `ControlC(True)` 和 `ControlC(False)` 均可以继续正确执行,但 `ControlC(quit_all=True/False)` 需要修改为
- `ControlC(quit_and_do=True/False)` 。
\ No newline at end of file
diff --git a/docs/source/user/example.rst b/docs/source/user/example.rst
index 63535058..cc83d578 100644
--- a/docs/source/user/example.rst
+++ b/docs/source/user/example.rst
@@ -144,7 +144,7 @@ csv 表格
\<\>内表示的是链接地址,\<\>外的是显示到外面的文字
-:doc:`根据文件名链接 `
+:doc:`根据文件名链接 `
:mod:`~fastNLP.core.batch`
diff --git a/docs/source/user/installation.rst b/docs/source/user/installation.rst
deleted file mode 100644
index b4156f6a..00000000
--- a/docs/source/user/installation.rst
+++ /dev/null
@@ -1,24 +0,0 @@
-===============
-安装指南
-===============
-
-.. contents::
- :local:
-
-fastNLP 依赖如下包::
-
- numpy>=1.14.2
- torch>=1.0.0
- tqdm>=4.28.1
- nltk>=3.4.1
- requests
- spacy
- prettytable>=0.7.2
-
-其中torch的安装可能与操作系统及 CUDA 的版本相关,请参见 `PyTorch 官网 `_ 。
-在依赖包安装完成的情况,您可以在命令行执行如下指令完成安装
-
-.. code:: shell
-
- >>> pip install fastNLP
- >>> python -m spacy download en
diff --git a/docs/source/user/quickstart.rst b/docs/source/user/quickstart.rst
deleted file mode 100644
index 40039af6..00000000
--- a/docs/source/user/quickstart.rst
+++ /dev/null
@@ -1,14 +0,0 @@
-===============
-快速入门
-===============
-
-如果你想用 fastNLP 来快速地解决某类 NLP 问题,你可以参考以下教程之一:
-
-.. toctree::
- :maxdepth: 1
-
- /tutorials/文本分类
- /tutorials/序列标注
-
-这些教程是简单地介绍了 fastNLP 的使用流程,其中文本分类相对简单,序列标注则较为复杂。更多的教程分析见 :doc:`/user/tutorials`
-
diff --git a/docs/source/user/tutorials.rst b/docs/source/user/tutorials.rst
deleted file mode 100644
index 7296ea72..00000000
--- a/docs/source/user/tutorials.rst
+++ /dev/null
@@ -1,25 +0,0 @@
-========================
-fastNLP 详细使用教程
-========================
-
-这里是更详细的使用教程。对于大部分的用户,我们建议你从第一篇开始顺序阅读;如果你只想了解其中的一部分,也可以进行选读。
-
-.. toctree::
- :maxdepth: 1
-
- 使用DataSet预处理文本
- 使用Vocabulary转换文本与index
- 使用Embedding模块将文本转成向量
- 使用Loader和Pipe加载并处理数据集
- 使用Trainer和Tester快速训练和测试
- 使用DataSetIter实现自定义训练过程
- 使用Metric快速评测你的模型
- 使用Modules和Models快速搭建自定义模型
- 使用Callback自定义你的训练过程
-
-.. toctree::
- :maxdepth: 1
-
- 拓展阅读1:BertEmbedding的各种用法
- 拓展阅读2:分布式训练简介
- 拓展阅读3:使用fitlog 辅助 fastNLP 进行科研
diff --git a/docs/transfer.ipynb b/docs/transfer.ipynb
new file mode 100644
index 00000000..ff897acd
--- /dev/null
+++ b/docs/transfer.ipynb
@@ -0,0 +1,101 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'应当为一个字符串,其值应当为以下之一:``[None, \"dist\", \"unrepeatdist\"]``;为 ``None`` 时,表示不需要考虑当前 ``dataloader`` 切换为分布式状态;为 ``\"dist\"`` 时,表示该 ``dataloader`` 应该保证每个 ``gpu`` 上返回的 ``batch`` 的数量是一样多的,允许出现少量 ``sample`` ,在 不同 ``gpu`` 上出现重复;为 ``\"unrepeatdist\"`` 时,表示该 ``dataloader`` 应该保证所有 ``gpu`` 上迭代出来的数据合并起来应该刚好等于原始的 数据,允许不同 ``gpu`` 上 ``batch`` 的数量不一致。其中 ``trainer`` 中 ``kwargs`` 的参数 ``use_dist_sampler`` 为 ``True`` 时,该值为 ``\"dist\"``; 否则为 ``None`` ,``evaluator`` 中的 ``kwargs`` 的参数 ``use_dist_sampler`` 为 ``True`` 时,该值为 ``\"unrepeatdist\"``,否则为 ``None``; 注意当 ``dist`` 为 ``ReproducibleSampler, ReproducibleBatchSampler`` 时,是断点重训加载时 ``driver.load`` 函数在调用; 当 ``dist`` 为 ``str`` 或者 ``None`` 时,是 ``trainer`` 在初始化时调用该函数;'"
+ ]
+ },
+ "execution_count": 20,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import re\n",
+ "import sys\n",
+ "sys.path.append(\"../\")\n",
+ "# import fastNLP\n",
+ "\n",
+ "def get_class(text):\n",
+ " return f\":class:`~{text}`\"\n",
+ "\n",
+ "def get_meth(text):\n",
+ " return f\":meth:`~{text}`\"\n",
+ "\n",
+ "def get_module(text):\n",
+ " return f\":mod:`~{text}`\"\n",
+ "\n",
+ "def replace(matched):\n",
+ " \"\"\"\n",
+ " \"\"\"\n",
+ " text = matched.group()\n",
+ " non_space = text.strip()\n",
+ " if non_space == \"\":\n",
+ " return text\n",
+ " # 如果原本就添加了 `,那么只加一个\n",
+ " if non_space.startswith(\"`\"):\n",
+ " res = \"`\" + non_space\n",
+ " else:\n",
+ " res = \"``\" + non_space\n",
+ " if non_space.endswith(\"`\"):\n",
+ " res += \"`\"\n",
+ " else:\n",
+ " res += \"``\"\n",
+ " return text.replace(non_space, f\"{res}\")\n",
+ "\n",
+ "def transfer(text):\n",
+ " \"\"\"\n",
+ " 将输入的 ``text`` 中的英文单词添加 \"``\"。在得到结果后最好手动检查一下,\n",
+ " \"\"\"\n",
+ " res = re.sub(\n",
+ " # 匹配字母、下划线、点、逗号、引号、中括号和`\n",
+ " pattern=r\"[a-zA-Z_ \\.,\\\"\\'\\[\\]`]+\",\n",
+ " repl=replace,\n",
+ " string=text\n",
+ " )\n",
+ " return res\n",
+ "\n",
+ "\n",
+ "text = '应当为一个字符串,其值应当为以下之一:[None, \"dist\", \"unrepeatdist\"];为 None 时,表示不需要考虑当前 dataloader \\\n",
+ " 切换为分布式状态;为 \"dist\" 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在 \\\n",
+ " 不同 gpu 上出现重复;为 \"unrepeatdist\" 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 \\\n",
+ " 数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 \"dist\"; \\\n",
+ " 否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 \"unrepeatdist\",否则为 None; \\\n",
+ " 注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; \\\n",
+ " 当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数;'\n",
+ "transfer(text)"
+ ]
+ }
+ ],
+ "metadata": {
+ "interpreter": {
+ "hash": "c79c3370938623706c2d55a7989cf7c7c31ff0346157477d22565bb370580b77"
+ },
+ "kernelspec": {
+ "display_name": "Python 3.7.13 ('fnlp')",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.13"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py
index efc46888..38007c3f 100644
--- a/fastNLP/__init__.py
+++ b/fastNLP/__init__.py
@@ -1,98 +1,5 @@
-r"""
-fastNLP 由 :mod:`~fastNLP.core` 、 :mod:`~fastNLP.io` 、:mod:`~fastNLP.embeddings` 、 :mod:`~fastNLP.modules`、
-:mod:`~fastNLP.models` 等子模块组成,你可以查看每个模块的文档。
-- :mod:`~fastNLP.core` 是fastNLP 的核心模块,包括 DataSet、 Trainer、 Tester 等组件。详见文档 :mod:`fastNLP.core`
-- :mod:`~fastNLP.io` 是实现输入输出的模块,包括了数据集的读取,模型的存取等功能。详见文档 :mod:`fastNLP.io`
-- :mod:`~fastNLP.embeddings` 提供用于构建复杂网络模型所需的各种embedding。详见文档 :mod:`fastNLP.embeddings`
-- :mod:`~fastNLP.modules` 包含了用于搭建神经网络模型的诸多组件,可以帮助用户快速搭建自己所需的网络。详见文档 :mod:`fastNLP.modules`
-- :mod:`~fastNLP.models` 包含了一些使用 fastNLP 实现的完整网络模型,包括 :class:`~fastNLP.models.CNNText` 、 :class:`~fastNLP.models.SeqLabeling` 等常见模型。详见文档 :mod:`fastNLP.models`
+from fastNLP.envs import *
+from fastNLP.core import *
-fastNLP 中最常用的组件可以直接从 fastNLP 包中 import ,他们的文档如下:
-"""
-__all__ = [
- "Instance",
- "FieldArray",
-
- "DataSetIter",
- "BatchIter",
- "TorchLoaderIter",
-
- "Vocabulary",
- "DataSet",
- "Const",
-
- "Trainer",
- "Tester",
-
- "DistTrainer",
- "get_local_rank",
-
- "Callback",
- "GradientClipCallback",
- "EarlyStopCallback",
- "FitlogCallback",
- "EvaluateCallback",
- "LRScheduler",
- "ControlC",
- "LRFinder",
- "TensorboardCallback",
- "WarmupCallback",
- 'SaveModelCallback',
- "CallbackException",
- "EarlyStopError",
- "CheckPointCallback",
-
- "Padder",
- "AutoPadder",
- "EngChar2DPadder",
-
- # "CollateFn",
- "ConcatCollateFn",
-
- "MetricBase",
- "AccuracyMetric",
- "SpanFPreRecMetric",
- "CMRC2018Metric",
- "ClassifyFPreRecMetric",
- "ConfusionMatrixMetric",
-
- "Optimizer",
- "SGD",
- "Adam",
- "AdamW",
-
- "Sampler",
- "SequentialSampler",
- "BucketSampler",
- "RandomSampler",
- "SortedSampler",
- "ConstantTokenNumSampler",
-
- "LossFunc",
- "CrossEntropyLoss",
- "MSELoss",
- "L1Loss",
- "BCELoss",
- "NLLLoss",
- "LossInForward",
- "LossBase",
- "CMRC2018Loss",
-
- "cache_results",
-
- 'logger',
- "init_logger_dist",
-]
-__version__ = '0.6.0'
-
-import sys
-
-from . import embeddings
-from . import models
-from . import modules
-from .core import *
-from .doc_utils import doc_process
-from .io import loader, pipe
-
-doc_process(sys.modules[__name__])
+__version__ = '1.0.0beta'
diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py
index d2963d13..10472917 100644
--- a/fastNLP/core/__init__.py
+++ b/fastNLP/core/__init__.py
@@ -1,112 +1,124 @@
-r"""
-core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fastNLP 包中直接 import。当然你也同样可以从 core 模块的子模块中 import,
-例如 :class:`~fastNLP.DataSetIter` 组件有两种 import 的方式::
-
- # 直接从 fastNLP 中 import
- from fastNLP import DataSetIter
-
- # 从 core 模块的子模块 batch 中 import DataSetIter
- from fastNLP.core.batch import DataSetIter
-
-对于常用的功能,你只需要在 :mod:`fastNLP` 中查看即可。如果想了解各个子模块的具体作用,您可以在下面找到每个子模块的具体文档。
-
-"""
__all__ = [
- "DataSet",
-
- "Instance",
-
- "FieldArray",
+ # callbacks
+ 'Callback',
+ 'Event',
+ 'Filter',
+ 'CheckpointCallback',
+ 'ProgressCallback',
+ 'RichCallback',
+ 'TqdmCallback',
+ 'RawTextCallback',
+ "LRSchedCallback",
+ 'LoadBestModelCallback',
+ "EarlyStopCallback",
+ 'MoreEvaluateCallback',
+ "TorchWarmupCallback",
+ "TorchGradClipCallback",
+ "ResultsMonitor",
+ 'HasMonitorCallback',
+ "FitlogCallback",
+ "TimerCallback",
+
+ # collators
+ 'Collator',
+ 'NumpyNumberPadder',
+ 'NumpySequencePadder',
+ "NumpyTensorPadder",
"Padder",
- "AutoPadder",
- "EngChar2DPadder",
+ "NullPadder",
+ "RawNumberPadder",
+ "RawSequencePadder",
+ 'TorchNumberPadder',
+ 'TorchSequencePadder',
+ 'TorchTensorPadder',
+ "PaddleNumberPadder",
+ "PaddleTensorPadder",
+ "PaddleSequencePadder",
+ "get_padded_numpy_array",
- "ConcatCollateFn",
-
- "Vocabulary",
-
- "DataSetIter",
- "BatchIter",
- "TorchLoaderIter",
-
- "Const",
-
- "Tester",
- "Trainer",
+ # controllers
+ 'Loop',
+ 'EvaluateBatchLoop',
+ 'TrainBatchLoop',
+ 'Evaluator',
+ 'Trainer',
- "DistTrainer",
- "get_local_rank",
+ # dataloaders TODO 需要把 mix_dataloader 的搞定
+ 'TorchDataLoader',
+ 'PaddleDataLoader',
+ 'JittorDataLoader',
+ 'OneflowDataLoader',
+ 'prepare_jittor_dataloader',
+ 'prepare_paddle_dataloader',
+ 'prepare_torch_dataloader',
+ 'prepare_oneflow_dataloader',
+ "prepare_dataloader",
- "cache_results",
- "seq_len_to_mask",
- "get_seq_len",
- "logger",
- "init_logger_dist",
+ # dataset
+ 'DataSet',
+ 'FieldArray',
+ 'Instance',
- "Callback",
- "GradientClipCallback",
- "EarlyStopCallback",
- "FitlogCallback",
- "EvaluateCallback",
- "LRScheduler",
- "ControlC",
- "LRFinder",
- "TensorboardCallback",
- "WarmupCallback",
- 'SaveModelCallback',
- "CallbackException",
- "EarlyStopError",
- "CheckPointCallback",
+ # drivers
+ "TorchSingleDriver",
+ "TorchDDPDriver",
+ "DeepSpeedDriver",
+ "PaddleSingleDriver",
+ "PaddleFleetDriver",
+ "JittorSingleDriver",
+ "JittorMPIDriver",
+ "OneflowSingleDriver",
+ "OneflowDDPDriver",
+ "torch_seed_everything",
+ "paddle_seed_everything",
+ "oneflow_seed_everything",
+ "torch_move_data_to_device",
+ 'paddle_move_data_to_device',
+ 'oneflow_move_data_to_device',
- "LossFunc",
- "CrossEntropyLoss",
- "L1Loss",
- "BCELoss",
- "BCEWithLogits",
- "NLLLoss",
- "LossInForward",
- "CMRC2018Loss",
- "MSELoss",
- "LossBase",
+ # log
+ "logger",
+ "print",
- "MetricBase",
- "AccuracyMetric",
- "SpanFPreRecMetric",
- "CMRC2018Metric",
- "ClassifyFPreRecMetric",
- "ConfusionMatrixMetric",
+ # metrics
+ "Metric",
+ "Accuracy",
+ "TransformersAccuracy",
+ 'SpanFPreRecMetric',
+ 'ClassifyFPreRecMetric',
- "Optimizer",
- "SGD",
- "Adam",
- "AdamW",
-
+ # samplers
+ 'ReproducibleSampler',
+ 'RandomSampler',
"SequentialSampler",
- "BucketSampler",
- "RandomSampler",
- "Sampler",
"SortedSampler",
- "ConstantTokenNumSampler"
-]
+ 'UnrepeatedSampler',
+ 'UnrepeatedRandomSampler',
+ "UnrepeatedSortedSampler",
+ "UnrepeatedSequentialSampler",
+ "ReproduceBatchSampler",
+ "BucketedBatchSampler",
+ "ReproducibleBatchSampler",
+ "RandomBatchSampler",
-from ._logger import logger, init_logger_dist
-from .batch import DataSetIter, BatchIter, TorchLoaderIter
-from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \
- LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, CallbackException, \
- EarlyStopError, CheckPointCallback
-from .const import Const
-from .dataset import DataSet
-from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder
-from .instance import Instance
-from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, \
- LossInForward, CMRC2018Loss, LossBase, MSELoss, BCEWithLogits
-from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\
- ConfusionMatrixMetric
-from .optimizer import Optimizer, SGD, Adam, AdamW
-from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler, SortedSampler, ConstantTokenNumSampler
-from .tester import Tester
-from .trainer import Trainer
-from .utils import cache_results, seq_len_to_mask, get_seq_len
-from .vocabulary import Vocabulary
-from .collate_fn import ConcatCollateFn
-from .dist_trainer import DistTrainer, get_local_rank
+ # utils
+ "cache_results",
+ "f_rich_progress",
+ "auto_param_call",
+ "f_tqdm_progress",
+ "seq_len_to_mask",
+
+ # vocabulary.py
+ 'Vocabulary'
+]
+from .callbacks import *
+from .collators import *
+from .controllers import *
+from .dataloaders import *
+from .dataset import *
+from .drivers import *
+from .log import *
+from .metrics import *
+from .samplers import *
+from .utils import *
+from .vocabulary import Vocabulary
\ No newline at end of file
diff --git a/fastNLP/core/_logger.py b/fastNLP/core/_logger.py
deleted file mode 100644
index 9051f700..00000000
--- a/fastNLP/core/_logger.py
+++ /dev/null
@@ -1,179 +0,0 @@
-r"""
-Logger 是fastNLP中记录日志的模块,logger封装了logging模块的Logger,
-具体使用方式与直接使用logging.Logger相同,同时也新增一些简单好用的API
-使用方式:
-from fastNLP import logger
-#
-# logger 可以和 logging.Logger 一样使用
-logger.info('your msg')
-logger.error('your msg')
-
-# logger 新增的API
-# 将日志输出到文件,以及输出的日志等级
-logger.add_file('/path/to/log', level='INFO')
-# 定义在命令行中的显示格式和日志等级
-logger.set_stdout('tqdm', level='WARN')
-
-"""
-
-__all__ = [
- 'logger',
- 'init_logger_dist'
-]
-
-import logging
-import logging.config
-import os
-import sys
-import warnings
-from torch import distributed as dist
-
-ROOT_NAME = 'fastNLP'
-
-try:
- import fitlog
-except ImportError:
- fitlog = None
-try:
- from tqdm.auto import tqdm
-except ImportError:
- tqdm = None
-
-if tqdm is not None:
- class TqdmLoggingHandler(logging.Handler):
- def __init__(self, level=logging.INFO):
- super().__init__(level)
-
- def emit(self, record):
- try:
- msg = self.format(record)
- tqdm.write(msg)
- self.flush()
- except (KeyboardInterrupt, SystemExit):
- raise
- except:
- self.handleError(record)
-else:
- class TqdmLoggingHandler(logging.StreamHandler):
- def __init__(self, level=logging.INFO):
- super().__init__(sys.stdout)
- self.setLevel(level)
-
-
-def _get_level(level):
- if isinstance(level, int):
- pass
- else:
- level = level.lower()
- level = {'info': logging.INFO, 'debug': logging.DEBUG,
- 'warn': logging.WARN, 'warning': logging.WARN,
- 'error': logging.ERROR}[level]
- return level
-
-
-def _add_file_handler(logger, path, level='INFO'):
- for h in logger.handlers:
- if isinstance(h, logging.FileHandler):
- if os.path.abspath(path) == h.baseFilename:
- # file path already added
- return
-
- # File Handler
- if os.path.exists(path):
- assert os.path.isfile(path)
- warnings.warn('log already exists in {}'.format(path))
- dirname = os.path.abspath(os.path.dirname(path))
- os.makedirs(dirname, exist_ok=True)
-
- file_handler = logging.FileHandler(path, mode='a')
- file_handler.setLevel(_get_level(level))
- file_formatter = logging.Formatter(fmt='%(asctime)s - %(module)s - [%(levelname)s] - %(message)s',
- datefmt='%Y/%m/%d %H:%M:%S')
- file_handler.setFormatter(file_formatter)
- logger.addHandler(file_handler)
-
-
-def _set_stdout_handler(logger, stdout='tqdm', level='INFO'):
- level = _get_level(level)
- if stdout not in ['none', 'plain', 'tqdm']:
- raise ValueError('stdout must in one of {}'.format(['none', 'plain', 'tqdm']))
- # make sure to initialize logger only once
- stream_handler = None
- for i, h in enumerate(logger.handlers):
- if isinstance(h, (logging.StreamHandler, TqdmLoggingHandler)):
- stream_handler = h
- break
- if stream_handler is not None:
- logger.removeHandler(stream_handler)
-
- # Stream Handler
- if stdout == 'plain':
- stream_handler = logging.StreamHandler(sys.stdout)
- elif stdout == 'tqdm':
- stream_handler = TqdmLoggingHandler(level)
- else:
- stream_handler = None
-
- if stream_handler is not None:
- stream_formatter = logging.Formatter('%(message)s')
- stream_handler.setLevel(level)
- stream_handler.setFormatter(stream_formatter)
- logger.addHandler(stream_handler)
-
-
-class FastNLPLogger(logging.getLoggerClass()):
- def __init__(self, name):
- super().__init__(name)
-
- def add_file(self, path='./log.txt', level='INFO'):
- r"""add log output file and the output level"""
- _add_file_handler(self, path, level)
-
- def set_stdout(self, stdout='tqdm', level='INFO'):
- r"""set stdout format and the output level"""
- _set_stdout_handler(self, stdout, level)
-
-
-logging.setLoggerClass(FastNLPLogger)
-
-
-# print(logging.getLoggerClass())
-# print(logging.getLogger())
-
-def _init_logger(path=None, stdout='tqdm', level='INFO'):
- r"""initialize logger"""
- level = _get_level(level)
-
- # logger = logging.getLogger()
- logger = logging.getLogger(ROOT_NAME)
- logger.propagate = False
- logger.setLevel(1) # make the logger the lowest level
-
- _set_stdout_handler(logger, stdout, level)
-
- # File Handler
- if path is not None:
- _add_file_handler(logger, path, level)
-
- return logger
-
-
-def _get_logger(name=None, level='INFO'):
- level = _get_level(level)
- if name is None:
- name = ROOT_NAME
- assert isinstance(name, str)
- if not name.startswith(ROOT_NAME):
- name = '{}.{}'.format(ROOT_NAME, name)
- logger = logging.getLogger(name)
- logger.setLevel(level)
- return logger
-
-
-logger = _init_logger(path=None, level='INFO')
-
-
-def init_logger_dist():
- global logger
- rank = dist.get_rank()
- logger.setLevel(logging.INFO if rank == 0 else logging.WARNING)
diff --git a/fastNLP/core/_parallel_utils.py b/fastNLP/core/_parallel_utils.py
deleted file mode 100644
index bcfd3b59..00000000
--- a/fastNLP/core/_parallel_utils.py
+++ /dev/null
@@ -1,107 +0,0 @@
-r"""undocumented"""
-
-__all__ = []
-
-import threading
-
-import torch
-from torch import nn
-from torch.nn.parallel.parallel_apply import get_a_var
-from torch.nn.parallel.replicate import replicate
-from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
-
-
-def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
- r"""Applies each `module` in :attr:`modules` in parallel on arguments
- contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
- on each of :attr:`devices`.
-
- :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
- :attr:`devices` (if given) should all have same length. Moreover, each
- element of :attr:`inputs` can either be a single object as the only argument
- to a module, or a collection of positional arguments.
- """
- assert len(modules) == len(inputs)
- if kwargs_tup is not None:
- assert len(modules) == len(kwargs_tup)
- else:
- kwargs_tup = ({},) * len(modules)
- if devices is not None:
- assert len(modules) == len(devices)
- else:
- devices = [None] * len(modules)
-
- lock = threading.Lock()
- results = {}
- grad_enabled = torch.is_grad_enabled()
-
- def _worker(i, module, input, kwargs, device=None):
- torch.set_grad_enabled(grad_enabled)
- if device is None:
- device = get_a_var(input).get_device()
- try:
- with torch.cuda.device(device):
- # this also avoids accidental slicing of `input` if it is a Tensor
- if not isinstance(input, (list, tuple)):
- input = (input,)
- output = getattr(module, func_name)(*input, **kwargs)
- with lock:
- results[i] = output
- except Exception as e:
- with lock:
- results[i] = e
-
- if len(modules) > 1:
- threads = [threading.Thread(target=_worker,
- args=(i, module, input, kwargs, device))
- for i, (module, input, kwargs, device) in
- enumerate(zip(modules, inputs, kwargs_tup, devices))]
-
- for thread in threads:
- thread.start()
- for thread in threads:
- thread.join()
- else:
- _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
-
- outputs = []
- for i in range(len(inputs)):
- output = results[i]
- if isinstance(output, Exception):
- raise output
- outputs.append(output)
- return outputs
-
-
-def _data_parallel_wrapper(func_name, device_ids, output_device):
- r"""
- 这个函数是用于对需要多卡执行的函数的wrapper函数。参考的nn.DataParallel的forward函数
-
- :param str, func_name: 对network中的这个函数进行多卡运行
- :param device_ids: nn.DataParallel中的device_ids
- :param output_device: nn.DataParallel中的output_device
- :return:
- """
-
- def wrapper(network, *inputs, **kwargs):
- inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0)
- if len(device_ids) == 1:
- return getattr(network, func_name)(*inputs[0], **kwargs[0])
- replicas = replicate(network, device_ids[:len(inputs)])
- outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)])
- return gather(outputs, output_device)
-
- return wrapper
-
-
-def _model_contains_inner_module(model):
- r"""
-
- :param nn.Module model: 模型文件,判断是否内部包含model.module, 多用于check模型是否是nn.DataParallel,
- nn.parallel.DistributedDataParallel。主要是在做形参匹配的时候需要使用最内部的model的function。
- :return: bool
- """
- if isinstance(model, nn.Module):
- if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
- return True
- return False
diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py
deleted file mode 100644
index 94942f09..00000000
--- a/fastNLP/core/batch.py
+++ /dev/null
@@ -1,465 +0,0 @@
-r"""
-batch 模块实现了 fastNLP 所需的 :class:`~fastNLP.core.batch.DataSetIter` 类。
-
-"""
-__all__ = [
- "BatchIter",
- "DataSetIter",
- "TorchLoaderIter",
-]
-
-import atexit
-import abc
-
-from numbers import Number
-import numpy as np
-import torch
-import torch.utils.data
-from collections import defaultdict
-
-from .dataset import DataSet
-from .sampler import SequentialSampler, Sampler
-from ._logger import logger
-
-
-_python_is_exit = False
-
-
-def _set_python_is_exit():
- global _python_is_exit
- _python_is_exit = True
-
-
-atexit.register(_set_python_is_exit)
-
-
-def _pad(batch_dict, dataset, as_numpy):
- result = {}
- for n, vlist in batch_dict.items():
- f = dataset.field_arrays[n]
- if f.padder is None:
- result[n] = np.array(vlist)
- else:
- res = f.pad(vlist)
- if not as_numpy:
- res, _ = _to_tensor(res, field_dtype=f.dtype)
- result[n] = res
-
- return result
-
-
-class DataSetGetter:
- r"""
- 传递给torch.utils.data.DataLoader获取数据,DataLoder会传入int的idx获取数据(调用这里的__getitem__()函数)。
- """
- def __init__(self, dataset: DataSet, as_numpy=False):
- self.dataset = dataset
- self.as_numpy = as_numpy
- self.idx_list = list(range(len(dataset)))
-
- self.x_names = {n for n, f in dataset.get_all_fields().items() if f.is_input}
- self.y_names = {n for n, f in dataset.get_all_fields().items() if f.is_target}
-
- def __getitem__(self, idx: int):
- # mapping idx to sampled idx
- idx = self.idx_list[idx]
- ins = self.dataset[idx]
- return idx, ins
-
- def __len__(self):
- return len(self.dataset)
-
- def collate_fn(self, ins_list: list):
- r"""
-
- :param batch: [[idx1, x_dict1, y_dict1], [idx2, x_dict2, y_dict2], [xx, xx, xx]]
- :return:
- """
- indices = []
- sin_x, sin_y = defaultdict(list), defaultdict(list)
- # 收集需要关注的field的数据
- for idx, ins in ins_list:
- indices.append(idx)
- for n, v in ins.items():
- if n in self.x_names:
- sin_x[n].append(v)
- if n in self.y_names:
- sin_y[n].append(v)
- # 根据情况,进行pad
- sin_x = _pad(sin_x, dataset=self.dataset, as_numpy=self.as_numpy)
- sin_y = _pad(sin_y, dataset=self.dataset, as_numpy=self.as_numpy)
-
- if not self.dataset.collater.is_empty():
- bx, by = self.dataset._collate_batch(ins_list)
- sin_x.update(bx)
- sin_y.update(by)
-
- return indices, sin_x, sin_y
-
- def __getattr__(self, item):
- if hasattr(self.dataset, item):
- return getattr(self.dataset, item)
- else:
- raise AttributeError("'DataSetGetter' object has no attribute '{}'".format(item))
-
-
-class SamplerAdapter(torch.utils.data.Sampler):
- r"""
- 用于传入torch.utils.data.DataLoader中,DataLoader会调用__iter__()方法获取index(一次只取一个int)
-
- """
- def __init__(self, sampler, dataset):
- super().__init__(dataset)
- self.sampler = sampler
- self.dataset = dataset
-
- def __len__(self):
- return len(self.dataset)
-
- def __iter__(self):
- return iter(self.sampler(self.dataset))
-
-
-class BatchIter:
- r"""
- Trainer用于迭代数据的类。继承该类,并实现get_num_batches(), get_batch_indices(), num_batches(), __iter__()方法以及dataset属性。
-
- """
- def __init__(self, dataset, batch_size=1, sampler=None,
- num_workers=0, pin_memory=False, drop_last=False,
- timeout=0, worker_init_fn=None, collate_fn=None,
- batch_sampler=None):
- if isinstance(sampler, Sampler): # 如果时fastNLP的sampler需要adapt一下
- sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset)
- self.sampler = sampler
- self.batch_sampler = batch_sampler
-
- # DataLoader的collate_fn输入是List[],里面的元素是dataset[index]返回的结果
- if collate_fn is None:
- # pytoch <= 1.1 中不能设置collate_fn=None
- self.dataiter = torch.utils.data.DataLoader(
- dataset=dataset, batch_size=batch_size, sampler=self.sampler,
- num_workers=num_workers,
- pin_memory=pin_memory, drop_last=drop_last,
- timeout=timeout, worker_init_fn=worker_init_fn,
- batch_sampler=batch_sampler)
- else:
- self.dataiter = torch.utils.data.DataLoader(
- dataset=dataset, batch_size=batch_size, sampler=self.sampler,
- collate_fn=collate_fn, num_workers=num_workers,
- pin_memory=pin_memory, drop_last=drop_last,
- timeout=timeout, worker_init_fn=worker_init_fn,
- batch_sampler=batch_sampler)
-
- # 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了
- if self.batch_sampler is None:
- self._num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last)
- else:
- self._num_batches = len(self.batch_sampler)
- self.batch_size = batch_size
- self.cur_batch_indices = None
-
- @property
- def num_batches(self):
- return self._num_batches
-
- @num_batches.setter
- def num_batches(self, value):
- self._num_batches = value
-
- def init_iter(self):
- pass
-
- @staticmethod
- def get_num_batches(num_samples, batch_size, drop_last):
- r"""
- 计算batch的数量。用于前端显示进度
-
- :param int num_samples:
- :param int batch_size:
- :param bool drop_last: 如果最后一个batch没有batch_size这么多,是否就丢掉。
- :return:
- """
- num_batches = num_samples // batch_size
- if not drop_last and (num_samples % batch_size > 0):
- num_batches += 1
- return num_batches
-
- def get_batch_indices(self):
- r"""
- 获取最近输出的batch的index。用于溯源当前batch的数据
-
- :return:
- """
- return self.cur_batch_indices
-
- def __len__(self):
- return self.num_batches
-
- @property
- def dataset(self):
- r"""
- 获取正在参与iterate的dataset
-
- :return:
- """
- return self.dataiter.dataset
-
- @abc.abstractmethod
- def __iter__(self):
- r"""
- 用于实际数据循环的类,返回值需要为两个dict, 第一个dict中的内容会认为是input, 第二个dict中的内容会认为是target
-
- :return:
- """
- raise NotImplemented
-
-
-class DataSetIter(BatchIter):
- r"""
- DataSetIter 用于从 `DataSet` 中按一定的顺序, 依次按 ``batch_size`` 的大小将数据取出,通过使用DataSetIter,可以不需要考虑
- 输入的padding(由DataSet中每列的Padder决定了)以及不需要考虑将数据转为tensor。
- 组成 `x` 和 `y`::
-
- batch = DataSetIter(data_set, batch_size=16, sampler=SequentialSampler())
- num_batch = len(batch)
- for batch_x, batch_y in batch:
- # do stuff ...
-
- """
- def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, num_workers=0, pin_memory=False,
- drop_last=False, timeout=0, worker_init_fn=None, batch_sampler=None):
- r"""
-
- :param dataset: :class:`~fastNLP.DataSet` 对象, 数据集
- :param int batch_size: 取出的batch大小
- :param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.SequentialSampler`.
-
- Default: ``None``
- :param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 :class:`torch.Tensor`.
-
- Default: ``False``
- :param int num_workers: 使用多少个进程来预处理数据
- :param bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快速度。
- :param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个
- :param timeout: 生成一个batch的timeout值
- :param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。
- :param batch_sampler: 当每次batch取出的数据数量不一致时,可以使用该sampler。batch_sampler每次iter应该输出一个list的index。
- 当batch_sampler不为None时,参数batch_size, sampler, drop_last会被忽略。
- """
- assert isinstance(dataset, DataSet)
- dataset = DataSetGetter(dataset, as_numpy)
- collate_fn = dataset.collate_fn
- if batch_sampler is not None:
- batch_size = 1
- sampler = None
- drop_last = False
- super().__init__(
- dataset=dataset, batch_size=batch_size, sampler=sampler,
- num_workers=num_workers, pin_memory=pin_memory,
- drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
- collate_fn=collate_fn, batch_sampler=batch_sampler
- )
-
- def __iter__(self):
- self.init_iter()
- for indices, batch_x, batch_y in self.dataiter:
- self.cur_batch_indices = indices
- yield batch_x, batch_y
-
-
-class TorchLoaderIter(BatchIter):
- r"""
- 与DataSetIter类似,但可以用于非fastNLP的数据容器对象,以及可以实现完全自定义的生成batch的方式,然后与Trainer,Tester可以实现
- 与DataSetIter一样的对接。
- 需要保证传入的数据容器实现了实现了以下的方法
-
- Example::
-
- import random
- from fastNLP import TorchLoaderIter
- import torch
- class UdfDataSet:
- def __init__(self, num_samples):
- self.num_samples = num_samples
-
- def __getitem__(self, idx): # 必须实现的方法,输入参数是一个int,范围为[0, len(self))
- x = [random.random() for _ in range(3)]
- y = random.random()
- return x,y
-
- def __len__(self): # 需要实现该方法返回值需要是一个int数据
- return self.num_samples
-
- # 需要实现collact_fn将数据转换为tensor
- def collate_fn(data_list):
- # [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list
- xs, ys = [], []
- for l in data_list:
- x, y = l
- xs.append(x)
- ys.append(y)
- # 不需要转移到gpu,Trainer或Tester会将其转移到model所在的device
- x,y = torch.FloatTensor(xs), torch.FloatTensor(ys)
- return {'x':x, 'y':y}, {'y':y} # 第一个dict中内容类似于DataSet中的input列,第二个dict的内容类似于target列
-
- udf_dataset = UdfDataSet(10)
- dataset = TorchLoaderIter(udf_dataset, collate_fn=collate_fn)
- class Model(nn.Module):
- def __init__(self):
- super().__init__()
- self.fc = nn.Linear(3, 1)
- def forward(self, x, y):
- return {'loss':torch.pow(self.fc(x).squeeze(-1)-y, 2).sum()}
- def predict(self, x):
- return {'pred':self.fc(x).squeeze(0)}
- model = Model()
- trainer = Trainer(train_data=dataset, model=model, loss=None, print_every=2, dev_data=dataset,
- metrics=AccuracyMetric(target='y'), use_tqdm=False)
- trainer.train(load_best_model=False)
-
- 除此之外,还可以通过该方法实现OnTheFly的训练,如下面的代码所示
-
- Example::
-
- import tempfile
- import random
- import torch
- tmp_file_handler, tmp_file_path = tempfile.mkstemp(text=True)
- try:
- num_samples, data = 10, []
- for _ in range(num_samples):
- x, y = [random.random() for _ in range(3)], random.random()
- data.append(x + [y])
- with open(tmp_file_path, 'w') as f:
- for d in data:
- f.write(' '.join(map(str, d)) + '\n')
-
- class FileDataSet:
- def __init__(self, tmp_file):
- num_samples = 0
- line_pos = [0] # 对应idx是某一行对应的位置
- self.tmp_file_handler = open(tmp_file, 'r', encoding='utf-8')
- line = self.tmp_file_handler.readline()
- while line:
- if line.strip():
- num_samples += 1
- line_pos.append(self.tmp_file_handler.tell())
- line = self.tmp_file_handler.readline()
- self.tmp_file_handler.seek(0)
- self.num_samples = num_samples
- self.line_pos = line_pos
-
- def __getitem__(self, idx):
- line_start, line_end = self.line_pos[idx], self.line_pos[idx + 1]
- self.tmp_file_handler.seek(line_start)
- line = self.tmp_file_handler.read(line_end - line_start).strip()
- values = list(map(float, line.split()))
- x, y = values[:3], values[-1]
- return x, y
-
- def __len__(self):
- return self.num_samples
-
- def collate_fn(data_list):
- # [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list
- xs, ys = [], []
- for l in data_list:
- x, y = l
- xs.append(x)
- ys.append(y)
- x, y = torch.FloatTensor(xs), torch.FloatTensor(ys)
- return {'x': x, 'y': y}, {'y': y} # 第一个dict中内容类似于DataSet中的input列,第二个dict的内容类似于target列
-
- file_data = FileDataSet(tmp_file_path)
- dataset = TorchLoaderIter(file_data, collate_fn=collate_fn)
-
- class Model(nn.Module):
- def __init__(self):
- super().__init__()
- self.fc = nn.Linear(3, 1)
-
- def forward(self, x, y):
- return {'loss': torch.pow(self.fc(x).squeeze(-1) - y, 2).sum()}
-
- def predict(self, x):
- return {'pred': self.fc(x).squeeze(0)}
-
- model = Model()
- trainer = Trainer(train_data=dataset, model=model, loss=None, print_every=2, dev_data=dataset,
- metrics=AccuracyMetric(target='y'), use_tqdm=False, n_epochs=2)
- trainer.train(load_best_model=False)
-
- finally:
- import os
- if os.path.exists(tmp_file_path):
- os.remove(tmp_file_path)
-
- """
- def __init__(self, dataset, collate_fn, batch_size=1, sampler=None,
- num_workers=0, pin_memory=False, drop_last=False,
- timeout=0, worker_init_fn=None,
- batch_sampler=None):
- r"""
-
- :param dataset: 实现了__getitem__和__len__方法的数据容器。
- :param callable collate_fn: 用于将样本组合成batch的函数。输入为[dataset[idx1], dataset[idx2], ...], 即dataset中
- __getitem__返回值组成的list,返回值必须为两个dict,其中第一个dict会被认为是input,第二个dict中的内容被认为是target。
- 需要转换为tensor的数据,需要在collate_fn中转化,但不需要转移到对应device。
- :param int batch_size: 取出的batch大小
- :param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.SequentialSampler`.
- Default: ``None``
- :param int num_workers: 使用多少个进程来预处理数据
- :param bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快速度。
- :param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个
- :param timeout: 生成一个batch的timeout值
- :param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。
- :param batch_sampler: 当每次batch取出的数据数量不一致时,可以使用该sampler。batch_sampler每次iter应该输出一个list的index。
- 当batch_sampler不为None时,参数batch_size, sampler, drop_last会被忽略。
- """
- assert len(dataset) > 0
- assert collate_fn is not None, "You must pass collate_fn to pad the batch."
- if batch_sampler is not None:
- batch_size = 1
- sampler = None
- drop_last = False
-
- super().__init__(
- dataset=dataset, batch_size=batch_size, sampler=sampler,
- num_workers=num_workers, pin_memory=pin_memory,
- drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
- collate_fn=collate_fn, batch_sampler=batch_sampler
- )
-
- def __iter__(self):
- self.init_iter()
- for batch_x, batch_y in self.dataiter:
- self.cur_batch_indices = None
- yield batch_x, batch_y
-
-
-def _to_tensor(batch, field_dtype):
- r"""
-
- :param batch: np.array()
- :param field_dtype: 数据类型
- :return: batch, flag. 如果传入的数据支持转为tensor,返回的batch就是tensor,且flag为True;如果传入的数据不支持转为tensor,
- 返回的batch就是原来的数据,且flag为False
- """
- try:
- if field_dtype is not None and isinstance(field_dtype, type)\
- and issubclass(field_dtype, Number) \
- and not isinstance(batch, torch.Tensor):
- new_batch = torch.as_tensor(batch)
- flag = True
- else:
- new_batch = batch
- flag = False
- if torch.is_tensor(new_batch):
- if 'float' in new_batch.dtype.__repr__():
- new_batch = new_batch.float()
- elif 'int' in new_batch.dtype.__repr__():
- new_batch = new_batch.long()
- return new_batch, flag
- except Exception as e:
- raise e
diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py
deleted file mode 100644
index 808ddbf5..00000000
--- a/fastNLP/core/callback.py
+++ /dev/null
@@ -1,1235 +0,0 @@
-r"""
-callback模块实现了 fastNLP 中的许多 callback 类,用于增强 :class:`~fastNLP.Trainer` 类。
-
-虽然Trainer本身已经集成了一些功能,但仍然不足以囊括训练过程中可能需要到的功能,
-比如负采样,learning rate decay 和 early stop等。
-为了解决这个问题,fastNLP引入了callback的机制,:class:`~fastNLP.Callback` 是一种在Trainer训练过程中特定阶段会运行的函数集合。
-关于 :class:`~fastNLP.Trainer` 的详细文档,请参见 :mod:`trainer 模块`
-
-我们将 :meth:`~fastNLP.Trainer.train` 这个函数内部分为以下的阶段,在对应阶段会触发相应的调用::
-
- callback.on_train_begin() # 开始进行训练
- for i in range(1, n_epochs+1):
- callback.on_epoch_begin() # 开始新的epoch
- for batch_x, batch_y in Batch:
- callback.on_batch_begin(batch_x, batch_y, indices) # batch_x是设置为input的field,batch_y是设置为target的field
- 获取模型输出
- callback.on_loss_begin()
- 计算loss
- callback.on_backward_begin() # 可以进行一些检查,比如loss是否为None
- 反向梯度回传
- callback.on_backward_end() # 进行梯度截断等
- 进行参数更新
- callback.on_step_end()
- callback.on_batch_end()
- # 根据设置进行evaluation,比如这是本epoch最后一个batch或者达到一定step
- if do evaluation:
- callback.on_valid_begin()
- 进行dev data上的验证
- callback.on_valid_end() # 可以进行在其它数据集上进行验证
- callback.on_epoch_end() # epoch结束调用
- callback.on_train_end() # 训练结束
- callback.on_exception() # 这是一个特殊的步骤,在训练过程中遭遇exception会跳转到这里。
-
-如下面的例子所示,我们可以使用内置的 callback 组件,或者继承 :class:`~fastNLP.core.callback.Callback`
-定义自己的 callback 组件::
-
- from fastNLP import Callback, EarlyStopCallback, Trainer, CrossEntropyLoss, AccuracyMetric
- from fastNLP.models import CNNText
-
- start_time = time.time()
-
- class MyCallback(Callback):
- def on_epoch_end(self):
- print('{:d}ms\n\n'.format(round((time.time()-start_time)*1000)))
-
- model = CNNText((len(vocab),50), num_classes=5, padding=2, dropout=0.1)
- trainer = Trainer(model=model, train_data=train_data, dev_data=dev_data, loss=CrossEntropyLoss(),
- metrics=AccuracyMetric(), callbacks=[MyCallback(),EarlyStopCallback(10)])
- trainer.train()
-
-"""
-__all__ = [
- "Callback",
-
- "GradientClipCallback",
- "EarlyStopCallback",
- "FitlogCallback",
- "EvaluateCallback",
- "LRScheduler",
- "ControlC",
- "LRFinder",
- "TensorboardCallback",
- "WarmupCallback",
- "SaveModelCallback",
-
- "CallbackException",
- "EarlyStopError",
- "CheckPointCallback"
-]
-
-import os
-import sys
-from copy import deepcopy
-
-import torch
-
-from .utils import _save_model
-
-try:
- from tensorboardX import SummaryWriter
-
- tensorboardX_flag = True
-except:
- tensorboardX_flag = False
-
-from .dataset import DataSet
-from .tester import Tester
-from ._logger import logger
-from ._parallel_utils import _model_contains_inner_module
-
-try:
- import fitlog
-except:
- pass
-
-
-class Callback(object):
- r"""
- Callback是fastNLP中被设计用于增强 :class:`~fastNLP.Trainer` 的类。
- 如果Callback被传递给了 Trainer , 则 Trainer 会在对应的阶段调用Callback的函数,
- 具体调用时机可以通过 :mod:`trainer 模块` 查看。
- 这是Callback的基类,所有的callback必须继承自这个类
-
- """
-
- def __init__(self):
- super(Callback, self).__init__()
- self._trainer = None # 在Trainer内部被重新赋值
- self._disabled = False
-
- def __repr__(self):
- return self.__class__.__name__
-
- @property
- def trainer(self):
- r"""
- 该属性可以通过self.trainer获取到,一般情况下不需要使用这个属性。
- """
- return self._trainer
-
- @property
- def grad_scaler(self):
- r"""
- float16的gradient scaler
- """
- return self._trainer.grad_scaler
-
- @property
- def auto_cast(self):
- r"""
- float16用的auto cast环境
- """
- return self._trainer.auto_cast
-
- @property
- def step(self):
- r"""当前运行到的step, 范围为[1, self.n_steps+1)"""
- return self._trainer.step
-
- @property
- def n_steps(self):
- r"""Trainer一共会采多少个batch。当Trainer中update_every设置为非1的值时,该值不等于update的次数"""
- return self._trainer.n_steps
-
- @property
- def batch_size(self):
- r"""train和evaluate时的batch_size为多大"""
- return self._trainer.batch_size
-
- @property
- def epoch(self):
- r"""当前运行的epoch数,范围是[1, self.n_epochs+1)"""
- return self._trainer.epoch
-
- @property
- def n_epochs(self):
- r"""一共会运行多少个epoch"""
- return self._trainer.n_epochs
-
- @property
- def optimizer(self):
- r"""初始化Trainer时传递的Optimizer"""
- return self._trainer.optimizer
-
- @property
- def model(self):
- r"""正在被Trainer训练的模型"""
- return self._trainer.model
-
- @property
- def pbar(self):
- r"""如果在Callback中需要打印内容,请使用self.pbar.write(str)。否则可能出现命令行显示效果不太好的问题。在
- on_train_begin(), on_train_end(), on_exception()中请不要使用该属性,通过print输出即可。"""
- return self._trainer.pbar
-
- @property
- def update_every(self):
- r"""Trainer中的模型多少次反向传播才进行一次梯度更新,在Trainer初始化时传入的。"""
- return self._trainer.update_every
-
- @property
- def batch_per_epoch(self):
- r"""每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。"""
- return self._trainer.batch_per_epoch
-
- @property
- def is_master(self):
- return self._trainer.is_master
-
- @property
- def disabled(self):
- return self._disabled
-
- @property
- def logger(self):
- return getattr(self._trainer, 'logger', logger)
-
- def on_train_begin(self):
- r"""
- 在Train过程开始之前调用。
-
- :return:
- """
- pass
-
- def on_epoch_begin(self):
- r"""
- 在每个epoch开始之前调用一次
-
- :return:
- """
- pass
-
- def on_batch_begin(self, batch_x, batch_y, indices):
- r"""
- 每次采集到一个batch的数据则调用一次。这里对batch_x或batch_y删除添加内容是可以影响到Trainer中内容的。所以在这一步
- 可以进行一些负采样之类的操作。batch_x和batch_y中的tensor已经被放置到了模型所在的设备上。
-
- :param dict batch_x: DataSet中被设置为input的field的batch。
- :param dict batch_y: DataSet中被设置为target的field的batch。
- :param list(int) indices: 这次采样使用到的indices,可以通过DataSet[indices]获取出这个batch采出的Instance,在一些
- 情况下可以帮助定位是哪个Sample导致了错误。仅当num_workers=0时有效。
- :return:
- """
- pass
-
- def on_loss_begin(self, batch_y, predict_y):
- r"""
- 在计算loss前调用,即这里修改batch_y或predict_y的值是可以影响到loss计算的。
-
- :param dict batch_y: 在DataSet中被设置为target的field的batch集合。
- :param dict predict_y: 模型的forward()返回的结果。
- :return:
- """
- pass
-
- def on_backward_begin(self, loss):
- r"""
- 在loss得到之后,但在反向传播之前。可能可以进行loss是否为NaN的检查。
-
- :param torch.Tensor loss: 计算得到的loss值
- :return:
- """
- pass
-
- def on_backward_end(self):
- r"""
- 反向梯度传播已完成,但由于update_every的设置,可能并不是每一次调用都有梯度。到这一步,还没有更新参数。
-
- :return:
- """
- pass
-
- def on_step_end(self):
- r"""
- 到这里模型的参数已经按照梯度更新。但可能受update_every影响,并不是每次都更新了。
-
- :return:
- """
- pass
-
- def on_batch_end(self):
- r"""
- 这一步与on_step_end是紧接着的。只是为了对称性加上了这一步。
-
- """
- pass
-
- def on_valid_begin(self):
- r"""
- 如果Trainer中设置了验证,则发生验证前会调用该函数
-
- :return:
- """
- pass
-
- def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
- r"""
- 每次执行验证集的evaluation后会调用。
-
- :param Dict[str: Dict[str: float]] eval_result: , evaluation的结果。一个例子为{'AccuracyMetric':{'acc':1.0}},即
- 传入的dict是有两层,第一层是metric的名称,第二层是metric的具体指标。
- :param str metric_key: 初始化Trainer时传入的metric_key。
- :param torch.Optimizer optimizer: Trainer中使用的优化器。
- :param bool is_better_eval: 当前dev结果是否比之前的好。
- :return:
- """
- pass
-
- def on_epoch_end(self):
- r"""
- 每个epoch结束将会调用该方法
- """
- pass
-
- def on_train_end(self):
- r"""
- 训练结束,调用该方法
- """
- pass
-
- def on_exception(self, exception):
- r"""
- 当训练过程出现异常,会触发该方法
- :param exception: 某种类型的Exception,比如KeyboardInterrupt等
- """
- pass
-
-
-def _transfer(func):
- r"""装饰器,将对CallbackManager的调用转发到各个Callback子类.
-
- :param func:
- :return:
- """
-
- def wrapper(manager, *arg):
- returns = []
- for callback in manager.callbacks:
- if callback.disabled:
- continue
- returns.append(getattr(callback, func.__name__)(*arg))
- return returns
-
- return wrapper
-
-
-class CallbackManager(Callback):
- r"""
- 内部使用的Callback管理类
- """
- def __init__(self, env, callbacks=None):
- r"""
-
- :param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself.
- :param List[Callback] callbacks:
- """
- super(CallbackManager, self).__init__()
- # set attribute of trainer environment
- self._env = env
- self.callbacks = []
- if callbacks:
- self.callbacks = self.prepare_callbacks(callbacks)
-
- def prepare_callbacks(self, callbacks):
- if not callbacks:
- return []
- if isinstance(callbacks, list):
- if all([isinstance(cb, Callback) for cb in callbacks]) is True:
- pass
- else:
- obj = [not isinstance(cb, Callback) for cb in callbacks][0]
- raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}")
- else:
- raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.")
-
- for env_name, env_val in self._env.items():
- for callback in callbacks:
- setattr(callback, '_' + env_name, env_val) # Callback.trainer
- return callbacks
-
- @_transfer
- def on_train_begin(self):
- pass
-
- @_transfer
- def on_epoch_begin(self):
- pass
-
- @_transfer
- def on_batch_begin(self, batch_x, batch_y, indices):
- pass
-
- @_transfer
- def on_loss_begin(self, batch_y, predict_y):
- pass
-
- @_transfer
- def on_backward_begin(self, loss):
- pass
-
- @_transfer
- def on_backward_end(self):
- pass
-
- @_transfer
- def on_step_end(self):
- pass
-
- @_transfer
- def on_batch_end(self):
- pass
-
- @_transfer
- def on_valid_begin(self):
- pass
-
- @_transfer
- def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
- pass
-
- @_transfer
- def on_validation(self):
- pass
-
- @_transfer
- def on_epoch_end(self):
- pass
-
- @_transfer
- def on_train_end(self):
- pass
-
- @_transfer
- def on_exception(self, exception):
- pass
-
-
-class DistCallbackManager(CallbackManager):
- def __init__(self, env, callbacks_all=None, callbacks_master=None):
- super(DistCallbackManager, self).__init__(env)
- assert 'trainer' in env
- self._trainer = env['trainer']
- self.callbacks_master = []
- self.callbacks_all = []
- self.add_callback(callbacks_all, master=False)
- self.add_callback(callbacks_master, master=True)
-
- def patch_callback(self, callbacks, disabled):
- if not callbacks:
- return
- if not isinstance(callbacks, (list, tuple)):
- callbacks = [callbacks]
- for cb in callbacks:
- cb._disabled = disabled
-
- def add_callback(self, cb, master=False):
- if master:
- self.patch_callback(cb, not self.is_master)
- self.callbacks_master += self.prepare_callbacks(cb)
- else:
- self.callbacks_all += self.prepare_callbacks(cb)
- self.callbacks = self.callbacks_all + self.callbacks_master
-
-
-class GradientClipCallback(Callback):
- r"""
- 每次backward前,将parameter的gradient clip到某个范围。
- """
-
- def __init__(self, parameters=None, clip_value=1, clip_type='norm'):
- r"""
-
- :param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。
- 如果为None则默认对Trainer的model中所有参数进行clip
- :param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数
- :param str clip_type: 支持'norm', 'value'
- 两种::
-
- 1 'norm', 将gradient的norm rescale到[-clip_value, clip_value]
-
- 2 'value', 将gradient限制在[-clip_value, clip_value],
- 小于-clip_value的gradient被赋值为-clip_value;
- 大于clip_value的gradient被赋值为clip_value.
- """
- super().__init__()
-
- from torch import nn
- if clip_type == 'norm':
- self.clip_fun = nn.utils.clip_grad_norm_
- elif clip_type == 'value':
- self.clip_fun = nn.utils.clip_grad_value_
- else:
- raise ValueError("Only supports `norm` or `value` right now.")
- if parameters is not None:
- self.parameters = list(parameters)
- else:
- self.parameters = None
- self.clip_value = clip_value
-
- def on_backward_end(self):
- if self.step%self.update_every==0:
- if self.trainer.fp16:
- self.grad_scaler.unscale_(self.optimizer)
- if self.parameters is not None:
- self.clip_fun(self.parameters, self.clip_value)
- else:
- self.clip_fun(self.model.parameters(), self.clip_value)
-
-
-class EarlyStopCallback(Callback):
- r"""
- 多少个epoch没有变好就停止训练,相关类 :class:`~fastNLP.core.callback.EarlyStopError`
- """
-
- def __init__(self, patience):
- r"""
-
- :param int patience: epoch的数量
- """
- super(EarlyStopCallback, self).__init__()
- self.patience = patience
- self.wait = 0
-
- def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
- if not is_better_eval:
- # current result is getting worse
- if self.wait == self.patience:
- raise EarlyStopError("Early stopping raised.")
- else:
- self.wait += 1
- else:
- self.wait = 0
-
- def on_exception(self, exception):
- if isinstance(exception, EarlyStopError):
- logger.info("Early Stopping triggered in epoch {}!".format(self.epoch))
- else:
- raise exception # 抛出陌生Error
-
-
-class FitlogCallback(Callback):
- r"""
- 该callback可将loss和progress写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入
- 一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。
- 并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则
- fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。
- """
-
- def __init__(self, data=None, tester=None, log_loss_every=0, verbose=1, log_exception=False):
- r"""
-
- :param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要
- 传入多个DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。data的结果的名称以'data'开头。
- :param ~fastNLP.Tester,Dict[~fastNLP.Tester] tester: Tester对象,将在on_valid_end时调用。tester的结果的名称以'tester'开头
- :param int log_loss_every: 多少个step记录一次loss(记录的是这几个batch的loss平均值),如果数据集较大建议将该值设置得
- 大一些,不然会导致log文件巨大。默认为0, 即不要记录loss。
- :param int verbose: 是否在终端打印evaluation的结果,0不打印。
- :param bool log_exception: fitlog是否记录发生的exception信息
- """
- super().__init__()
- self.datasets = {}
- self.testers = {}
- self._log_exception = log_exception
- assert isinstance(log_loss_every, int) and log_loss_every>=0
- if tester is not None:
- if isinstance(tester, dict):
- for name, test in tester.items():
- if not isinstance(test, Tester):
- raise TypeError(f"{name} in tester is not a valid fastNLP.Tester.")
- self.testers['tester-' + name] = test
- if isinstance(tester, Tester):
- self.testers['tester-test'] = tester
- for tester in self.testers.values():
- setattr(tester, 'verbose', 0)
-
- if isinstance(data, dict):
- for key, value in data.items():
- assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}."
- for key, value in data.items():
- self.datasets['data-' + key] = value
- elif isinstance(data, DataSet):
- self.datasets['data-test'] = data
- elif data is not None:
- raise TypeError("data receives dict[DataSet] or DataSet object.")
-
- self.verbose = verbose
- self._log_loss_every = log_loss_every
- self._avg_loss = 0
-
- def on_train_begin(self):
- if (len(self.datasets) > 0 or len(self.testers) > 0) and self.trainer.dev_data is None:
- raise RuntimeError("Trainer has no dev data, you cannot pass extra data to do evaluation.")
-
- if len(self.datasets) > 0:
- for key, data in self.datasets.items():
- tester = Tester(data=data, model=self.model,
- batch_size=self.trainer.kwargs.get('dev_batch_size', self.trainer.batch_size),
- metrics=self.trainer.metrics,
- verbose=0,
- use_tqdm=self.trainer.kwargs.get('test_use_tqdm', self.trainer.use_tqdm),
- sampler=self.trainer.kwargs.get('test_sampler', None))
- self.testers[key] = tester
- fitlog.add_progress(total_steps=self.n_steps)
-
- def on_backward_begin(self, loss):
- if self._log_loss_every>0:
- self._avg_loss += loss.item()
- if self.step%self._log_loss_every==0:
- fitlog.add_loss(self._avg_loss/self._log_loss_every*self.update_every, name='loss', step=self.step, epoch=self.epoch)
- self._avg_loss = 0
-
- def on_valid_end(self, eval_result, metric_key, optimizer, better_result):
- if better_result:
- eval_result = deepcopy(eval_result)
- eval_result['step'] = self.step
- eval_result['epoch'] = self.epoch
- fitlog.add_best_metric(eval_result)
- fitlog.add_metric(eval_result, step=self.step, epoch=self.epoch)
- if len(self.testers) > 0:
- for key, tester in self.testers.items():
- try:
- eval_result = tester.test()
- if self.verbose != 0:
- self.pbar.write("FitlogCallback evaluation on {}:".format(key))
- self.pbar.write(tester._format_eval_results(eval_result))
- fitlog.add_metric(eval_result, name=key, step=self.step, epoch=self.epoch)
- if better_result:
- fitlog.add_best_metric(eval_result, name=key)
- except Exception as e:
- self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key))
- raise e
-
- def on_train_end(self):
- fitlog.finish()
-
- def on_exception(self, exception):
- fitlog.finish(status=1)
- if self._log_exception:
- fitlog.add_other(repr(exception), name='except_info')
-
-
-class EvaluateCallback(Callback):
- r"""
- 通过使用该Callback可以使得Trainer在evaluate dev之外还可以evaluate其它数据集,比如测试集。每一次验证dev之前都会先验证EvaluateCallback
- 中的数据。
- """
-
- def __init__(self, data=None, tester=None):
- r"""
- :param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用Trainer中的metric对数据进行验证。如果需要传入多个
- DataSet请通过dict的方式传入。
- :param ~fastNLP.Tester,Dict[~fastNLP.DataSet] tester: Tester对象, 通过使用Tester对象,可以使得验证的metric与Trainer中
- 的metric不一样。
- """
- super().__init__()
- self.datasets = {}
- self.testers = {}
- if tester is not None:
- if isinstance(tester, dict):
- for name, test in tester.items():
- if not isinstance(test, Tester):
- raise TypeError(f"{name} in tester is not a valid fastNLP.Tester.")
- self.testers['tester-' + name] = test
- if isinstance(tester, Tester):
- self.testers['tester-test'] = tester
- for tester in self.testers.values():
- setattr(tester, 'verbose', 0)
-
- if isinstance(data, dict):
- for key, value in data.items():
- assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}."
- for key, value in data.items():
- self.datasets['data-' + key] = value
- elif isinstance(data, DataSet):
- self.datasets['data-test'] = data
- elif data is not None:
- raise TypeError("data receives dict[DataSet] or DataSet object.")
-
- def on_train_begin(self):
- if len(self.datasets) > 0 and self.trainer.dev_data is None:
- raise RuntimeError("Trainer has no dev data, you cannot pass extra DataSet to do evaluation.")
-
- if len(self.datasets) > 0:
- for key, data in self.datasets.items():
- tester = Tester(data=data, model=self.model,
- batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size),
- metrics=self.trainer.metrics, verbose=0,
- use_tqdm=self.trainer.test_use_tqdm)
- self.testers[key] = tester
-
- def on_valid_end(self, eval_result, metric_key, optimizer, better_result):
- if len(self.testers) > 0:
- for key, tester in self.testers.items():
- try:
- eval_result = tester.test()
- self.logger.info("EvaluateCallback evaluation on {}:".format(key))
- self.logger.info(tester._format_eval_results(eval_result))
- except Exception as e:
- self.logger.error("Exception happens when evaluate on DataSet named `{}`.".format(key))
- raise e
-
-class LRScheduler(Callback):
- r"""
- 对PyTorch LR Scheduler的包装以使得其可以被Trainer所使用
- """
-
- def __init__(self, lr_scheduler):
- r"""
- :param torch.optim.lr_scheduler._LRScheduler lr_scheduler: PyTorch的lr_scheduler
- """
- super(LRScheduler, self).__init__()
- import torch.optim
- if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
- self.scheduler = lr_scheduler
- else:
- raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.")
-
- def on_epoch_end(self):
- self.scheduler.step(self.epoch)
-
-
-class ControlC(Callback):
- r"""
- 检测到 control+C 时的反馈
- """
-
- @staticmethod
- def quit_all():
- import sys
- sys.exit(0) # 直接退出程序
-
- def __init__(self, quit_and_do, action=quit_all):
- r"""
- :param bool quit_and_do: 若为True,则检测到control+C 进行后续操作(默认值为:直接退出程序);否则只退出Trainer。
- """
-
- super(ControlC, self).__init__()
- if type(quit_and_do) != bool:
- raise ValueError("In KeyBoardInterrupt, quit_and_do arguemnt must be a bool.")
- self.quit_and_do = quit_and_do
- self.action = action
-
- def on_exception(self, exception):
- if isinstance(exception, KeyboardInterrupt):
- if self.quit_and_do is True:
- self.action()
- else:
- pass
- else:
- raise exception # 抛出陌生Error
-
-
-class SmoothValue(object):
- r"""work for LRFinder"""
-
- def __init__(self, beta: float):
- self.beta, self.n, self.mov_avg = beta, 0, 0
- self.smooth = None
-
- def add_value(self, val: float) -> None:
- r"""Add `val` to calculate updated smoothed value."""
- self.n += 1
- self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val
- self.smooth = self.mov_avg / (1 - self.beta ** self.n)
-
-
-class LRFinder(Callback):
- r"""
- 用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它
- """
-
- def __init__(self, start_lr=1e-6, end_lr=10):
- r"""
-
- :param float start_lr: 学习率下界
- :param float end_lr: 学习率上界
- """
- super(LRFinder, self).__init__()
- self.start_lr, self.end_lr = start_lr, end_lr
-
- self.stop = False
- self.best_loss = 0.
- self.best_lr = None
- self.loss_history = []
- self.smooth_value = SmoothValue(0.8)
- self.opt = None
- self.find = None
-
- @property
- def lr_gen(self):
- scale = (self.end_lr - self.start_lr) / self.batch_per_epoch
- return (self.start_lr + scale * (step + 1) for step in range(self.batch_per_epoch))
-
- @property
- def num_it(self):
- return self.batch_per_epoch
-
- def on_epoch_begin(self):
- if self.epoch == 1: # first epoch
- self.opt = self.trainer.optimizer # pytorch optimizer
- self.opt.param_groups[0]["lr"] = self.start_lr
- # save model
- torch.save(self.model.state_dict(), 'tmp')
- self.find = True
-
- def on_backward_begin(self, loss):
- if self.find:
- if torch.isnan(loss) or self.stop is True:
- self.stop = True
- return
- loss_val = loss.detach().mean().item()
- self.loss_history.append(loss_val)
- self.smooth_value.add_value(loss_val)
- if self.best_loss == 0. or self.smooth_value.smooth < self.best_loss:
- self.best_loss = self.smooth_value.smooth
- self.best_lr = self.opt.param_groups[0]["lr"]
-
- def on_batch_end(self, *args):
- if self.find:
- lr = next(self.lr_gen, None)
- if lr is None or self.stop is True or self.loss_history[-1] > 4 * self.best_loss:
- self.stop = True
- return
- self.opt.param_groups[0]["lr"] = lr
- # self.loader.load_pytorch(self.trainer.model, "tmp")
-
- def on_epoch_end(self):
- if self.epoch == 1: # first epoch
- self.opt.param_groups[0]["lr"] = self.best_lr
- self.find = False
- # reset model
- states = torch.load('tmp')
- self.model.load_state_dict(states)
- os.remove('tmp')
- self.pbar.write("Model reset. \nFind best lr={}".format(self.best_lr))
-
-
-class TensorboardCallback(Callback):
- r"""
- 接受以下一个或多个字符串作为参数:
- - "model"
- - "loss"
- - "metric"
-
- .. warning::
- fastNLP 已停止对此功能的维护,请等待 fastNLP 兼容 PyTorch1.1 的下一个版本。
- 或者使用和 fastNLP 高度配合的 fitlog(参见 :doc:`/tutorials/extend_3_fitlog` )。
-
- """
-
- def __init__(self, *options):
- super(TensorboardCallback, self).__init__()
- args = {"model", "loss", "metric"}
- for opt in options:
- if opt not in args:
- raise ValueError("Unrecognized argument {}. Expect one of {}".format(opt, args))
- self.options = options
- self._summary_writer = None
- self.graph_added = False
-
- def on_train_begin(self):
- save_dir = self.trainer.save_path
- if save_dir is None:
- path = os.path.join("./", 'tensorboard_logs_{}'.format(self.trainer.start_time))
- else:
- path = os.path.join(save_dir, 'tensorboard_logs_{}'.format(self.trainer.start_time))
- if tensorboardX_flag:
- self._summary_writer = SummaryWriter(path)
- else:
- self._summary_writer = None
-
- def on_batch_begin(self, batch_x, batch_y, indices):
- if "model" in self.options and self.graph_added is False:
- # tesorboardX 这里有大bug,暂时没法画模型图
- # from fastNLP.core.utils import _build_args
- # inputs = _build_args(self.trainer.model, **batch_x)
- # args = tuple([value for value in inputs.values()])
- # args = args[0] if len(args) == 1 else args
- # self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2))
- self.graph_added = True
-
- def on_backward_begin(self, loss):
- if "loss" in self.options and self._summary_writer:
- self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step)
-
- if "model" in self.options and self._summary_writer:
- for name, param in self.trainer.model.named_parameters():
- if param.requires_grad:
- self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.trainer.step)
- # self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.trainer.step)
- self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(),
- global_step=self.trainer.step)
-
- def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
- if "metric" in self.options and self._summary_writer:
- for name, metric in eval_result.items():
- for metric_key, metric_val in metric.items():
- self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val,
- global_step=self.trainer.step)
-
- def on_train_end(self):
- if self._summary_writer:
- self._summary_writer.close()
- del self._summary_writer
-
- def on_exception(self, exception):
- if hasattr(self, "_summary_writer"):
- self._summary_writer.close()
- del self._summary_writer
-
-
-class CheckPointCallback(Callback):
- def __init__(self, save_path, delete_when_train_finish=True, recovery_fitlog=True):
- r"""
- 用于在每个epoch结束的时候保存一下当前的Trainer状态,可以用于恢复之前的运行。使用最近的一个epoch继续训练
- 一段示例代码
- Example1::
-
- >>> callback = CheckPointCallback('chkp.pt')
- >>> trainer = Trainer(xxx, callback=callback)
- >>> trainer.train() # 如果训练过程没结束就fail,请直接再次运行即可(请务必保证与上次使用了完全相同的数据与超参数)
-
- Example2::
-
- >>> fitlog.set_log_dir('xxx')
- >>> callback = CheckPointCallback('chkp.pt') # 一定要在set_log_dir下一行就接着CheckPointCallback
- >>> trainer = Trainer(xxx, callback=callback)
- >>> trainer.train() # 如果训练过程没结束就fail,请直接再次运行即可(请务必保证与上次使用了完全相同的数据与超参数)
-
- :param str save_path: 将状态保存到哪个位置。需要指定一个具体的路径,比如'checkpoints/chtp.pt'。如果检查到该文件存在,将在
- Trainer开始训练的时候自动从这个Checkpoint处开始运行。
- :param bool delete_when_train_finish: 如果Train正常运行完毕,是否自动删除。删除该文件可以使得路径自动复用。
- :param bool recovery_fitlog: 是否恢复fitlog为对应的log,如果为True请将本Callback放在fitlog.set_log_dir后面一行初始化。
- 如果为False,将新建一个log folder否则继续使用之前的。
- """
- super().__init__()
- self.save_path = os.path.abspath(os.path.expanduser(save_path))
- self.delete_when_train_finish = delete_when_train_finish
- self.recover_fitlog = recovery_fitlog
- try:
- import fitlog
- except:
- self.recover_fitlog = False
- if os.path.exists(os.path.expanduser(self.save_path)):
- logger.info("The train will start from the checkpoint saved in {}.".format(self.save_path))
- if self.recover_fitlog:
- states = torch.load(self.save_path)
- if 'fitlog_log_dir' in states:
- try:
- import fitlog
- log_dir = states['fitlog_log_dir']
- if 'fitlog_save_log_dir' in states:
- log_dir = states['fitlog_save_log_dir']
- fitlog.set_log_dir(log_dir, new_log=True)
- except:
- logger.error("Fail to recovery the fitlog states.")
-
- def on_train_begin(self):
- r"""
- 当train开始时,且需要恢复上次训练时,会做以下的操作
- (1) 重新加载model权重
- (2) 重新加载optimizer的状态
- (3) 加载当前epoch数
- (4) 加载当前最佳evaluate的性能
- (5) (optional) 自动将fitlog设置到上次log出继续
-
- :return:
- """
- if os.path.exists(os.path.expanduser(self.save_path)):
- states = torch.load(self.save_path)
- model = self.model
- if _model_contains_inner_module(model):
- model = model.module
- model.load_state_dict(states['model'])
- self.optimizer.load_state_dict(states['optimizer'])
- if 'grad_scaler' in states:
- self.grad_scaler.load_state_dict(states['grad_scaler'])
- self.trainer.epoch = states['epoch'] + 1 # 因为是结束储存的,所以需要从下一个epoch开始
- self.trainer.step = states['step']
- if 'best_dev_epoch' in states:
- self.trainer.best_dev_perf = states['best_dev_perf']
- self.trainer.best_dev_epoch = states['best_dev_epoch']
- self.trainer.best_dev_step = states['best_dev_step']
- self.trainer.best_metric_indicator = states['best_metric_indicator']
- logger.info("Load checkpoint from {}".format(os.path.expanduser(self.save_path)))
-
- def on_epoch_end(self):
- r"""
- 保存状态,使得结果可以被恢复
-
- :param self:
- :return:
- """
- states = {}
- model = self.model
- if _model_contains_inner_module(model):
- model = model.module
- states['model'] = {name:param.cpu() for name, param in model.state_dict().items()}
- states['optimizer'] = self.optimizer.state_dict()
- states['grad_scaler'] = self.grad_scaler.state_dict()
- states['epoch'] = self.epoch
- states['step'] = self.step
- if self.trainer.best_dev_epoch is not None:
- states['best_dev_epoch'] = self.trainer.best_dev_epoch
- states['best_dev_perf'] = self.trainer.best_dev_perf
- states['best_dev_step'] = self.trainer.best_dev_step
- states['best_metric_indicator'] = self.trainer.best_metric_indicator
- if self.recover_fitlog:
- try:
- import fitlog
- if fitlog._logger._log_dir is not None:
- states['fitlog_log_dir'] = fitlog._logger._log_dir
- if fitlog._logger._save_log_dir is not None:
- states['fitlog_save_log_dir'] = fitlog._logger._save_log_dir
- except:
- pass
- torch.save(states, self.save_path)
- logger.debug("Checkpoint:{} has been saved in epoch:{}.".format(self.save_path, self.epoch))
-
- def on_train_end(self):
- # 训练结束,根据情况删除保存的内容
- if self.delete_when_train_finish:
- if os.path.exists(self.save_path):
- os.remove(self.save_path)
- logger.debug("Checkpoint:{} has been removed.".format(self.save_path))
-
-
-class WarmupCallback(Callback):
- r"""
- learning rate按照一定的速率从0上升到设置的learning rate。
- """
- def __init__(self, warmup=0.1, schedule='constant'):
- r"""
-
- :param int,float warmup: 如果warmup为int,则在该step之前,learning rate根据schedule的策略变化; 如果warmup为float,
- 如0.1, 则前10%的step是按照schedule策略调整learning rate。
- :param str schedule: 以哪种方式调整。
- linear: 前warmup的step上升到指定的learning rate(从Trainer中的optimizer处获取的), 后warmup的step下降到0;
- constant前warmup的step上升到指定learning rate,后面的step保持learning rate.
- """
- super().__init__()
- self.warmup = max(warmup, 0.)
-
- self.initial_lrs = [] # 存放param_group的learning rate
- if schedule == 'constant':
- self.get_lr = self._get_constant_lr
- elif schedule == 'linear':
- self.get_lr = self._get_linear_lr
- else:
- raise RuntimeError("Only support 'linear', 'constant'.")
-
- def _get_constant_lr(self, progress):
- if progress1:
- self.warmup = self.warmup/self.t_steps
- self.t_steps = max(2, self.t_steps) # 不能小于2
- # 获取param_group的初始learning rate
- for group in self.optimizer.param_groups:
- self.initial_lrs.append(group['lr'])
-
- def on_backward_end(self):
- if self.step%self.update_every==0:
- progress = (self.step/self.update_every)/self.t_steps
- for lr, group in zip(self.initial_lrs, self.optimizer.param_groups):
- group['lr'] = lr * self.get_lr(progress)
-
-
-class SaveModelCallback(Callback):
- r"""
- 由于Trainer在训练过程中只会保存最佳的模型, 该callback可实现多种方式的结果存储。
- 会根据训练开始的时间戳在save_dir下建立文件夹,再在文件夹下存放多个模型::
-
- -save_dir
- -2019-07-03-15-06-36
- -epoch:0_step:20_{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_performance是性能
- -epoch:1_step:40_{metric_key}:{evaluate_performance}.pt
- -2019-07-03-15-10-00
- -epoch:0_step:20_{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能
- """
- def __init__(self, save_dir, top=3, only_param=False, save_on_exception=False):
- r"""
-
- :param str save_dir: 将模型存放在哪个目录下,会在该目录下创建以时间戳命名的目录,并存放模型。如果save_dir不存在将自动创建
- :param int top: 保存dev表现top多少模型。-1为保存所有模型。
- :param bool only_param: 是否只保存模型的权重。
- :param save_on_exception: 发生exception时,是否保存一份发生exception的模型。模型名称为epoch:x_step:x_Exception:{exception_name}.
- """
- super().__init__()
-
- os.makedirs(save_dir, exist_ok=True)
- self.save_dir = save_dir
- if top < 0:
- self.top = sys.maxsize
- else:
- self.top = top
- self._ordered_save_models = [] # List[Tuple], Tuple[0]是metric, Tuple[1]是path。metric是依次变好的,所以从头删
-
- self.only_param = only_param
- self.save_on_exception = save_on_exception
-
- def on_train_begin(self):
- self.save_dir = os.path.join(self.save_dir, self.trainer.start_time)
-
- def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
- metric_value = list(eval_result.values())[0][metric_key]
- self._save_this_model(metric_value)
-
- def _insert_into_ordered_save_models(self, pair):
- # pair:(metric_value, model_name)
- # 返回save的模型pair与删除的模型pair. pair中第一个元素是metric的值,第二个元素是模型的名称
- index = -1
- for _pair in self._ordered_save_models:
- if _pair[0]>=pair[0] and self.trainer.increase_better:
- break
- if not self.trainer.increase_better and _pair[0]<=pair[0]:
- break
- index += 1
- save_pair = None
- if len(self._ordered_save_models)=self.top and index!=-1):
- save_pair = pair
- self._ordered_save_models.insert(index+1, pair)
- delete_pair = None
- if len(self._ordered_save_models)>self.top:
- delete_pair = self._ordered_save_models.pop(0)
- return save_pair, delete_pair
-
- def _save_this_model(self, metric_value):
- name = "epoch-{}_step-{}_{}-{:.6f}.pt".format(self.epoch, self.step, self.trainer.metric_key, metric_value)
- save_pair, delete_pair = self._insert_into_ordered_save_models((metric_value, name))
- if save_pair:
- try:
- _save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param)
- except Exception as e:
- logger.error(f"The following exception:{e} happens when save model to {self.save_dir}.")
- if delete_pair:
- try:
- delete_model_path = os.path.join(self.save_dir, delete_pair[1])
- if os.path.exists(delete_model_path):
- os.remove(delete_model_path)
- except Exception as e:
- logger.error(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.")
-
- def on_exception(self, exception):
- if self.save_on_exception:
- name = "epoch-{}_step-{}_Exception-{}.pt".format(self.epoch, self.step, exception.__class__.__name__)
- _save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param)
-
-
-class CallbackException(BaseException):
- r"""
- 当需要通过callback跳出训练的时候可以通过抛出CallbackException并在on_exception中捕获这个值。
- """
-
- def __init__(self, msg):
- r"""
-
- :param str msg: Exception的信息。
- """
- super(CallbackException, self).__init__(msg)
-
-
-class EarlyStopError(CallbackException):
- r"""
- 用于EarlyStop时从Trainer训练循环中跳出。
-
- """
-
- def __init__(self, msg):
- super(EarlyStopError, self).__init__(msg)
-
-
-class EchoCallback(Callback):
- r"""
- 用于测试分布式训练
-
- """
- def __init__(self, name, out=sys.stdout):
- super(EchoCallback, self).__init__()
- self.name = name
- self.out = out # deprecated
-
- def __getattribute__(self, item):
- if item.startswith('on_'):
- logger.info('{}.{} has been called at pid: {}'.format(self.name, item, os.getpid()))
- return super(EchoCallback, self).__getattribute__(item)
-
-
-class _TesterCallback(Callback):
- def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None, sampler=None,
- use_tqdm=True):
- super(_TesterCallback, self).__init__()
- self.tester = Tester(data, model,
- metrics=metrics, batch_size=batch_size,
- num_workers=num_workers, verbose=0, sampler=sampler, use_tqdm=use_tqdm)
- if metric_key is not None:
- self.metric_key, self.increase_better = self._parse_metric_key(metric_key)
- else:
- self.metric_key = None
- self.increase_better = True
- self.score = None
-
- def on_valid_begin(self):
- cur_score = self.tester.test()
- eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format(
- self.epoch, self.n_epochs, self.step, self.n_steps,
- self.tester._format_eval_results(cur_score))
- self.logger.info(eval_str)
- is_better = self.compare_better(cur_score)
- if is_better:
- self.score = cur_score
- return cur_score, is_better
-
- @staticmethod
- def _get_score(metric_dict, key):
- for metric in metric_dict.values():
- if key in metric:
- return metric[key]
- return None
-
- @staticmethod
- def _parse_metric_key(metric_key):
- # parse metric_key
- # increase_better is True. It means the exp result gets better if the indicator increases.
- # It is true by default.
- increase_better = False if metric_key[0] == "-" else True
- metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
- return metric_key, increase_better
-
- def compare_better(self, a):
- if self.score is None:
- return True
- if self.metric_key is None:
- metric_key = list(list(self.score.values())[0].keys())[0]
- self.metric_key, self.increase_better = self._parse_metric_key(metric_key)
- k = self.metric_key
- score = self._get_score(self.score, k)
- new_score = self._get_score(a, k)
- if score is None or new_score is None:
- return False
- if self.increase_better:
- return score <= new_score
- else:
- return score >= new_score
diff --git a/fastNLP/core/callbacks/__init__.py b/fastNLP/core/callbacks/__init__.py
new file mode 100644
index 00000000..ac34d5ee
--- /dev/null
+++ b/fastNLP/core/callbacks/__init__.py
@@ -0,0 +1,46 @@
+__all__ = [
+ 'Callback',
+ 'Event',
+ 'Filter',
+ 'CheckpointCallback',
+ 'choose_progress_callback',
+
+ 'ProgressCallback',
+ 'RichCallback',
+ 'TqdmCallback',
+ 'RawTextCallback',
+
+ "LRSchedCallback",
+ 'LoadBestModelCallback',
+ "EarlyStopCallback",
+
+ 'MoreEvaluateCallback',
+
+ "TorchWarmupCallback",
+ "TorchGradClipCallback",
+
+ "ResultsMonitor",
+ 'HasMonitorCallback',
+
+ "FitlogCallback",
+
+ "TimerCallback",
+
+ "TopkSaver"
+]
+
+
+from .callback import Callback
+from .callback_event import Event, Filter
+from .callback_manager import CallbackManager
+from .checkpoint_callback import CheckpointCallback
+from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback, TqdmCallback, RawTextCallback
+from .lr_scheduler_callback import LRSchedCallback
+from .load_best_model_callback import LoadBestModelCallback
+from .early_stop_callback import EarlyStopCallback
+from .torch_callbacks import *
+from .more_evaluate_callback import MoreEvaluateCallback
+from .has_monitor_callback import ResultsMonitor, HasMonitorCallback
+from .fitlog_callback import FitlogCallback
+from .timer_callback import TimerCallback
+from .topk_saver import TopkSaver
diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py
new file mode 100644
index 00000000..a4275f3e
--- /dev/null
+++ b/fastNLP/core/callbacks/callback.py
@@ -0,0 +1,297 @@
+
+__all__ = [
+ 'Callback',
+]
+
+from typing import Callable, Dict, Optional
+
+from .callback_event import Event, Filter
+
+
+class Callback:
+ r"""
+ 实际使用的 callback 类,不管是 **fastNLP** 默认提供的一些 callback 实例,还是用户自己定制的 callback 类,都应该继承该基类;
+ callback 调用时机顺序大概如下::
+
+ Trainer.__init__():
+ on_after_trainer_initialized(trainer, driver)
+ Trainer.run():
+ if num_eval_sanity_batch>0:
+ on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch
+ on_sanity_check_end(trainer, sanity_check_res)
+ try:
+ on_train_begin(trainer)
+ while cur_epoch_idx < n_epochs:
+ on_train_epoch_begin(trainer)
+ while batch_idx_in_epoch<=num_batches_per_epoch:
+ on_fetch_data_begin(trainer)
+ batch = next(dataloader)
+ on_fetch_data_end(trainer)
+ on_train_batch_begin(trainer, batch, indices)
+ on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。
+ on_after_backward(trainer)
+ on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
+ on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
+ on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
+ on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
+ on_train_batch_end(trainer)
+ on_train_epoch_end(trainer)
+ except BaseException:
+ self.on_exception(trainer, exception)
+ finally:
+ on_train_end(trainer)
+
+ 其它 callback 例如 **on_evaluate_begin(trainer)** / **on_evaluate_end(trainer, results)** / **on_save_model(trainer)** /
+ **on_load_model(trainer)** / **on_save_checkpoint(trainer)** / **on_load_checkpoint(trainer)** 将根据需要在 :meth:`Trainer.run `
+ 中特定的时间调用。
+ """
+
+ def on_after_trainer_initialized(self, trainer, driver):
+ r"""
+ 在 ``Trainer`` 初始化后会被触发;
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ :param driver: :class:`~fastNLP.core.controllers.Trainer` 中的 ``driver`` 实例;
+ """
+ pass
+
+ def on_sanity_check_begin(self, trainer):
+ r"""
+ 在 '预跑'检测 开始前会被触发;
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ """
+ pass
+
+ def on_sanity_check_end(self, trainer, sanity_check_res):
+ r"""
+ 在 '预跑'检测 开始后会被触发;
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ :param sanity_check_res: 预跑得到的评测结果,关于对于 **预跑** 的解释,请见 :meth:`~fastNLP.core.controllers.trainer.Trainer.run`;
+ """
+ pass
+
+ def on_train_begin(self, trainer):
+ r"""
+ 在训练开始前会被触发;
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ """
+ pass
+
+ def on_train_end(self, trainer):
+ r"""
+ 在训练完成后会被触发;
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ """
+ pass
+
+ def on_train_epoch_begin(self, trainer):
+ r"""
+ 在训练过程中的每一个 epoch 开始前会被触发;
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ """
+ pass
+
+ def on_train_epoch_end(self, trainer):
+ r"""
+ 在训练过程中的每一个 epoch 完成后会被触发;此时 trainer.cur_epoch_idx 已经完成加 1 操作。
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ """
+ pass
+
+ def on_fetch_data_begin(self, trainer):
+ r"""
+ 在训练过程中准备取出下一个 batch 的数据时触发
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ """
+ pass
+
+ def on_fetch_data_end(self, trainer):
+ r"""
+ 在训练过程中拿到当前的 batch 数据后会被触发;
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ """
+ pass
+
+ def on_train_batch_begin(self, trainer, batch, indices):
+ r"""
+ 在取得数据,执行完 ``input_mapping`` (如果 :class:`~fastNLP.core.controllers.Trainer` 传有该参数),并且移动 ``batch`` 中的张量到了指定设备之后会被触发。
+ 其中 ``batch`` 中的数据格式要么是 ``Dataloader`` 返回的每个 ``batch`` 的格式;要么是 ``input_mapping`` 之后的内容。
+ 如果 ``batch`` 是 ``dict`` 类型,直接增删其中的 key 或 修改其中的 value 会影响到输入模型的中的 ``batch`` 数据。
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ :param batch: batch 的数据,已经经过 ``input_mapping`` (如果有) 以及移动到指定设备 。
+ :param list[int] indices: 当前的 ``batch`` 是数据集中的哪些数据。仅在 ``DataLoader`` 支持得到当前 ``batch index`` 的时候有值,
+ 其它时候为 ``None`` 。
+ """
+ pass
+
+ def on_train_batch_end(self, trainer):
+ r"""
+ 完成一个 batch 的训练(forward)、梯度回传(backward)、梯度更新(step)、梯度置零、batch_idx_in_epoch 与
+ global_forward_batches 累计加1操作之后会被触发。其中梯度更新、梯度置零操作会考虑 **accumulation_steps** ,所以不一定在当前 batch 会
+ 执行。
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ """
+ pass
+
+ def on_exception(self, trainer, exception):
+ r"""
+ 在训练过程遇到异常时调用。
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ :param exception: 遭遇的异常;
+ """
+ pass
+
+ def on_save_model(self, trainer):
+ r"""
+ 当调用 :meth:`Trainer.save_model() ` 时调用,此刻模型还未保存。
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ """
+ pass
+
+ def on_load_model(self, trainer):
+ r"""
+ 当调用 :meth:`Trainer.load_model() ` 加载模型时调用,此刻模型还未加载。
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ """
+ pass
+
+ def on_save_checkpoint(self, trainer) -> Dict:
+ r"""
+ 当 Trainer 将要保存 checkpoint 的时候触发 (即调用 :meth:`Trainer.save_checkpoint() `
+ 函数时),该函数用于保存当前 callback 在恢复时需要的相关数据。
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ """
+ pass
+
+ def on_load_checkpoint(self, trainer, states: Optional[Dict]):
+ r"""
+ 当 Trainer 要恢复 checkpoint 的时候触发(即调用 :meth:`Trainer.load_checkpoint() `
+ 函数时, 此刻 Trainer 与 Driver 已经加载好自身的状态), 参数 states 为 Callback 在调用 :meth:`on_save_checkpoint` 的返回值。
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ :param states:
+ """
+ pass
+
+ def on_before_backward(self, trainer, outputs):
+ r"""
+ 在 backward 前执行。
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ :param outputs: ``model`` 的返回内容。如果有 ``output_mapping``,则 ``outputs`` 中的内容为已经执行了 ``output_mapping`` 后的结果。
+ """
+ pass
+
+ def on_after_backward(self, trainer):
+ r"""
+ 在 ``backward`` 后执行。在多卡场景下,由于 ``accumulation_steps`` 的影响,仅在需要真正 ``update`` 参数那次梯度回传才会触发梯度同步,
+ 因此在多卡且使用 ``accumulation_steps`` 时,可能存在某些 step 各卡上梯度不一致的问题。
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ """
+ pass
+
+ def on_before_optimizers_step(self, trainer, optimizers):
+ r"""
+ 在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 ``accumulation_steps`` 的影响。
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ :param optimizers: 优化器,内容为在 :class:`~fastNLP.core.controllers.Trainer` 初始化时传入的值。
+ """
+ pass
+
+ def on_after_optimizers_step(self, trainer, optimizers):
+ r"""
+ 在进行 optimizer 优化进行后调用。该接口不一定每次前向计算都会触发,实际调用会受到 ``accumulation_steps`` 的影响。
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ :param optimizers: 优化器,内容为在 :class:`~fastNLP.core.controllers.Trainer` 初始化时传入的值。
+ """
+ pass
+
+ def on_before_zero_grad(self, trainer, optimizers):
+ r"""
+ 在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 ``accumulation_steps`` 的影响。
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ :param optimizers: 优化器,内容为在 :class:`~fastNLP.core.controllers.Trainer` 初始化时传入的值。
+ """
+ pass
+
+ def on_after_zero_grad(self, trainer, optimizers):
+ r"""
+ 在进行模型梯度置零后调用。该接口不一定每次前向计算都会触发,实际调用会受到 ``accumulation_steps`` 的影响。
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ :param optimizers: 优化器,内容为在 :class:`~fastNLP.core.controllers.Trainer` 初始化时传入的值。
+ """
+ pass
+
+ def on_evaluate_begin(self, trainer):
+ r"""
+ 在将要进行 ``evaluate`` 时调用。如果是设置的以 step 数量或自定义地决定 evaluate 的频率,该接口是在 :meth:`on_train_batch_end` 之后
+ 进行调用。如果是以 epoch 数量决定调用时机,该接口是在 :meth:`on_train_epoch_end` 之后调用。
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ """
+ pass
+
+ def on_evaluate_end(self, trainer, results):
+ r"""
+ 结束 evaluate 时调用,并把 evaluate 的结果传入。
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ :param results: :class:`~fastNLP.core.controllers.Trainer` 内置的 ``Evaluator`` 评测的结果,通常是个 ``dict``;
+ """
+ pass
+
+ @property
+ def callback_name(self):
+ r"""
+ ``callback`` 的名称,我们会使用该名称从 ``checkpoint`` 中读取的相应的 ``state`` 并传递给 :meth:`on_load_checkpoint` 函数。
+
+ :return: 用于区分该 ``callback`` 实例的名称;
+ """
+ return self.__class__.__name__
+
+ @property
+ def need_reproducible_sampler(self) -> bool:
+ r"""
+ 当前 callback 是否需要能够复现的 sampler 。一般用于 checkpoint 类的 callback 。
+ """
+ return False
+
+
+class _CallbackWrapper(Callback):
+ """
+ 对于用户使用函数修饰器加入的 callback 函数,使用该 _CallbackWrapper 类为其进行定制,这一个类只保留用户的
+ 这一个 callback 函数;
+ """
+ def __init__(self, event: Event, fn: Callable):
+ r"""
+ :param event: 具体的 callback 时机,例如 'on_train_begin' 等;
+ :param fn: 用户定制的 callback 函数;
+ """
+
+ self.fn = fn
+ if isinstance(event, Event):
+ _filter = Filter(event.every, event.once, event.filter_fn)
+ setattr(self, event.value, _filter(fn))
+
+ @property
+ def callback_name(self):
+ return self.fn.__name__
diff --git a/fastNLP/core/callbacks/callback_event.py b/fastNLP/core/callbacks/callback_event.py
new file mode 100644
index 00000000..f632cf3c
--- /dev/null
+++ b/fastNLP/core/callbacks/callback_event.py
@@ -0,0 +1,500 @@
+from typing import Optional, Callable, Dict
+from functools import wraps
+
+
+__all__ = [
+ 'Event',
+ 'Filter'
+]
+
+
+def check_legality(fn):
+ @wraps(fn)
+ def wrap(every=None, once=None, filter_fn=None):
+ if (every is None) and (once is None) and (filter_fn is None):
+ every = 1
+
+ if not ((every is not None) ^ (once is not None) ^ (filter_fn is not None)):
+ raise ValueError("These three values should be only set one.")
+
+ if (filter_fn is not None) and not callable(filter_fn):
+ raise TypeError("Argument filter_fn should be a callable")
+
+ if (every is not None) and not (isinstance(every, int) and every > 0):
+ raise ValueError("Argument every should be integer and greater than zero")
+
+ if (once is not None) and not (isinstance(once, int) and once > 0):
+ raise ValueError("Argument once should be integer and positive")
+ return fn(every=every, once=once, filter_fn=filter_fn)
+ return wrap
+
+
+class Event:
+ """
+ 与 :meth:`Trainer.on` 函数配合使用,达到控制 callback 函数运行时机的目的。
+
+ :param value: Trainer 的 callback 时机;
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 是否仅运行一次;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ """
+ every: Optional[int]
+ once: Optional[bool]
+
+ def __init__(self, value: str, every: Optional[int] = None, once: Optional[bool] = None,
+ filter_fn: Optional[Callable] = None):
+ self.every = every
+ self.once = once
+ self.filter_fn = filter_fn
+ self.value = value
+
+ def __str__(self):
+ return "".format(self.value, self.every, self.once,
+ self.filter_fn)
+ @staticmethod
+ def on_after_trainer_initialized(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_after_trainer_initialized` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_after_trainer_initialized', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_sanity_check_begin(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_sanity_check_begin` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ :return:
+ """
+ return Event(value='on_sanity_check_begin', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_sanity_check_end(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_sanity_check_end` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_sanity_check_end', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_train_begin(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_train_begin` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_train_begin', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_train_end(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_train_end` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_train_end', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_train_epoch_begin(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_train_epoch_begin` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_train_epoch_begin', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_train_epoch_end(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_train_epoch_end` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_train_epoch_end', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_fetch_data_begin(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_fetch_data_begin` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_fetch_data_begin', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_fetch_data_end(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_fetch_data_end` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_fetch_data_end', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_train_batch_begin(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_train_batch_begin` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_train_batch_begin', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_train_batch_end(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_train_batch_end` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_train_batch_end', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_exception(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_exception` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_exception', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_save_model(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_save_model` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_save_model', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_load_model(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_load_model` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_load_model', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_save_checkpoint(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_save_checkpoint` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_save_checkpoint', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_load_checkpoint(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_load_checkpoint` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_load_checkpoint(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_load_checkpoint` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_before_backward(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_before_backward` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_before_backward', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_after_backward(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_after_backward` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_after_backward', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_before_optimizers_step(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_before_optimizers_step` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_before_optimizers_step', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_after_optimizers_step(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_after_optimizers_step` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_after_optimizers_step', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_before_zero_grad(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_before_zero_grad` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_before_zero_grad', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_after_zero_grad(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_after_zero_grad` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_after_zero_grad', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_evaluate_begin(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_evaluate_begin` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_evaluate_begin', every=every, once=once, filter_fn=filter_fn)
+
+ @staticmethod
+ def on_evaluate_end(every=None, once=None, filter_fn=None):
+ """
+ 当 Trainer 运行到 :func:`on_evaluate_end` 时触发;
+
+ 以下三个参数互斥,只能设置其中一个。默认为行为等同于 ``every=1`` 。
+
+ :param every: 每触发多少次才真正运行一次;
+ :param once: 在第一次运行后时候再次执行;
+ :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和
+ `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer;
+ :return:
+ """
+ return Event(value='on_evaluate_end', every=every, once=once, filter_fn=filter_fn)
+
+
+class Filter:
+ r"""
+ 可以控制一个函数实际的运行频率的函数修饰器。
+
+ :param every: 表示一个函数隔多少次运行一次;
+ :param once: 表示一个函数是否只运行一次;
+ :param filter_fn: 用户定制的频率控制函数;注意该函数内部的频率判断应当是无状态的,除了参数 `self.num_called` 和
+ `self.num_executed` 外,因为我们会在预跑后重置这两个参数的状态;
+ """
+ def __init__(self, every: Optional[int] = None, once: Optional[bool] = None, filter_fn: Optional[Callable] = None):
+ # check legality
+ check_legality(lambda *args,**kwargs:...)(every, once, filter_fn)
+ if (every is None) and (once is None) and (filter_fn is None):
+ every = 1
+ # 设置变量,包括全局变量;
+ self.num_called = 0
+ self.num_executed = 0
+
+ if every is not None:
+ self._every = every
+ self._filter = self.every_filter
+ elif once is not None:
+ self._once = once
+ self._filter = self.once_filter
+ else:
+ self._filter = filter_fn
+
+ def __call__(self, fn: Callable):
+
+ @wraps(fn)
+ def wrapper(*args, **kwargs) -> Callable:
+ self.num_called += 1
+
+ # 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer;
+ trainer = args[0]
+ if self._filter(self, trainer):
+ self.num_executed += 1
+ return fn(*args, **kwargs)
+
+ wrapper.__fastNLP_filter__ = self
+ return wrapper
+
+ def every_filter(self, *args):
+ return self.num_called % self._every == 0
+
+ def once_filter(self, *args):
+ return self.num_called == self._once
+
+ def state_dict(self) -> Dict:
+ r"""
+ 通过该函数来保存该 `Filter` 的状态;
+ """
+ return {"num_called": self.num_called, "num_executed": self.num_executed}
+
+ def load_state_dict(self, state: Dict):
+ r"""
+ 通过该函数来加载 `Filter` 的状态;
+
+ :param state: 通过 `Filter.state_dict` 函数保存的状态元组;
+ """
+ self.num_called = state["num_called"]
+ self.num_executed = state["num_executed"]
+
+
+
+
+
+
+
diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py
new file mode 100644
index 00000000..bf1de884
--- /dev/null
+++ b/fastNLP/core/callbacks/callback_manager.py
@@ -0,0 +1,321 @@
+import inspect
+from typing import List, Optional, Dict, Sequence
+from collections import defaultdict
+
+from .callback_event import Event
+from .callback import Callback
+from fastNLP.core.log import logger
+from .progress_callback import ProgressCallback, choose_progress_callback
+from ..utils.exceptions import EarlyStopException
+from ..utils.utils import _get_fun_msg
+
+
+def _transfer(func):
+ r"""
+ 装饰器,将对CallbackManager的调用转发到各个Callback子类.
+ 需要注意这里的 wrapper 内的函数不会运行 `func` 本身,因此如果有什么需要直接在 callback 函数内运行的代码,请放在 TrainerCallback 内;
+ """
+
+ def wrapper(manager, *arg, **kwargs):
+ manager.callback_counter[func.__name__] += 1 # 给实际被调用的 callback_fn 的计数加 1;
+ for callback_fn in manager.callback_fns[func.__name__]:
+ try:
+ callback_fn(*arg, **kwargs)
+ except (EarlyStopException, KeyboardInterrupt) as e:
+ raise e
+ except BaseException as e:
+ logger.error(f"The following callback_fn raise exception:{_get_fun_msg(callback_fn)}.")
+ raise e
+ return wrapper
+
+
+def prepare_callbacks(callbacks, progress_bar: str):
+ """
+ :param callbacks: 对用户传入的类 ``callback`` 进行检查,查看是否是否继承了我们的 ``Callback`` 类;
+ :param progress_bar: 选择怎样的 ``progress_bar`` 给 ``Trainer`` 使用;
+ :return:
+ """
+ _callbacks = []
+ if callbacks is not None:
+ if isinstance(callbacks, Callback):
+ callbacks = [callbacks]
+ if not isinstance(callbacks, Sequence):
+ raise ValueError("Parameter `callbacks` should be type 'List' or 'Tuple'.")
+ callbacks = list(callbacks)
+ for _callback in callbacks:
+ if not isinstance(_callback, Callback):
+ raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`")
+ _callbacks += callbacks
+
+ has_no_progress = True
+ for _callback in _callbacks:
+ if isinstance(_callback, ProgressCallback):
+ has_no_progress = False
+ if has_no_progress and progress_bar is not None:
+ callback = choose_progress_callback(progress_bar)
+ if callback is not None:
+ _callbacks = [callback] + _callbacks # 放在最前面,方便分割不同 epoch
+ has_no_progress = False
+ elif has_no_progress is False and progress_bar not in ('auto', None):
+ logger.rank_zero_warning(f"Since you have passed in ProgressCallback, progress_bar={progress_bar} will be ignored.")
+
+ if has_no_progress:
+ logger.rank_zero_warning("No progress bar is provided, there will have no progress output during training.")
+
+ return _callbacks
+
+
+class CallbackManager:
+ r"""
+ 用来管理训练过程中的所有的 callback 实例;
+ """
+ all_callbacks: List[Callback]
+ class_callbacks: Optional[List[Callback]] # 用来保留原始的类callback;
+ callback_fns: dict
+
+ def __init__(self, callbacks: Optional[List[Callback]]):
+ r"""
+ 注意 callback 的调用顺序为:
+
+ 1. 通过函数修饰器 `Trainer.on` 添加的 callback 函数;
+ 2. 通过 `Trainer` 的参数 `callbacks` 添加的 callback 类;
+ 3. 通过 `Trainer.add_callback_fn` 添加的 callback 函数;
+
+ :param callbacks: 初始化时可以传入的一系列 :class:`~fastNLP.Callback` 类,通常为用户在初始化 ``Trainer`` 时直接传入的 callback 列表;
+ """
+ self._need_reproducible_sampler = False
+
+ self.callback_fns = defaultdict(list)
+ # 因为理论上用户最多只能通过 'trainer.on_train_begin' 或者 'trainer.callback_manager.on_train_begin' 来调用,即其是没办法
+ # 直接调用具体的某一个 callback 函数,而不调用其余的同名的 callback 函数的,因此我们只需要记录具体 Event 的时机即可;
+ self.callback_counter = defaultdict(lambda: 0)
+ if len(callbacks):
+ # 这一对象是为了保存原始的类 callback 对象来帮助用户进行 debug,理论上在正常的使用中你并不会需要它;
+ self.class_callbacks = callbacks
+ else:
+ self.class_callbacks: Optional[List[Callback]] = []
+
+ # 预跑需要拿到每一个被 `Filter` 修饰的函数的 `Filter` 实例,从而在预跑结束后重置它们的内部状态;
+ self._callback_filters = [] # [(callback_name, fn_name, filter 实例), ]
+
+ # 保留所有 callback 的引用,用于断点重训;包括全部的三种callback:函数修饰器 callback;类 callback;纯函数 callback;
+ # 因为所有的 callback 都是通过函数 `self.add_one_callback` 添加,因此我们选择在其下进行添加;
+ # 一个比较重要的概念在于在训练过程运行的时候,两个 callback 的 callback_name 可以是一样的,并且理论上不会造成任何影响;但是当
+ # `on_load_checkpoint` 时,我们需要处理两个 callback_name 一样这种情况了;
+ # 因此这里的 `all_callbacks` 为了避免正常训练过程的运行,只能是一个 List,而不能是一个 dict,`_callback_filters` 也是一样;
+ self.all_callbacks = []
+
+ def initialize_class_callbacks(self):
+ r"""
+ 在实际的运行过程中,我们会将具体的一个 callback 实例拆分为单独的一个个 callback 函数,然后将它们加在一个字典里,该字典的键值就是
+ 一个个 callback 时机,也就是 `Event` 的类别;
+ 如果一个 callback 类的 callback 函数并不具备任何作用,我们实际并不会将其加在字典当中;
+ """
+ for each_callback in self.class_callbacks:
+ self._need_reproducible_sampler |= each_callback.need_reproducible_sampler
+ self.dissect_one_callback(each_callback)
+
+ def dissect_one_callback(self, callback: Callback):
+ r"""
+ 将具体的一个 callback 实例的所有 callback 函数拆分后按时机插入到字典中;
+
+ :param callback: 一个具体的 callback 实例;
+ """
+ self.all_callbacks.append(callback)
+ for name, member in Event.__dict__.items():
+ if isinstance(member, staticmethod):
+ _fn = getattr(callback, name)
+ if inspect.getsource(_fn) != inspect.getsource(getattr(Callback, name)):
+ self.callback_fns[name].append(_fn)
+ self.extract_callback_filter_state(callback.callback_name, _fn)
+
+ def extract_callback_filter_state(self, callback_name, callback_fn):
+ r"""
+ 将一个具体的 callback 函数的 filter 的状态抽取出来;
+ """
+ if hasattr(callback_fn, "__fastNLP_filter__"):
+ # 注意我们的 `Filter` 使用了 `@wraps` 来保证被修饰的函数的 `__name__` 属性仍旧是其真实的名字;
+ self._callback_filters.append((callback_name, callback_fn.__name__, callback_fn.__fastNLP_filter__))
+
+ def on_save_checkpoint(self, trainer) -> Dict:
+ r"""
+ 用于断点重训的 callback 的保存函数;
+ 该函数主要涉及两个方面:
+
+ 1. callback 的状态的保存;我们会调用每一个 callback 的 :func:`on_save_checkpoint` 方法,该方法应当返回一个字典,其中包含着
+ 断点重训应当保存的状态;
+ 2. 每一个具体的 callback 函数的 filter 的状态;
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ :return: 一个包含上述内容的字典,格式如下:
+ .. code-block::
+
+ {
+ "callback_name_1": {
+ "states": {...},
+ "filter_states": {"on_train_begin": filter1.state_dict(), ...}
+ }
+ }
+ """
+
+ states = {}
+ # 1. 每一个 callback 的状态;
+ # 如果有两个 callback 的 name 相同,那么我们只会保存第一个;
+ _duplicated_callbacks = []
+ for each_callback in self.all_callbacks:
+ if each_callback.callback_name in states:
+ _duplicated_callbacks.append(each_callback.callback_name)
+ # 对于 callback_name 有重复的 callback,我们仍旧会调用其 `on_save_checkpoint` 函数,就如同调用其它 callback 函数
+ # 一样,但是其结果并不会存储在 states 中返回;
+ each_callback.on_save_checkpoint(trainer)
+ else:
+ states[each_callback.callback_name] = {}
+ states[each_callback.callback_name]["states"] = each_callback.on_save_checkpoint(trainer)
+
+ if len(_duplicated_callbacks) > 0:
+ logger.warning(f"Notice these callback_name: {_duplicated_callbacks} are duplicated, "
+ f"fastNLP will only save the first callback's state.")
+
+ # 2. 每一个具体的 callback 函数的 filter 的状态;
+ _record_duplicated_callback_names = set()
+ for each_callback_filters in self._callback_filters:
+ if each_callback_filters[0] not in _record_duplicated_callback_names:
+ _record_duplicated_callback_names.add(each_callback_filters[0])
+ if 'filter_states' not in states[each_callback_filters[0]]:
+ states[each_callback_filters[0]]["filter_states"] = {}
+ states[each_callback_filters[0]]["filter_states"][each_callback_filters[1]] = each_callback_filters[2].state_dict()
+
+ # 3. 保存 callback_counter;
+ # callback_counter 不应当被保存,因为其在断点重训时会由新的 callback_manager 重新初始化;
+ # 对于断点重训,我们不会保存 Trainer 的所有参数,例如 batch_step_fn;如果在断点重训时重新初始化 Trainer 发现 batch_step_fn
+ # 不为 None,那么 Trainer 就会调用实际的 check_batch_step_fn 函数,从而需要 callback_counter 为全新的状态;
+
+ return states
+
+ def on_load_checkpoint(self, trainer, states: Dict):
+ r"""
+ 用于断点重训的加载函数,对应于断点重训的保存函数;
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ :param states: 同 :func:`on_save_checkpoint` 函数的返回值;
+ """
+
+ # 1. 先恢复每一个具体的 callback 函数的 filter 的状态;
+ # self._callback_filters 是当前的 Trainer 的 callback 的 filter 状态,是我们要去维护的对象;
+ _already_loaded_callback_names = set()
+ _duplicated_callback_names = set()
+ for each_callback_filters in self._callback_filters:
+ if each_callback_filters[0] in states:
+ if each_callback_filters[0] not in _already_loaded_callback_names:
+ _already_loaded_callback_names.add(each_callback_filters[0])
+ if 'filter_states' in states[each_callback_filters[0]] and \
+ each_callback_filters[1] in states[each_callback_filters[0]]['filter_states']:
+ each_callback_filters[2].load_state_dict(states[each_callback_filters[0]]['filter_states'][each_callback_filters[1]])
+ else:
+ _duplicated_callback_names.add(each_callback_filters[0])
+
+ if len(_duplicated_callback_names) > 0:
+ logger.rank_zero_warning(f"Notice these callback_name: {_duplicated_callback_names} are duplicated, "
+ f"fastNLP will only load the first callback's state.")
+
+ # 2. 再恢复每一个 callback 的单独的状态;
+ # 每一个我们自己提供的类 callback,都需要重写其特定的 `callback_name` 方法,保证如果两个 callback 的 callback_name 一样,
+ # 那么它们就应该是同一个对象;
+ _already_loaded_callback_names = set()
+ for each_callback in self.all_callbacks:
+ if each_callback.callback_name in states and each_callback.callback_name not in _already_loaded_callback_names:
+ _already_loaded_callback_names.add(each_callback.callback_name)
+ # 这里要注意,我们已经确保每一个 callback 的 `on_load_checkpoint` 函数拿到的就是其自己的状态;
+ each_callback.on_load_checkpoint(trainer, states[each_callback.callback_name]["states"])
+
+ @property
+ def has_trainer_checkpoint(self) -> bool:
+ return self._need_reproducible_sampler
+
+ @_transfer
+ def on_after_trainer_initialized(self, trainer):
+ pass
+
+ @_transfer
+ def on_sanity_check_begin(self, trainer):
+ pass
+
+ @_transfer
+ def on_sanity_check_end(self, trainer):
+ pass
+
+ @_transfer
+ def on_train_begin(self, trainer):
+ pass
+
+ @_transfer
+ def on_train_end(self, trainer):
+ pass
+
+ @_transfer
+ def on_train_epoch_begin(self, trainer):
+ pass
+
+ @_transfer
+ def on_train_epoch_end(self, trainer):
+ pass
+
+ @_transfer
+ def on_fetch_data_begin(self, trainer):
+ pass
+
+ @_transfer
+ def on_fetch_data_end(self, trainer):
+ pass
+
+ @_transfer
+ def on_train_batch_begin(self, trainer, batch, indices=None):
+ pass
+
+ @_transfer
+ def on_train_batch_end(self, trainer):
+ pass
+
+ @_transfer
+ def on_exception(self, trainer, exception):
+ pass
+
+ @_transfer
+ def on_save_model(self, trainer):
+ pass
+
+ @_transfer
+ def on_load_model(self, trainer):
+ pass
+
+ @_transfer
+ def on_before_backward(self, trainer, outputs):
+ pass
+
+ @_transfer
+ def on_after_backward(self, trainer):
+ pass
+
+ @_transfer
+ def on_before_optimizers_step(self, trainer, optimizers):
+ pass
+
+ @_transfer
+ def on_after_optimizers_step(self, trainer, optimizers):
+ pass
+
+ @_transfer
+ def on_before_zero_grad(self, trainer, optimizers):
+ pass
+
+ @_transfer
+ def on_after_zero_grad(self, trainer, optimizers):
+ pass
+
+ @_transfer
+ def on_evaluate_begin(self, trainer):
+ pass
+
+ @_transfer
+ def on_evaluate_end(self, trainer, results):
+ pass
diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py
new file mode 100644
index 00000000..44bd9c03
--- /dev/null
+++ b/fastNLP/core/callbacks/checkpoint_callback.py
@@ -0,0 +1,145 @@
+__all__ = [
+ 'CheckpointCallback'
+]
+
+from typing import Union, Optional, Callable, Dict, Sequence
+from pathlib import Path
+import sys
+
+from fastNLP.core.log import logger
+from .topk_saver import TopkSaver
+from .callback import Callback
+from ..utils.exceptions import EarlyStopException
+
+
+class CheckpointCallback(Callback):
+ """
+ 保存 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下::
+
+ - folder/
+ - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
+ - {save_object}-epoch_{epoch_idx}/ # 满足 every_n_epochs 条件保存的模型
+ - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 every_n_batches 保存的模型
+ - {save_object}-last/ # 最后一个 epoch 的保存
+ - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。
+ - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名
+
+ ``model_save_fn`` 为 ``None`` ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。若 ``model_save_fn`` 不为 ``None``,
+ 则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。默认情况下,本 checkpoint 只保存了 model
+ 的状态;如还需保存 Trainer 的状态以断点重训的话,请使用 ``save_object='trainer'`` 。
+
+ :param monitor: 监控的 metric 值。
+
+ * 为 ``None``
+ 将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。
+ * 为 ``str``
+ 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将
+ 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。
+ * 为 :class:`Callable`
+ 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关
+ 的 ``monitor`` 值请返回 ``None`` 。
+
+ :param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
+ 时间戳文件夹中。如果为 None ,默认使用当前文件夹。
+ :param every_n_epochs: 多少个 epoch 保存一次。
+ :param every_n_batches: 多少个 batch 保存一次。
+ :param last: 如果为 ``True`` ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。如果为 ``False`` 则不会保存 ``{save_object}-last`` 文件
+ :param topk: 保存 monitor 结果中的 ``topk`` 个。
+ :param on_exceptions: 在出异常信息时,是否保存。传入需要捕获的异常的类。默认将捕获 :class:`~fastNLP.core.callbacks.EarlyStopException` 。
+ :param larger_better: monitor 的值是否时越大越好。
+ :param only_state_dict: 保存模型时是否只保存 state_dict 。当 ``model_save_fn`` 不为 ``None`` 时,该参数无效。
+ :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。
+ 如果传入了 ``model_save_fn`` 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。
+ :param save_object: 可选 ``['trainer', 'model']`` ,表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果
+ 保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.core.controllers.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断
+ 点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。
+ :param save_evaluate_results: 是否保存 evaluate 的结果。如果为 ``True`` ,在保存 topk 模型的 folder 中还将额外保存一个
+ ``fastnlp_evaluate_results.json`` 文件,记录当前的 results。仅在设置了 ``topk`` 的场景下有用,默认为 ``True`` 。
+ :param kwargs:
+ """
+ def __init__(self, folder: Optional[Union[str, Path]] = None, every_n_epochs: Optional[int] = None,
+ every_n_batches: Optional[int] = None, last: bool = False, topk: int = 0,
+ on_exceptions: Optional[Union[BaseException, Sequence[BaseException]]] = (EarlyStopException),
+ monitor: Optional[Union[str, Callable]] = None, larger_better: bool = True,
+ only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, save_object: str = 'model',
+ save_evaluate_results=True, **kwargs):
+ super().__init__()
+ if every_n_epochs is not None:
+ if not isinstance(every_n_epochs, int) or every_n_epochs < 1:
+ raise ValueError("Parameter `every_n_epochs` should be an int and greater than or equal to 1.")
+ else:
+ every_n_epochs = sys.maxsize # 使得没有数字可以整除
+
+ if every_n_batches is not None:
+ if not isinstance(every_n_batches, int) or every_n_batches < 1:
+ raise ValueError("Parameter `every_n_batches` should be an int and greater than or equal to 1.")
+ else:
+ every_n_batches = sys.maxsize # 使得没有数字可以整除
+
+ if on_exceptions is not None:
+ if not isinstance(on_exceptions, Sequence):
+ on_exceptions = [on_exceptions]
+
+ for exception in on_exceptions:
+ if not issubclass(exception, BaseException):
+ raise TypeError("Each exception in parameter `on_exception` can only be "
+ "`BaseException` type.")
+ else:
+ on_exceptions = []
+
+ self.topk_saver = TopkSaver(topk=topk, monitor=monitor, larger_better=larger_better, folder=folder,
+ save_object=save_object, only_state_dict=only_state_dict, model_save_fn=model_save_fn,
+ save_evaluate_results=save_evaluate_results, **kwargs)
+ self.topk_saver.log_name = self.__class__.__name__
+
+ self.topk = topk
+ self.save_object = save_object
+
+ self.every_n_epochs = every_n_epochs
+ self.every_n_batches = every_n_batches
+ self.last = last
+ self.exceptions = on_exceptions
+
+ @property
+ def need_reproducible_sampler(self) -> bool:
+ return self.save_object == 'trainer'
+
+ def on_after_trainer_initialized(self, trainer, driver):
+ if self.topk_saver.topk_queue: # 需要设置 monitor
+ if self.topk_saver.monitor is None:
+ self.topk_saver.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better)
+ if self.topk_saver.topk_queue and trainer.evaluator is None:
+ logger.warning(f"You set `topk={self.topk}`, but `evaluate_dataloaders` is not set in Trainer.")
+
+ def on_evaluate_end(self, trainer, results):
+ # 如果发生了保存,则返回的 folder 不为 None
+ folder = self.topk_saver.save_topk(trainer, results)
+
+ def on_train_epoch_end(self, trainer: "fastNLP.Trainer"):
+ if trainer.cur_epoch_idx % self.every_n_epochs == 0:
+ folder_name = f'{self.save_object}-epoch_{trainer.cur_epoch_idx}'
+ self.topk_saver.save(trainer, folder_name=folder_name)
+ if self.last:
+ folder_name = f'{self.save_object}-last'
+ self.topk_saver.save(trainer, folder_name=folder_name)
+
+ def on_train_batch_end(self, trainer):
+ if trainer.global_forward_batches % self.every_n_batches == 0:
+ folder_name = f'{self.save_object}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}'
+ self.topk_saver.save(trainer, folder_name=folder_name)
+
+ def on_exception(self, trainer, exception: BaseException):
+ if exception.__class__ in self.exceptions:
+ folder_name = f'{self.save_object}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}-' \
+ f'exception_{exception.__class__.__name__}'
+ self.topk_saver.save(trainer, folder_name=folder_name)
+
+ def on_save_checkpoint(self, trainer) -> Dict:
+ states = {}
+ states['topk_saver'] = self.topk_saver.state_dict()
+ return states
+
+ def on_load_checkpoint(self, trainer, states: Optional[Dict]):
+ topk_saver_states = states['topk_saver']
+ self.topk_saver.load_state_dict(topk_saver_states)
+
diff --git a/fastNLP/core/callbacks/early_stop_callback.py b/fastNLP/core/callbacks/early_stop_callback.py
new file mode 100644
index 00000000..8e542d56
--- /dev/null
+++ b/fastNLP/core/callbacks/early_stop_callback.py
@@ -0,0 +1,74 @@
+__all__ = [
+ 'EarlyStopCallback'
+]
+
+from typing import Dict, Union, Callable
+
+from .has_monitor_callback import HasMonitorCallback
+from fastNLP.core.utils.exceptions import EarlyStopException
+
+
+class EarlyStopCallback(HasMonitorCallback):
+ """
+ 用于 early stop 的 callback 。当监控的结果连续多少次没有变好便 raise 一个 :class:`EarlyStopException` 。
+
+ :param monitor: 监控的 metric 值。
+
+ * 为 ``None``
+ 将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。
+ * 为 ``str``
+ 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将
+ 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。
+ * 为 :class:`Callable`
+ 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关
+ 的 ``monitor`` 值请返回 ``None`` 。
+ :param larger_better: monitor 的值是否是越大越好。
+ :param patience: 多少次 evaluate 不没有提升就停止。
+ """
+ def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10):
+ super(EarlyStopCallback, self).__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True)
+ self.wait = 0
+ self.patience = patience
+
+ def on_evaluate_end(self, trainer, results):
+ monitor_value = self.get_monitor_value(results)
+ if monitor_value is None:
+ return
+ if self.is_better_monitor_value(monitor_value, keep_if_better=True):
+ self.wait = 0
+ else:
+ self.wait += 1
+
+ def on_fetch_data_begin(self, trainer):
+ # 当是 step evaluate 的时候,下一步执行的就是这个, 所以在这里检查。
+ if self.wait >= self.patience:
+ raise EarlyStopException(f"After {self.wait} validations, no improvement for "
+ f"metric `{self._real_monitor}`(best value: {self.monitor_value})")
+
+ def on_train_epoch_begin(self, trainer):
+ # 当是 epoch evaluate 的时候,下一步执行的就是这个, 所以在这里检查。
+ if self.wait >= self.patience:
+ raise EarlyStopException(f"After {self.wait} validations, no improvement for "
+ f"metric `{self._real_monitor}`(best value: {self.monitor_value})")
+
+ def on_save_checkpoint(self, trainer) -> Dict:
+ states = {
+ 'patience': self.patience,
+ 'wait': self.wait,
+ 'monitor_value': self.monitor_value
+ }
+ if not callable(self._real_monitor):
+ states['_real_monitor'] = self._real_monitor
+ return states
+
+ def on_load_checkpoint(self, trainer, states):
+ self.patience = states['patience']
+ self.wait = states['wait']
+ self.monitor_value = float(states['monitor_value'])
+ if '_real_monitor' in states:
+ self._real_monitor = states['_real_monitor']
+
+ @property
+ def callback_name(self):
+ return f'EarlyStopCallback#monitor-{self.monitor_name}#patience-{self.patience}'
+
diff --git a/fastNLP/core/callbacks/fitlog_callback.py b/fastNLP/core/callbacks/fitlog_callback.py
new file mode 100644
index 00000000..44430b67
--- /dev/null
+++ b/fastNLP/core/callbacks/fitlog_callback.py
@@ -0,0 +1,83 @@
+__all__ = [
+ 'FitlogCallback'
+]
+import os
+
+from .has_monitor_callback import HasMonitorCallback
+from ...envs import _module_available
+from ...envs import get_global_rank
+from ..log import logger
+if _module_available('fitlog'):
+ import fitlog
+
+
+class FitlogCallback(HasMonitorCallback):
+ """
+ 自动记录 ``evaluation`` 结果到 ``fitlog`` 中。会自动记录每一次 ``evaluate`` 后的结果;同时会根据
+ ``monitor`` 记录最好的结果。另外,会自动将非 ``rank 0`` 上的 ``fitlog`` 设置为 ``debug`` 状态。同时还会在 ``fitlog`` 的
+ ``other`` 列中记录一个 ``launch_time`` ,可以通过这个数值找到当前这个脚本的在 save_folder (如果有使用其它需要保存模型的
+ ``Callback`` ,例如 :class:`~fastNLP.core.callbacks.CheckpointCallback` )下的文件夹名称。
+
+ :param monitor: 监控的 metric 值。
+
+ * 为 ``None``
+ 将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。
+ * 为 ``str``
+ 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将
+ 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。
+ * 为 :class:`Callable`
+ 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关
+ 的 ``monitor`` 值请返回 ``None`` 。
+
+ :param larger_better: 是否是越大越好。
+ :param log_exception: 是否记录 ``exception`` 。
+ :param log_loss_every: 多少个 ``batch`` 记录一次 loss 到 ``fitlog`` 中。
+ """
+ def __init__(self, monitor=None, larger_better: bool = True, log_exception:bool=True, log_loss_every:int=0):
+ assert _module_available('fitlog'), "fitlog is not installed."
+
+ super().__init__(monitor=monitor, larger_better=larger_better)
+ self.log_exception = log_exception
+ self.log_loss_every = log_loss_every
+ self.avg_loss = 0
+ self.catch_exception = False
+
+ def on_after_trainer_initialized(self, trainer, driver):
+ if get_global_rank() != 0: # 如果不是 global rank 为 0 ,需要关闭 fitlog
+ fitlog.debug()
+ super().on_after_trainer_initialized(trainer, driver)
+ fitlog.add_other(name='launch_time', value=os.environ['FASTNLP_LAUNCH_TIME'])
+
+ def on_sanity_check_end(self, trainer, sanity_check_res):
+ super(FitlogCallback, self).on_sanity_check_end(trainer, sanity_check_res)
+ if self.monitor is None:
+ logger.rank_zero_warning(f"No monitor set for {self.log_name}. Therefore, no best metric will "
+ f"be logged.")
+
+ def on_evaluate_end(self, trainer, results):
+ results = self.itemize_results(results)
+ fitlog.add_metric(results, step=trainer.global_forward_batches, epoch=trainer.cur_epoch_idx)
+ if self.is_better_results(results, keep_if_better=True):
+ results['step'] = trainer.global_forward_batches
+ results['epoch'] = trainer.cur_epoch_idx
+ fitlog.add_best_metric(results)
+
+ def on_before_backward(self, trainer, outputs):
+ if self.log_loss_every > 0:
+ loss = trainer.extract_loss_from_outputs(outputs)
+ self.avg_loss += loss.item()
+ if trainer.global_forward_batches % self.log_loss_every == 0:
+ fitlog.add_loss(self.avg_loss / self.log_loss_every * trainer.accumulation_steps, name='loss',
+ step=trainer.global_forward_batches,
+ epoch=trainer.cur_epoch_idx)
+ self.avg_loss = 0
+
+ def on_train_end(self, trainer):
+ if not self.catch_exception:
+ fitlog.finish()
+
+ def on_exception(self, trainer, exception):
+ self.catch_exception = True
+ fitlog.finish(status=1)
+ if self.log_exception:
+ fitlog.add_other(repr(exception), name='except_info')
diff --git a/fastNLP/core/callbacks/has_monitor_callback.py b/fastNLP/core/callbacks/has_monitor_callback.py
new file mode 100644
index 00000000..bb865bf8
--- /dev/null
+++ b/fastNLP/core/callbacks/has_monitor_callback.py
@@ -0,0 +1,255 @@
+__all__ = [
+ 'HasMonitorCallback',
+ 'ExecuteOnceBetterMonitor',
+ 'ResultsMonitor'
+]
+
+from typing import Dict, Union, Any
+from abc import ABC
+import functools
+
+from fastNLP.core.utils import apply_to_collection
+from fastNLP.core.callbacks import Callback
+from fastNLP.core.callbacks.utils import _get_monitor_value
+from fastNLP.core.log import logger
+from fastNLP.core.utils.utils import _check_valid_parameters_number
+
+
+class CanItemDataType(ABC):
+ @classmethod
+ def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
+ if cls is CanItemDataType:
+ item = getattr(subclass, 'item', None)
+ return callable(item)
+ return NotImplemented
+
+
+class ResultsMonitor:
+ """
+ 可用于监控某个数值,并通过 :meth:`is_better_results` 等接口检测结果是否变得更好。
+
+ :param monitor: 监控的 metric 值:
+
+ * 为 ``None``
+ 将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置);
+ * 为 ``str``
+ 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将
+ 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` ;
+ * 为 :class:`Callable`
+ 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关
+ 的 ``monitor`` 值请返回 ``None`` ;
+ :param larger_better: monitor 是否为越大越好;
+ """
+ def __init__(self, monitor:Union[Callback, str], larger_better:bool=True):
+ self.set_monitor(monitor, larger_better)
+ self._log_name = self.__class__.__name__
+
+ def set_monitor(self, monitor, larger_better):
+ if callable(monitor): # 检查是否能够接受一个参数
+ _check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor')
+ self.monitor = monitor
+ else:
+ self.monitor = str(monitor) if monitor is not None else None
+ if self.monitor is not None:
+ self.larger_better = bool(larger_better)
+ if larger_better:
+ self.monitor_value = float('-inf')
+ else:
+ self.monitor_value = float('inf')
+ self._real_monitor = self.monitor
+
+ def itemize_results(self, results):
+ """
+ 执行结果中所有对象的 :meth:`item` 方法(如果没有则忽略),使得 Tensor 类型的数据转为 python 内置类型。
+
+ :param results:
+ :return:
+ """
+ return apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item())
+
+ def get_monitor_value(self, results:Dict)->Union[float, None]:
+ """
+ 获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用 **最长公共字符串算法** 匹配的方式寻找。
+
+ :param results: 评测结果;
+ :return: monitor 的值;如果为 ``None`` ,表明此次没有找到合适的monitor;
+ """
+ if len(results) == 0 or self.monitor is None:
+ return None
+ # 保证所有的 tensor 都被转换为了 python 特定的类型
+ results = self.itemize_results(results)
+ use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor,
+ real_monitor=self._real_monitor,
+ res=results)
+ if monitor_value is None:
+ return monitor_value
+ # 第一次运行
+ if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor:
+ logger.rank_zero_warning(f"We can not find monitor:`{self.monitor}` for `{self.log_name}` in the "
+ f"evaluation result (with keys as "
+ f"{list(results.keys())}), we use the `{use_monitor}` as the monitor.", once=True)
+ # 检测到此次和上次不同。
+ elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor:
+ logger.rank_zero_warning(f"Change of monitor detected for `{self.log_name}`. "
+ f"The expected monitor is:`{self.monitor}`, last used monitor is:"
+ f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a "
+ f"customized monitor function when the evaluation results are varying between validation.")
+
+ self._real_monitor = use_monitor
+ return monitor_value
+
+ def is_better_monitor_value(self, monitor_value: float, keep_if_better=True):
+ """
+ 检测 ``monitor_value`` 是否是更好的
+
+ :param monitor_value: 待检查的 ``monitor_value`` 。如果为 ``None`` ,返回 False;
+ :param keep_if_better: 如果传入的 ``monitor_value`` 值更好,则将其保存下来;
+ :return:
+ """
+ if monitor_value is None:
+ return False
+ better = self.is_former_monitor_value_better(monitor_value, self.monitor_value)
+ if keep_if_better and better:
+ self.monitor_value = monitor_value
+ return better
+
+ def is_better_results(self, results, keep_if_better=True):
+ """
+ 检测给定的 ``results`` 是否比上一次更好,如果本次 results 中没有找到相关的 monitor 返回 ``False``。
+
+ :param results: evaluation 结果;
+ :param keep_if_better: 当返回为 ``True`` 时,是否保存到 ``self.monitor_value`` 中;
+ :return:
+ """
+ monitor_value = self.get_monitor_value(results)
+ if monitor_value is None:
+ return False
+ return self.is_better_monitor_value(monitor_value, keep_if_better=keep_if_better)
+
+ def is_former_monitor_value_better(self, monitor_value1, monitor_value2):
+ """
+ 传入的两个值中,是否 ``monitor_value1`` 的结果更好。
+
+ :param monitor_value1:
+ :param monitor_value2:
+ :return:
+ """
+ if monitor_value1 is None and monitor_value2 is None:
+ return True
+ if monitor_value1 is None:
+ return False
+ if monitor_value2 is None:
+ return True
+ better = False
+ if (self.larger_better and monitor_value1 > monitor_value2) or \
+ (not self.larger_better and monitor_value1 < monitor_value2):
+ better = True
+ return better
+
+ @property
+ def monitor_name(self):
+ """
+ 返回 monitor 的名字,如果 monitor 是个 Callable 的函数,则返回该函数的名称。
+
+ :return:
+ """
+ if callable(self.monitor):
+ try:
+ monitor = self.monitor
+ while isinstance(monitor, functools.partial):
+ monitor = monitor.func
+ monitor_name = monitor.__qualname__
+ except:
+ monitor_name = self.monitor.__name__
+ elif self.monitor is None:
+ return None
+ else:
+ # 这里是能是monitor,而不能是real_monitor,因为用户再次运行的时候real_monitor被初始化为monitor了
+ monitor_name = str(self.monitor)
+ return monitor_name
+
+ @property
+ def log_name(self) -> str:
+ """
+ 内部用于打印当前类别信息使用
+
+ :return:
+ """
+ return self._log_name
+
+ @log_name.setter
+ def log_name(self, value):
+ self._log_name = value
+
+
+class HasMonitorCallback(ResultsMonitor, Callback):
+ """
+ 该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了
+ (1)判断 monitor 合法性;(2)在需要时, 根据 trainer 的 monitor 设置自己的 monitor 名称。
+
+ :param monitor: 监控的 metric 值:
+
+ * 为 ``None``
+ 将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置);
+ * 为 ``str``
+ 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将
+ 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` ;
+ * 为 :class:`Callable`
+ 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关
+ 的 ``monitor`` 值请返回 ``None`` ;
+ :param larger_better: monitor 是否为越大越好;
+ :param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 ``True`` ,且没检测到设置 monitor 会报错;
+ """
+ def __init__(self, monitor, larger_better, must_have_monitor=False):
+ super().__init__(monitor, larger_better)
+ self.must_have_monitor = must_have_monitor
+
+ def on_after_trainer_initialized(self, trainer, driver):
+ """
+ 如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。
+ 同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。
+
+ :param trainer:
+ :param driver:
+ :return:
+ """
+ if self.monitor is None and trainer.monitor is not None:
+ self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better)
+ if self.must_have_monitor and self.monitor is None:
+ raise RuntimeError(f"No `monitor` is set for {self.log_name}. "
+ f"You can set it in the initialization or through Trainer.")
+ if self.must_have_monitor and self.monitor is not None and trainer.evaluator is None:
+ raise RuntimeError(f"No `evaluate_dataloaders` is set for Trainer. But Callback: {self.log_name}"
+ f" need to watch the monitor:`{self.monitor_name}`.")
+
+ def on_sanity_check_end(self, trainer, sanity_check_res):
+ # 主要核对一下 monitor 是否存在。
+ if self.monitor is not None:
+ self.get_monitor_value(results=sanity_check_res)
+
+
+class ExecuteOnceBetterMonitor(HasMonitorCallback):
+ """
+ 当监控的 ``monitor`` 结果更好的时候,调用 ``execute_fn`` 函数。
+
+ :param monitor: 监控的 metric 值:
+
+ * 为 ``None``
+ 将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 ``monitor`` 值(如果有设置);
+ * 为 ``str``
+ 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将
+ 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` ;
+ * 为 :class:`Callable`
+ 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关
+ 的 ``monitor`` 值请返回 ``None`` ;
+ :param larger_better: monitor 是否是越大越好;
+ :param execute_fn: 一个可执行的函数,不接受任何参数,没有返回值。在 monitor 取得更好结果的时候会调用;
+ """
+ def __init__(self, monitor, larger_better, execute_fn):
+ super().__init__(monitor, larger_better, must_have_monitor=True)
+ _check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn')
+ self.execute_fn = execute_fn
+
+ def on_evaluate_end(self, trainer, results):
+ if self.is_better_results(results):
+ self.execute_fn()
\ No newline at end of file
diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py
new file mode 100644
index 00000000..2bd41b5a
--- /dev/null
+++ b/fastNLP/core/callbacks/load_best_model_callback.py
@@ -0,0 +1,141 @@
+__all__ = [
+ 'LoadBestModelCallback'
+]
+
+import os
+from typing import Optional, Callable, Union
+from .has_monitor_callback import HasMonitorCallback
+from io import BytesIO
+import shutil
+
+from fastNLP.envs.env import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK, FASTNLP_BACKEND_LAUNCH
+from fastNLP.core.log import logger
+from fastNLP.envs import all_rank_call_context
+from fastNLP.core.utils.exceptions import EarlyStopException
+
+
+class LoadBestModelCallback(HasMonitorCallback):
+ """
+ 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型,默认会在加载之后删除权重文件。仅在训练正常结束的时候才能加载
+ 最好的模型。
+
+ :param monitor: 监控的 metric 值:
+
+ * 为 ``None``
+ 将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置);
+ * 为 ``str``
+ 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将
+ 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` ;
+ * 为 :class:`Callable`
+ 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关
+ 的 ``monitor`` 值请返回 ``None`` ;
+ :param larger_better: 该 metric 值是否是越大越好;
+ :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保
+ 不同的机器均可访问当该路径。当 ``model_save_fn`` 不为 None 时该值一定不能为空;
+ :param only_state_dict: 是否只保存模型的参数。当 ``model_save_fn`` 不为空时,该值无效;
+ :param model_save_fn: 保存 model 的函数,与 ``model_load_fn`` 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出,
+ 请在函数内完成对模型的保存;
+ :param model_load_fn: 加载 model 的函数,与 ``model_save_fn`` 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出,
+ 请在函数内完成对模型的加载;
+ :param delete_after_train: 在训练结束后是否删掉模型;
+ """
+ def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool = True, only_state_dict:bool = True,
+ save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None,
+ model_load_fn:Optional[Callable] = None,
+ delete_after_train:bool = True):
+ super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True)
+ if model_load_fn is not None:
+ assert callable(model_load_fn), "`model_load_fn` must be a callable object."
+ assert model_save_fn is not None, "`model_load_fn` and `model_save_fn` must be passed at the same time."
+ if model_save_fn is not None:
+ assert callable(model_save_fn), "`model_save_fn` must be a callable object."
+ assert model_load_fn is not None, "`model_load_fn` and `model_save_fn` must be passed at the same time."
+
+ if model_save_fn is not None:
+ assert save_folder is not None, "When passing `model_save_fn`, `save_folder` must be provided."
+
+ if save_folder:
+ if os.path.exists(save_folder):
+ assert os.path.isdir(save_folder), f"`save_folder={save_folder}` must be a directory."
+
+ self.save_folder = save_folder
+ self.only_state_dict = only_state_dict
+ self.model_save_fn = model_save_fn
+ self.model_load_fn = model_load_fn
+ self.delete_after_after = delete_after_train
+
+ def prepare_save_folder(self, trainer):
+ if not hasattr(self, 'real_save_folder'):
+ if self.save_folder is not None:
+ if not os.path.exists(self.save_folder):
+ os.makedirs(self.save_folder, exist_ok=True)
+ self.save_folder = os.path.join(self.save_folder, os.environ.get(FASTNLP_LAUNCH_TIME))
+ self.real_save_folder = os.path.join(self.save_folder, 'best_so_far')
+ if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
+ os.makedirs(self.real_save_folder, exist_ok=True)
+ if self.save_folder is not None and trainer.driver.is_distributed() and int(
+ os.environ.get(FASTNLP_BACKEND_LAUNCH, 0)) == 1:
+ trainer.driver.barrier()
+ try:
+ self.real_save_folder = trainer.driver.broadcast_object(self.real_save_folder, src=0, group=None)
+ logger.debug(
+ f"Synchronize best model save folder: {self.real_save_folder} for LoadBestModelCallback.")
+ except NotImplementedError:
+ raise RuntimeError(
+ f"Currently {trainer.driver.__class__.__name__} does not support using `save_folder` to "
+ f"save best model when launch using module.")
+ else: # 创建出一个 stringio
+ self.real_save_folder = None
+ self.buffer = BytesIO()
+
+ def on_after_trainer_initialized(self, trainer, driver):
+ super().on_after_trainer_initialized(trainer, driver)
+ self.encounter_exception = False
+
+ def on_evaluate_end(self, trainer, results):
+ if self.is_better_results(results, keep_if_better=True):
+ self.prepare_save_folder(trainer)
+ if self.real_save_folder:
+ trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
+ model_save_fn=self.model_save_fn)
+ else:
+ self.buffer.seek(0)
+ with all_rank_call_context():
+ trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict)
+
+ def on_train_end(self, trainer):
+ if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。
+ # 如果是分布式且报错了,就不要加载了,防止barrier的问题
+ if not (trainer.driver.is_distributed() and self.encounter_exception):
+ if self.real_save_folder:
+ logger.info(f"Loading best model from {self.real_save_folder} with {self._real_monitor}: {self.monitor_value}...")
+ trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
+ model_load_fn=self.model_load_fn)
+ else:
+ logger.info(f"Loading best model from buffer with {self._real_monitor}: {self.monitor_value}...")
+ self.buffer.seek(0)
+ trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict)
+ if self.delete_after_after:
+ if not self.encounter_exception: # 防止出现死锁。
+ trainer.driver.barrier()
+ self._delete_folder()
+ if not self.encounter_exception:
+ trainer.driver.barrier()
+
+ def on_exception(self, trainer, exception):
+ if not isinstance(exception, EarlyStopException):
+ self.encounter_exception = True
+
+ def _delete_folder(self):
+ if getattr(self, 'real_save_folder', None):
+ logger.info(f"Deleting {self.real_save_folder}...")
+ shutil.rmtree(self.real_save_folder, ignore_errors=True)
+ try:
+ # 如果是 emtpy 的,就会被删除掉
+ os.rmdir(self.save_folder)
+ logger.debug(f"Since {self.save_folder} is an empty folder, it has been removed.")
+ except:
+ pass
+ elif hasattr(self, 'buffer'):
+ self.buffer.close()
+ del self.buffer
\ No newline at end of file
diff --git a/fastNLP/core/callbacks/lr_scheduler_callback.py b/fastNLP/core/callbacks/lr_scheduler_callback.py
new file mode 100644
index 00000000..3d3f4a0f
--- /dev/null
+++ b/fastNLP/core/callbacks/lr_scheduler_callback.py
@@ -0,0 +1,28 @@
+from .callback import Callback
+
+__all__ = [
+ 'LRSchedCallback'
+]
+
+
+class LRSchedCallback(Callback):
+ """
+ 根据 ``step_on`` 参数在合适的时机调用 scheduler 的 step 函数。
+
+ :param scheduler: 实现了 :meth:`step` 函数的对象;
+ :param step_on: 可选 ``['batch', 'epoch']`` 表示在何时调用 scheduler 的 step 函数。如果为 ``batch`` 的话在每次更新参数
+ 之前调用;如果为 ``epoch`` 则是在一个 epoch 运行结束后调用;
+ """
+ def __init__(self, scheduler, step_on:str='batch'):
+ assert hasattr(scheduler, 'step') and callable(scheduler.step), "The scheduler object should have a " \
+ "step function."
+ self.scheduler = scheduler
+ self.step_on = 0 if step_on == 'batch' else 1
+
+ def on_after_optimizers_step(self, trainer, optimizers):
+ if self.step_on == 0:
+ self.scheduler.step()
+
+ def on_train_epoch_end(self, trainer):
+ if self.step_on == 1:
+ self.scheduler.step()
\ No newline at end of file
diff --git a/fastNLP/core/callbacks/more_evaluate_callback.py b/fastNLP/core/callbacks/more_evaluate_callback.py
new file mode 100644
index 00000000..26de0a04
--- /dev/null
+++ b/fastNLP/core/callbacks/more_evaluate_callback.py
@@ -0,0 +1,200 @@
+__all__ = [
+ 'MoreEvaluateCallback'
+]
+
+import os
+from typing import Union, Callable, Optional, Dict
+
+from fastNLP.core.log import logger
+from .has_monitor_callback import HasMonitorCallback
+from .topk_saver import TopkSaver
+
+
+class MoreEvaluateCallback(HasMonitorCallback):
+ """
+ 当评测时需要调用不同的 ``evaluate_fn`` (例如在大部分生成任务中,一般使用训练 loss 作为训练过程中的 evaluate ;但同时在训练到
+ 一定 epoch 数量之后,会让 model 生成的完整的数据评测 bleu 等。此刻就可能需要两种不同的 ``evaluate_fn`` ),只使用 Trainer
+ 无法满足需求,可以通过调用本 callback 进行。如果需要根据本 callback 中的评测结果进行模型保存,请传入 ``topk`` 以及
+ ``topk_monitor`` 等相关参数。可以通过 ``evaluate_every`` 或 ``watch_monitor`` 控制触发进行 evaluate 的条件。
+
+ 如果设置了 evaluate 结果更好就保存的话,将按如下文件结构进行保存::
+
+ - folder/
+ - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
+ - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名
+
+ :param dataloaders: 需要评估的数据
+ :param metrics: 使用的 metrics 。
+ :param evaluate_every: 用来控制 ``Trainer`` 内部的 ``Evaluator`` 验证的频率,其可以为负数、正数或者函数:
+
+ 1. 为负数时表示每隔几个 ``epoch`` evaluate 一次;
+ 2. 为正数则表示每隔几个 ``batch`` evaluate 一次;
+ 3. 为函数时表示用户自己传入的用于控制 evaluate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并
+ 返回一个 bool 值,返回为 ``True`` 说明需要进行 evaluate ;将在每个 ``batch`` 结束后调用该函数判断是否需要 evaluate;
+
+ .. note::
+
+ 如果参数 ``evaluate_every`` 为函数,其应当类似:
+
+ >>> def my_evaluate_every(trainer) -> bool:
+ ... if (trainer.global_forward_batches+1) % 1000 == 0:
+ ... return True
+ ... else:
+ ... return False
+
+ 该函数表示当每经过 1000 个 batch,``Trainer`` 中内置的 ``Evaluator`` 就会验证一次;
+
+ 另一个需要注意的事情在于该函数会在每一次 batch 的结尾进行调用,当该函数返回 ``True`` 时,``Evaluator`` 才会进行验证;
+ :param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 ``None`` ,``evaluate_every`` 失效。本参数的
+ 意义是,当检测到 Trainer 中 evaluate results 的 ``{watch_monitor}`` 的结果更好时,则进行一次 evaluate 。该参数有两种
+ 取值:
+
+ 1. ``str`` 类型,含义为监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 **最长公共字符串算法** 找到最
+ 匹配的那个作为 monitor ;
+ 2. 一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor
+ 的结果,如果当前结果中没有相关的monitor 值请返回 ``None`` ;
+ :param watch_monitor_larger_better: ``watch_monitor`` 是否越大越好;
+ :param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 :meth:`model.evaluate_step` 还是
+ :meth:`model.forward`:
+
+ 1. 如果该值是 ``None``,那么我们会默认使用 :meth:`model.evaluate_step` 当做前向传播的函数,如果
+ 在模型中没有找到该方法,则使用 :meth:`model.forward` 函数;
+ 2. 如果为 ``str`` 类型,则尝试从 model 中寻找该方法,找不到则报错;
+ :param num_eval_sanity_batch: 在初始化 Evaluator 后运行多少个 sanity check 的 batch ,检测一下。
+ :param topk: 如果需要根据当前 callback 中的 evaluate 结果保存模型或 Trainer ,可以通过设置 topk 实现:
+
+ 1. 为 ``-1`` 表示每次 evaluate 后都保存;
+ 2. 为 ``0`` (默认),表示不保存;
+ 3. 为整数,表示保存性能最好的 ``topk`` 个。
+ :param topk_monitor: 如果需要根据当前 callback 中的 evaluate 结果保存。这个参数是指在当前 callback 中的 evaluate 结果寻找
+ :param topk_larger_better: ``topk_monitor`` 的值是否是越大越好。
+ :param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
+ 时间戳文件夹中。如果为 ``None`` ,默认使用当前文件夹。
+ :param only_state_dict: 保存模型时是否只保存 state_dict 。当 ``model_save_fn`` 不为 ``None`` 时,该参数无效。
+ :param save_object: 可选 ``['trainer', 'model']`` ,表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果
+ 保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.core.controllers.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断
+ 点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。
+ :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。
+ 如果传入了 ``model_save_fn`` 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。
+ :param save_evaluate_results: 是否保存 evaluate 的结果。如果为 ``True`` ,在保存 topk 模型的 folder 中还将额外保存一个
+ ``fastnlp_evaluate_results.json`` 文件,记录当前的 results。仅在设置了 ``topk`` 的场景下有效,默认为 True 。
+ :param save_kwargs: 一个字典,表示更多的保存相关的参数。
+ :param kwargs: 其它与 :class:`~fastNLP.core.controllers.Evaluator` 相关的初始化参数,如果不传入,将从 :class:`~fastNLP.core.controllers.Trainer` 中获取。
+ """
+ def __init__(self, dataloaders, metrics:Dict, evaluate_every:Optional[Union[int, Callable]]=-1,
+ watch_monitor:Union[str, Callable]=None, watch_monitor_larger_better:bool=True,
+ evaluate_fn=None, num_eval_sanity_batch=2,
+ topk=0, topk_monitor=None, topk_larger_better=True,
+ folder=None, only_state_dict=True, save_object='model', model_save_fn=None,
+ save_evaluate_results=True, save_kwargs=None,
+ **kwargs):
+ super(MoreEvaluateCallback, self).__init__(watch_monitor, watch_monitor_larger_better,
+ must_have_monitor=False)
+ if watch_monitor is not None and evaluate_every == -1: # 将evaluate_every 弄掉。
+ evaluate_every = None
+ if watch_monitor is None and evaluate_every is None:
+ raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be None at the same time.")
+ if watch_monitor is not None and evaluate_every is not None:
+ raise RuntimeError(f"`evaluate_every`({evaluate_every}) and `watch_monitor`({watch_monitor}) "
+ f"cannot be set at the same time.")
+
+ if topk_monitor is not None and topk == 0:
+ raise RuntimeError("`topk_monitor` is set, but `topk` is 0.")
+ if topk != 0 and topk_monitor is None:
+ raise RuntimeError("`topk` is set, but `topk_monitor` is None.")
+ assert save_object in ['trainer', 'model']
+
+ self.dataloaders = dataloaders
+ self.metrics = metrics
+ self.evaluate_every = evaluate_every
+ self.evaluate_fn = evaluate_fn
+ self.num_eval_sanity_batch = num_eval_sanity_batch
+ if save_kwargs is None:
+ save_kwargs = {}
+ self.topk_saver = TopkSaver(topk=topk, monitor=topk_monitor, larger_better=topk_larger_better,
+ folder=folder, only_state_dict=only_state_dict,
+ model_save_fn=model_save_fn, save_evaluate_results=save_evaluate_results,
+ save_object=save_object, **save_kwargs)
+ self.kwargs = kwargs
+
+ @property
+ def need_reproducible_sampler(self) -> bool:
+ return self.topk_saver.save_object == 'trainer'
+
+ def on_after_trainer_initialized(self, trainer, driver):
+ # 如果是需要 watch 的,不能没有 evaluator
+ if self.monitor is not None:
+ assert trainer.evaluator is not None, f"You set `watch_monitor={self.monitor}`, but no " \
+ f"evaluate_dataloaders is provided in Trainer."
+
+ # 初始化 evaluator , 同时避免调用 super 对 monitor 赋值
+ kwargs = {
+ 'model': self.kwargs.get('model', trainer.model),
+ 'dataloaders': self.dataloaders,
+ 'metrics': self.metrics,
+ 'driver': self.kwargs.get('driver', trainer.driver),
+ 'device': self.kwargs.get('device', trainer.device),
+ 'evaluate_batch_step_fn': self.kwargs.get('evaluate_batch_step_fn', trainer.evaluate_batch_step_fn),
+ 'evaluate_fn': self.evaluate_fn,
+ 'input_mapping': self.kwargs.get('input_mapping', trainer.input_mapping),
+ 'output_mapping': self.kwargs.get('output_mapping', trainer.output_mapping),
+ 'fp16': self.kwargs.get('fp16', trainer.fp16),
+ 'use_dist_sampler': self.kwargs.get('use_dist_sampler',
+ trainer.kwargs.get('eval_use_dist_sampler', None)),
+ 'progress_bar': self.kwargs.get('progress_bar', trainer.kwargs.get('progress_bar', 'auto')),
+ 'verbose': self.kwargs.get('verbose', 1)
+ }
+
+ for key, value in self.kwargs.items():
+ if key not in kwargs:
+ kwargs[key] = value
+ from fastNLP.core.controllers.evaluator import Evaluator
+ self.evaluator = Evaluator(**kwargs)
+ if self.num_eval_sanity_batch>0:
+ results = self.evaluator.run(num_eval_batch_per_dl=self.num_eval_sanity_batch)
+ self.topk_saver.get_monitor_value(results)
+
+ def on_evaluate_end(self, trainer, results):
+ if self.is_better_results(results, keep_if_better=True):
+ results = self.evaluator.run()
+ self.topk_saver.save_topk(trainer, results)
+
+ def on_train_epoch_end(self, trainer):
+ if self.monitor is not None:
+ return
+ if isinstance(self.evaluate_every, int) and self.evaluate_every < 0:
+ evaluate_every = -self.evaluate_every
+ if trainer.cur_epoch_idx % evaluate_every == 0:
+ results = self.evaluator.run()
+ self.topk_saver.save_topk(trainer, results)
+
+ def on_train_batch_end(self, trainer):
+ if self.monitor is not None:
+ return
+ if callable(self.evaluate_every):
+ if self.evaluate_every(trainer):
+ results = self.evaluator.run()
+ self.topk_saver.save_topk(trainer, results)
+ elif self.evaluate_every > 0 and trainer.global_forward_batches % self.evaluate_every == 0:
+ results = self.evaluator.run()
+ self.topk_saver.save_topk(trainer, results)
+
+ def on_save_checkpoint(self, trainer) -> Dict:
+ states = {'topk_saver': self.topk_saver.state_dict()}
+ if isinstance(self._real_monitor, str):
+ states['_real_monitor'] = self._real_monitor
+ states['monitor_value'] = self.monitor_value
+ return states
+
+ def on_load_checkpoint(self, trainer, states: Optional[Dict]):
+ topk_saver_states = states['topk_saver']
+ self.topk_saver.load_state_dict(topk_saver_states)
+ if '_real_monitor' in states:
+ self._real_monitor = states["_real_monitor"]
+ self.monitor_value = states['monitor_value']
+
+ @property
+ def callback_name(self):
+ metric_names = '+'.join(sorted(self.metrics.keys()))
+ return f'more_evaluate_callback#metric_name-{metric_names}#monitor-{self.monitor_name}#topk_saver:{self.topk_saver}'
+
diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py
new file mode 100644
index 00000000..eda0f564
--- /dev/null
+++ b/fastNLP/core/callbacks/progress_callback.py
@@ -0,0 +1,337 @@
+import json
+from typing import Union
+
+__all__ = [
+ 'choose_progress_callback',
+ 'ProgressCallback',
+ 'RichCallback',
+ 'TqdmCallback',
+ 'RawTextCallback'
+]
+
+
+from .has_monitor_callback import HasMonitorCallback
+from fastNLP.core.utils import f_rich_progress, f_tqdm_progress
+from fastNLP.core.log import logger
+
+
+class ProgressCallback(HasMonitorCallback):
+ def __init__(self, monitor, larger_better, must_have_monitor=False):
+ super(ProgressCallback, self).__init__(monitor=monitor, larger_better=larger_better,
+ must_have_monitor=must_have_monitor)
+ self.best_monitor_epoch = -1
+ self.best_monitor_step = -1
+ self.best_results = None
+
+ def record_better_monitor(self, trainer, results):
+ self.best_monitor_step = trainer.global_forward_batches
+ self.best_monitor_epoch = trainer.cur_epoch_idx
+ self.best_results = self.itemize_results(results)
+
+ def on_train_end(self, trainer):
+ if self.best_monitor_epoch != -1:
+ msg = f"The best performance for monitor {self._real_monitor}:{self.monitor_value} was achieved in" \
+ f" Epoch:{self.best_monitor_epoch}, Global Batch:{self.best_monitor_step}."
+ if self.best_results is not None:
+ msg = msg + ' The evaluation result: \n' + str(self.best_results)
+ logger.info(msg)
+
+ @property
+ def name(self): # progress bar的名称
+ return 'auto'
+
+
+def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> ProgressCallback:
+ if progress_bar == 'auto':
+ if not f_rich_progress.dummy:
+ progress_bar = 'rich'
+ else:
+ progress_bar = 'raw'
+ if progress_bar == 'rich':
+ return RichCallback()
+ elif progress_bar == 'raw':
+ return RawTextCallback()
+ elif progress_bar == 'tqdm':
+ return TqdmCallback()
+ elif isinstance(progress_bar, ProgressCallback):
+ return progress_bar
+ else:
+ return None
+
+
+class RichCallback(ProgressCallback):
+ """
+ 在训练过程中打印 *rich* progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的
+ 参数,请通过实例化本 Callback 并传入到 Trainer 中实现。在打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。
+
+ :param print_every: 多少个 batch 更新一次显示。
+ :param loss_round_ndigit: 显示的 loss 保留多少位有效数字
+ :param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。
+
+ * 为 ``None``
+ 将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。
+ * 为 ``str``
+ 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将
+ 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。
+ * 为 :class:`Callable`
+ 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关
+ 的 ``monitor`` 值请返回 ``None`` 。
+
+ :param larger_better: 是否是 monitor 的结果越大越好。
+ :param format_json: 是否格式化 json 再打印
+ """
+ def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True,
+ format_json=True):
+ super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False)
+ self.print_every = print_every
+ self.progress_bar = f_rich_progress
+ self.task2id = {}
+ self.loss = 0
+ self.loss_round_ndigit = loss_round_ndigit
+ self.format_json = format_json
+
+ def on_after_trainer_initialized(self, trainer, driver):
+ if not self.progress_bar.disable:
+ self.progress_bar.set_disable(flag=trainer.driver.get_local_rank() != 0)
+ super(RichCallback, self).on_after_trainer_initialized(trainer, driver)
+
+ def on_train_begin(self, trainer):
+ self.task2id['epoch'] = self.progress_bar.add_task(description=f'Epoch:{trainer.cur_epoch_idx}',
+ total=trainer.n_epochs,
+ completed=trainer.global_forward_batches/(trainer.n_batches+1e-6)*
+ trainer.n_epochs)
+
+ def on_train_epoch_begin(self, trainer):
+ self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6)
+ if 'batch' in self.task2id:
+ self.progress_bar.reset(self.task2id['batch'], completed=trainer.batch_idx_in_epoch)
+ else:
+ self.task2id['batch'] = self.progress_bar.add_task(description=f'Batch:{trainer.batch_idx_in_epoch}',
+ total=trainer.num_batches_per_epoch,
+ completed=trainer.batch_idx_in_epoch)
+
+ def on_train_epoch_end(self, trainer):
+ self.progress_bar.update(self.task2id['epoch'], description=f'Epoch:{trainer.cur_epoch_idx}',
+ advance=None, completed=trainer.cur_epoch_idx, refresh=True)
+
+ def on_train_end(self, trainer):
+ super(RichCallback, self).on_train_end(trainer)
+ self.clear_tasks()
+
+ def on_before_backward(self, trainer, outputs):
+ loss = trainer.extract_loss_from_outputs(outputs)
+ loss = trainer.driver.tensor_to_numeric(loss, reduce='sum')
+ self.loss += loss
+
+ def on_train_batch_end(self, trainer):
+ if trainer.global_forward_batches % self.print_every == 0:
+ loss = self.loss/self.print_every
+ self.loss = 0
+ self.progress_bar.update(self.task2id['batch'], description=f'Batch:{trainer.batch_idx_in_epoch}',
+ advance=self.print_every,
+ post_desc=f'Loss:{round(loss, self.loss_round_ndigit)}', refresh=True)
+ self.progress_bar.update(self.task2id['epoch'], description=f'Epoch:{trainer.cur_epoch_idx}',
+ advance=self.epoch_bar_update_advance, refresh=True)
+
+ def on_evaluate_end(self, trainer, results):
+ if len(results)==0:
+ return
+ rule_style = ''
+ text_style = ''
+ characters = '-'
+ if self.monitor is not None:
+ if self.is_better_results(results, keep_if_better=True):
+ self.record_better_monitor(trainer, results)
+ if abs(self.monitor_value) != float('inf'):
+ rule_style = 'spring_green3'
+ text_style = '[bold]'
+ characters = '+'
+ self.progress_bar.print()
+ self.progress_bar.console.rule(text_style+f"Eval. results on Epoch:{trainer.cur_epoch_idx}, "
+ f"Batch:{trainer.batch_idx_in_epoch}",
+ style=rule_style, characters=characters)
+ results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if
+ not key.startswith('_')}
+ if self.format_json:
+ results = json.dumps(results)
+ self.progress_bar.console.print_json(results)
+ else:
+ self.progress_bar.print(results)
+
+ def clear_tasks(self):
+ for key, taskid in self.task2id.items():
+ self.progress_bar.destroy_task(taskid)
+ self.progress_bar.stop()
+ self.task2id = {}
+ self.loss = 0
+
+ @property
+ def name(self): # progress bar的名称
+ return 'rich'
+
+
+class RawTextCallback(ProgressCallback):
+ """
+ 通过向命令行打印进度的方式显示。在打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。
+
+ :param print_every: 多少个 batch 更新一次显示。
+ :param loss_round_ndigit: 显示的 loss 保留多少位有效数字
+ :param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。
+
+ * 为 ``None``
+ 将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。
+ * 为 ``str``
+ 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将
+ 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。
+ * 为 :class:`Callable`
+ 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关
+ 的 ``monitor`` 值请返回 ``None`` 。
+ :param larger_better: 是否是monitor的结果越大越好。
+ :param format_json: 是否format json再打印
+ """
+ def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True,
+ format_json=True):
+ super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False)
+ self.print_every = print_every
+ self.task2id = {}
+ self.loss = 0
+ self.loss_round_ndigit = loss_round_ndigit
+ self.set_monitor(monitor, larger_better)
+ self.format_json = format_json
+ self.num_signs = 10
+
+ def on_train_epoch_begin(self, trainer):
+ logger.info('\n' + "*"*self.num_signs + f'Epoch:{trainer.cur_epoch_idx} starts' + '*'*self.num_signs)
+
+ def on_before_backward(self, trainer, outputs):
+ loss = trainer.extract_loss_from_outputs(outputs)
+ loss = trainer.driver.tensor_to_numeric(loss, reduce='sum')
+ self.loss += loss
+
+ def on_train_batch_end(self, trainer):
+ if trainer.global_forward_batches % self.print_every == 0:
+ loss = self.loss/self.print_every
+ self.loss = 0
+ text = f'Epoch:{trainer.cur_epoch_idx}/{trainer.n_epochs}, Batch:{trainer.batch_idx_in_epoch}, ' \
+ f'loss:{round(loss, self.loss_round_ndigit)}, ' \
+ f'finished {round(trainer.global_forward_batches/trainer.n_batches*100, 2)}%.'
+ logger.info(text)
+
+ def on_evaluate_end(self, trainer, results):
+ if len(results)==0:
+ return
+ base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}'
+ text = ''
+ if self.monitor is not None:
+ if self.is_better_results(results, keep_if_better=True):
+ self.record_better_monitor(trainer, results)
+ if abs(self.monitor_value) != float('inf'):
+ text = '+'*self.num_signs + base_text + '+'*self.num_signs
+ if len(text) == 0:
+ text = '-'*self.num_signs + base_text + '-'*self.num_signs
+
+ logger.info(text)
+ results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if
+ not key.startswith('_')}
+ if self.format_json:
+ results = json.dumps(results)
+ logger.info(results)
+
+ @property
+ def name(self): # progress bar的名称
+ return 'raw'
+
+
+class TqdmCallback(ProgressCallback):
+ """
+ 在训练过程中打印 *tqdm* progress bar 的 callback 。在 Trainer 中,如果设置了 ``progress_bar='tqdm'`` 就会使用
+ 这个 callback 来显示进度。如果需要定制这个 Callback 的参数,请通过实例化本 Callback 并传入到 Trainer 中实现。在
+ 打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。
+
+ :param print_every: 多少个 batch 更新一次显示。
+ :param loss_round_ndigit: 显示的 loss 保留多少位有效数字
+ :param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。
+
+ * 为 ``None``
+ 将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。
+ * 为 ``str``
+ 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将
+ 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。
+ * 为 :class:`Callable`
+ 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关
+ 的 ``monitor`` 值请返回 ``None`` 。
+ :param larger_better: 是否是 monitor 的结果越大越好。
+ :param format_json: 是否格式化 json 再打印
+ """
+ def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True,
+ format_json=True):
+ super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False)
+ self.print_every = print_every
+ self.progress_bar = f_tqdm_progress
+ self.task2id = {}
+ self.loss = 0
+ self.loss_round_ndigit = loss_round_ndigit
+ self.format_json = format_json
+ self.num_signs = 10
+
+ def on_train_begin(self, trainer):
+ self.task2id['epoch'] = self.progress_bar.add_task(description=f'Epoch:{trainer.cur_epoch_idx}',
+ total=trainer.n_epochs,
+ bar_format='{desc}: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}, {rate_fmt}, {postfix}]',
+ initial=trainer.global_forward_batches/(trainer.n_batches+1e-6)*trainer.n_epochs)
+
+ def on_train_epoch_begin(self, trainer):
+ self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6)
+ if 'batch' in self.task2id:
+ self.progress_bar.reset(self.task2id['batch'])
+ else:
+ self.task2id['batch'] = self.progress_bar.add_task(description='Batch', total=trainer.num_batches_per_epoch,
+ initial=trainer.batch_idx_in_epoch)
+ self.progress_bar.set_description_str(self.task2id['epoch'], f'Epoch:{trainer.cur_epoch_idx}', refresh=True)
+
+ def on_train_end(self, trainer):
+ super(TqdmCallback, self).on_train_end(trainer)
+ self.clear_tasks()
+
+ def on_before_backward(self, trainer, outputs):
+ loss = trainer.extract_loss_from_outputs(outputs)
+ loss = trainer.driver.tensor_to_numeric(loss, reduce='sum')
+ self.loss += loss
+
+ def on_train_batch_end(self, trainer):
+ if trainer.global_forward_batches % self.print_every == 0:
+ loss = self.loss/self.print_every
+ self.loss = 0
+ self.progress_bar.update(self.task2id['batch'], advance=self.print_every, refresh=True)
+ self.progress_bar.set_postfix_str(self.task2id['batch'], f'Loss:{round(loss, self.loss_round_ndigit)}')
+ self.progress_bar.update(self.task2id['epoch'], advance=self.epoch_bar_update_advance, refresh=True)
+
+ def on_evaluate_end(self, trainer, results):
+ if len(results) == 0:
+ return
+ base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}'
+ text = ''
+ if self.monitor is not None:
+ if self.is_better_results(results, keep_if_better=True):
+ self.record_better_monitor(trainer, results)
+ if abs(self.monitor_value) != float('inf'):
+ text = '+'*self.num_signs + base_text + '+'*self.num_signs
+ if len(text) == 0:
+ text = '-'*self.num_signs + base_text + '-'*self.num_signs
+
+ logger.info(text)
+ results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if
+ not key.startswith('_')}
+ if self.format_json:
+ results = json.dumps(results)
+ logger.info(results)
+
+ def clear_tasks(self):
+ for key, taskid in self.task2id.items():
+ self.progress_bar.destroy_task(taskid)
+ self.task2id = {}
+ self.loss = 0
+
+ @property
+ def name(self): # progress bar的名称
+ return 'tqdm'
diff --git a/fastNLP/core/callbacks/timer_callback.py b/fastNLP/core/callbacks/timer_callback.py
new file mode 100644
index 00000000..27dbe538
--- /dev/null
+++ b/fastNLP/core/callbacks/timer_callback.py
@@ -0,0 +1,152 @@
+import time
+from .callback import Callback
+from ..log import logger
+__all__ = ['TimerCallback']
+
+
+class _Timer:
+ """Timer."""
+
+ def __init__(self, name):
+ self.name_ = name
+ self.elapsed_ = 0.0
+ self.started_ = False
+ self.start_time = time.time()
+
+ def start(self):
+ """Start the timer."""
+ assert not self.started_, f'{self.name_} timer has already been started'
+ self.start_time = time.time()
+ self.started_ = True
+
+ def stop(self):
+ """Stop the timer."""
+ assert self.started_, f'{self.name_} timer is not started'
+ self.elapsed_ += (time.time() - self.start_time)
+ self.started_ = False
+
+ def reset(self):
+ """Reset timer."""
+ self.elapsed_ = 0.0
+ self.started_ = False
+
+ def elapsed(self, reset=True):
+ """Calculate the elapsed time."""
+ started_ = self.started_
+ # If the timing in progress, end it first.
+ if self.started_:
+ self.stop()
+ # Get the elapsed time.
+ elapsed_ = self.elapsed_
+ # Reset the elapsed time
+ if reset:
+ self.reset()
+ # If timing was in progress, set it back.
+ if started_:
+ self.start()
+ return elapsed_
+
+
+class Timers:
+ """Group of timers."""
+
+ def __init__(self):
+ self.timers = {}
+
+ def __call__(self, name):
+ if name not in self.timers:
+ self.timers[name] = _Timer(name)
+ return self.timers[name]
+
+ def __contains__(self, item):
+ return item in self.timers
+
+ def reset(self):
+ for timer in self.timers.values():
+ timer.reset()
+
+
+class TimerCallback(Callback):
+ """
+ 这个 callback 的作用是打印训练过程中的相关时间信息,例如训练时长、评测时长、总时长等
+
+ """
+ def __init__(self, print_every=-1, time_ndigit=3):
+ """
+
+ :param print_every: 在哪个时候打印时间信息。
+
+ * *负数*: 表示每隔多少 epoch 结束打印一次;
+ * *0*: 表示整个训练结束才打印;
+ * *正数*: 每隔多少个 step 打印一次;
+
+ :param time_ndigit: 保留多少位的小数
+ """
+ assert isinstance(print_every, int), "print_every must be an int number."
+ self.timers = Timers()
+ self.print_every = print_every
+ self.time_ndigit = time_ndigit
+
+ def on_train_begin(self, trainer):
+ self.timers('total').start()
+ self.timers('train').start()
+
+ def on_fetch_data_begin(self, trainer):
+ self.timers('fetch-data').start()
+
+ def on_fetch_data_end(self, trainer):
+ self.timers('fetch-data').stop()
+
+ def on_train_batch_begin(self, trainer, batch, indices):
+ self.timers('forward').start()
+
+ def on_before_backward(self, trainer, outputs):
+ self.timers('forward').stop()
+ self.timers('backward').start()
+
+ def on_after_backward(self, trainer):
+ self.timers('backward').stop()
+
+ def on_before_optimizers_step(self, trainer, optimizers):
+ self.timers('optimize').start()
+
+ def on_after_optimizers_step(self, trainer, optimizers):
+ self.timers('optimize').stop()
+
+ def on_evaluate_begin(self, trainer):
+ self.timers('train').stop()
+ self.timers('evaluate').start()
+
+ def on_evaluate_end(self, trainer, results):
+ self.timers('evaluate').stop()
+ self.timers('train').start()
+
+ def format_timer(self, reset=True):
+ line = ''
+ timers = ['fetch-data', 'forward', 'backward', 'optimize', 'evaluate', 'train', 'total']
+ for timer_name in timers:
+ if not timer_name in self.timers:
+ continue
+ timer = self.timers(timer_name)
+ elapsed = round(timer.elapsed(reset=reset), self.time_ndigit)
+ if elapsed != 0:
+ line = line + f', {timer_name}: {elapsed}s'
+ return line
+
+ def on_train_batch_end(self, trainer):
+ if self.print_every>0 and trainer.global_forward_batches % self.print_every == 0:
+ line = self.format_timer()
+ logger.info(f"Running {self.print_every} batches{line}")
+
+ def on_train_epoch_end(self, trainer):
+ if self.print_every < 0 and trainer.cur_epoch_idx % abs(self.print_every) == 0:
+ line = self.format_timer()
+ logger.info(f"Running {abs(self.print_every)} epochs{line}")
+
+ def on_train_end(self, trainer):
+ if self.print_every == 0:
+ line = self.format_timer()
+ logger.info(f"Training finished{line}")
+
+
+
diff --git a/fastNLP/core/callbacks/topk_saver.py b/fastNLP/core/callbacks/topk_saver.py
new file mode 100644
index 00000000..21a8961f
--- /dev/null
+++ b/fastNLP/core/callbacks/topk_saver.py
@@ -0,0 +1,279 @@
+__all__ = [
+ 'TopkSaver'
+]
+import json
+import os
+from copy import deepcopy
+from pathlib import Path
+from typing import Optional, Dict, Tuple, Callable, Union
+
+from ...envs.distributed import rank_zero_rm
+from fastNLP.core.log import logger
+from fastNLP.envs import FASTNLP_LAUNCH_TIME
+from fastNLP.envs import rank_zero_call
+from fastNLP.envs.env import FASTNLP_EVALUATE_RESULT_FILENAME
+from .has_monitor_callback import ResultsMonitor
+
+
+class Saver:
+ """
+ 执行保存的对象。保存的文件组织结构为::
+
+ - folder # 当前初始化的参数
+ - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
+ - folder_name # 由 save() 调用时传入。
+
+ :param folder: 保存在哪个文件夹下,默认为当前 folder 下。
+ :param save_object: 可选 ``['trainer', 'model']`` ,表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果
+ 保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.core.controllers.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断
+ 点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。
+ :param only_state_dict: 保存时是否仅保存权重,在 model_save_fn 不为 None 时无意义。
+ :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。
+ 如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。
+ :param kwargs: 更多需要传递给 Trainer.save_checkpoint() 或者 Trainer.save_model() 接口的参数。
+ """
+ def __init__(self, folder:str=None, save_object:str='model', only_state_dict:bool=True,
+ model_save_fn:Callable=None, **kwargs):
+ if folder is None:
+ folder = Path.cwd().absolute()
+ folder = Path(folder)
+ if not folder.exists():
+ folder.mkdir(parents=True, exist_ok=True)
+ elif folder.is_file():
+ raise ValueError("Parameter `folder` should be a directory instead of a file.")
+
+ self.folder = folder
+ self.only_state_dict = only_state_dict
+ self.model_save_fn = model_save_fn
+ self.kwargs = kwargs
+ self.save_object = save_object
+ self.save_fn_name = 'save_checkpoint' if save_object == 'trainer' else 'save_model'
+
+ self.timestamp_path = self.folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME])
+ # 打印这次运行时 checkpoint 所保存在的文件夹,因为这个文件夹是根据时间实时生成的,因此需要打印出来防止用户混淆;
+ logger.info(f"The checkpoint will be saved in this folder for this time: {self.timestamp_path}.")
+
+ def save(self, trainer, folder_name):
+ """
+ 执行保存的函数,将数据保存在::
+
+ - folder/
+ - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
+ - folder_name # 当前函数参数
+
+ :param trainer: Trainer 对象
+ :param folder_name: 保存的 folder 名称,将被创建。
+ :return: 实际发生保存的 folder 绝对路径。如果为 None 则没有创建。
+ """
+ folder = self.timestamp_path.joinpath(folder_name)
+ folder.mkdir(parents=True, exist_ok=True)
+
+ save_fn = getattr(trainer, self.save_fn_name)
+ save_fn(
+ folder=folder,
+ only_state_dict=self.only_state_dict,
+ model_save_fn=self.model_save_fn,
+ **self.kwargs
+ )
+ return str(os.path.abspath(folder))
+
+ @rank_zero_call
+ def save_json(self, results, path):
+ """
+ 以 json 格式保存 results 到 path 中
+
+ :param results: 一般是评测后的结果。
+ :param path: 保存的文件名
+ :return:
+ """
+ with open(path, 'w', encoding='utf8') as f:
+ json.dump(results, f, indent=2)
+
+ @rank_zero_call
+ def rm(self, folder_name):
+ """
+ 移除 folder/timestamp/folder_name 。其中 folder 为用户在初始化指定, timestamp 为当前脚本的启动时间。
+
+ :param folder_name: 需要移除的路径。
+ :return:
+ """
+ folder = self.timestamp_path.joinpath(folder_name)
+ rank_zero_rm(folder)
+
+ def state_dict(self):
+ states = {
+ 'timestamp_path': str(self.timestamp_path),
+ }
+ return states
+
+ def load_state_dict(self, states):
+ timestamp_path = states['timestamp_path']
+ if not os.path.exists(timestamp_path):
+ logger.info(f"The resuming checkpoint folder {timestamp_path} is not exists, checkpoint will save to "
+ f" {self.timestamp_path.absolute()}.")
+ else:
+ logger.info(f"Resume to save checkpoint in path: {timestamp_path}.")
+ self.timestamp_path = Path(timestamp_path)
+
+ def __str__(self):
+ return f'saver:{self.save_object}'
+
+
+class TopkQueue:
+ """
+ 用于维护处于 topk 的 key, value 对。
+
+ :param int topk: 整数,-1 表示所有数据都是 topk 的; 如果是 0, 表示没有任何数据是满足 topk 的。
+ """
+ def __init__(self, topk):
+ assert isinstance(topk, int)
+ self.topk = topk
+ self.topk_dict = {} # 其中 key 为保存的内容, value 是对应的性能。
+
+ def push(self, key, value) -> Optional[Tuple[Union[str, None], Union[float, None]]]:
+ """
+ 将 key/value 推入 topk 的 queue 中,以 value 为标准,如果满足 topk 则保留此次推入的信息,同时如果新推入的数据将之前的数据给
+ 挤出了 topk ,则会返回被挤出的 (key, value);如果返回为 (None, None),说明满足 topk 且没有数据被挤出。如果不满足 topk ,则返回
+ 推入的 (key, value) 本身。这里排序只根据 value 是否更大了判断,因此如果有的情况是越小越好,请在输入前取负号。
+
+ :param str key:
+ :param float value: 如果为 None, 则不做任何操作。
+ :return: (1)返回输入的 (key, value) ,说明不满足 topk; (2) 返回(None, None),说明满足 topk 且没有被挤出过去的记录; (3)
+ 返回非输入的 (key, value) , 说明输入满足 topk,且挤出了之前的记录。
+ """
+ if value is None:
+ return key, value
+ if self.topk < 0:
+ return None, None
+ if self.topk == 0:
+ return key, value
+ if len(self.topk_dict) value:
+ return key, value
+ else:
+ min_value = self.topk_dict.pop(min_key)
+ self.topk_dict[key] = value
+ return min_key, min_value
+
+ def state_dict(self):
+ return deepcopy(self.topk_dict)
+
+ def load_state_dict(self, states):
+ self.topk_dict.update(states)
+
+ def __str__(self):
+ return f'topk-{self.topk}'
+
+ def __bool__(self):
+ # 当 topk 为 0 时,表明该 topk_queue 无意义。
+ return self.topk != 0
+
+
+class TopkSaver(ResultsMonitor, Saver):
+ """
+ 用来识别 topk 模型并保存,也可以仅当一个保存 Saver 使用。保存路径为::
+
+ - folder/
+ - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
+ - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名
+
+ :param topk: 保存表现最好的 ``topk`` 个模型,-1 为保存所有模型;0 为都不保存;大于 0 的数为保存 ``topk`` 个;
+ :param monitor: 监控的 metric 值:
+
+ * 为 ``None``
+ 将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。
+ * 为 ``str``
+ 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将
+ 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。
+ * 为 :class:`Callable`
+ 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关
+ 的 ``monitor`` 值请返回 ``None`` 。
+ :param larger_better: 该 monitor 是否越大越好。
+ :param folder: 保存在哪个文件夹下,默认为当前 folder 下。
+ :param save_object: 可选 ``['trainer', 'model']`` ,表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果
+ 保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.core.controllers.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断
+ 点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。
+ :param only_state_dict: 保存时是否仅保存权重,在 ``model_save_fn`` 不为 None 时无意义。
+ :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。
+ 如果传入了 ``model_save_fn`` 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。
+ :param save_evaluate_results: 是否保存 evaluate 的结果。如果为 True ,在保存 topk 模型的 folder 中还将额外保存一个
+ ``fastnlp_evaluate_results.json`` 文件,记录当前的 metric results 。仅在设置了 ``topk`` 的场景下有用,默认为 True 。
+ :param kwargs: 更多需要传递给 :meth:`Trainer.save_checkpoint` 或者 :meth:`Trainer.save_model` 接口的参数。
+ """
+ def __init__(self, topk:int=0, monitor:str=None, larger_better:bool=True, folder:str=None, save_object:str='model',
+ only_state_dict:bool=True, model_save_fn:Callable=None, save_evaluate_results:bool=True,
+ **kwargs):
+ if topk is None:
+ topk = 0
+ ResultsMonitor.__init__(self, monitor, larger_better)
+ Saver.__init__(self, folder, save_object, only_state_dict, model_save_fn, **kwargs)
+
+ if monitor is not None and topk == 0:
+ raise RuntimeError("`monitor` is set, but `topk` is 0.")
+ if topk != 0 and monitor is None:
+ raise RuntimeError("`topk` is set, but `monitor` is None.")
+
+ self.topk_queue = TopkQueue(topk)
+ self.save_evaluate_results = save_evaluate_results
+
+ # 注意这里我们为了支持 torch_fsdp 去除了 ''@rank_zero_call'';
+ def save_topk(self, trainer, results: Dict) -> Optional[str]:
+ """
+ 根据 ``results`` 是否满足 topk 的相关设定决定是否保存,如果发生了保存,将返回保存的文件夹。如果返回为 ``None`` ,则说明此次没有满足
+ topk 要求,没有发生保存。
+
+ :param trainer:
+ :param results: evaluate 的结果。
+ :return:
+ """
+ if self.monitor is not None and self.topk_queue:
+ monitor_value = self.get_monitor_value(results)
+ if monitor_value is None:
+ return
+ key = f"{self.save_object}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \
+ f"-{self.monitor_name}_{monitor_value}"
+ pop_key, pop_value = self.topk_queue.push(key, monitor_value if self.larger_better else -monitor_value)
+ if pop_key == key: # 说明不足以构成 topk,被退回了
+ return None
+ folder = self.save(trainer, key)
+ if self.save_evaluate_results and folder:
+ try:
+ self.save_json(self.itemize_results(results),
+ os.path.join(folder, FASTNLP_EVALUATE_RESULT_FILENAME))
+ except:
+ logger.exception(f"Fail to save evaluate results to {folder}")
+
+ if pop_key and pop_key != key: # 说明需要移除之前的 topk
+ self.rm(pop_key)
+ return folder
+
+ def state_dict(self):
+ states = {
+ 'topk_queue': self.topk_queue.state_dict(),
+ 'timestamp_path': str(self.timestamp_path),
+ }
+ if isinstance(self._real_monitor, str):
+ states['_real_monitor'] = self._real_monitor
+
+ return states
+
+ def load_state_dict(self, states):
+ topk_queue_states = states['topk_queue']
+ self.topk_queue.load_state_dict(topk_queue_states)
+
+ timestamp_path = states['timestamp_path']
+ if not os.path.exists(timestamp_path):
+ logger.info(f"The resuming checkpoint folder {timestamp_path} is not exists, checkpoint will save to "
+ f" {self.timestamp_path.absolute()}.")
+ else:
+ logger.info(f"Resume to save checkpoint in path: {timestamp_path}.")
+ self.timestamp_path = Path(timestamp_path)
+
+ if '_real_monitor' in states:
+ self._real_monitor = states["_real_monitor"]
+
+ def __str__(self):
+ return f'topk-{self.topk_queue}#save_object-{self.save_object}'
diff --git a/fastNLP/core/callbacks/torch_callbacks/__init__.py b/fastNLP/core/callbacks/torch_callbacks/__init__.py
new file mode 100644
index 00000000..1cadd7f6
--- /dev/null
+++ b/fastNLP/core/callbacks/torch_callbacks/__init__.py
@@ -0,0 +1,8 @@
+__all__ = [
+ 'TorchWarmupCallback',
+ 'TorchGradClipCallback'
+]
+
+
+from .torch_lr_sched_callback import TorchWarmupCallback
+from .torch_grad_clip_callback import TorchGradClipCallback
\ No newline at end of file
diff --git a/fastNLP/core/callbacks/torch_callbacks/torch_grad_clip_callback.py b/fastNLP/core/callbacks/torch_callbacks/torch_grad_clip_callback.py
new file mode 100644
index 00000000..10ef7894
--- /dev/null
+++ b/fastNLP/core/callbacks/torch_callbacks/torch_grad_clip_callback.py
@@ -0,0 +1,61 @@
+__all__ = [
+ 'TorchGradClipCallback'
+]
+from typing import Union, List
+from ..callback import Callback
+from ...drivers.torch_driver.fairscale import FairScaleDriver
+from ...drivers.torch_driver import TorchDriver
+from fastNLP.envs.imports import _NEED_IMPORT_FAIRSCALE
+if _NEED_IMPORT_FAIRSCALE:
+ from fairscale.nn import FullyShardedDataParallel
+
+class TorchGradClipCallback(Callback):
+ r"""
+ 在每次 :func:`optimizer.step` 之前对参数的梯度进行截断。
+
+ :param clip_value: 将梯度限制到 [-clip_value, clip_value] 之间。``clip_value`` 应该为正数;
+ :param clip_type: 应为 ``'norm'``, ``'value'`` 中的一个:
+
+ 1. 为 ``'norm'`` 时, 将梯度的范数限制在 [-clip_value, clip_value] 之间;
+ 2. 为 ``'value'`` 时,, 将梯度限制在 [-clip_value, clip_value] 之间,小于 ``-clip_value``
+ 的梯度被赋值为 ``-clip_value``,大于 ``clip_value`` 的梯度被赋值为 ``clip_value``;
+
+ :param parameters: 参数,一般通过 :func:`model.parameters` 获得。
+ 如果为 ``None`` 则默认对 Trainer 的 optimizers 中所有参数进行梯度裁剪。
+ """
+ def __init__(self, clip_value:int=1, clip_type:str='norm',
+ parameters:Union["torch.Tensor", List["torch.Tensor"]]=None):
+ super().__init__()
+
+ from torch import nn
+ if clip_type == 'norm':
+ self.clip_fun = nn.utils.clip_grad_norm_
+ elif clip_type == 'value':
+ self.clip_fun = nn.utils.clip_grad_value_
+ else:
+ raise ValueError("Only supports `norm` or `value` right now.")
+ if parameters is not None:
+ self.parameters = list(parameters)
+ else:
+ self.parameters = None
+ self.clip_value = clip_value
+ self.clip_type = clip_type
+
+ def on_after_trainer_initialized(self, trainer, driver):
+ assert isinstance(driver, TorchDriver), f"Callback:{self.__class__.__name__} only supports torch " \
+ f"related drivers for now."
+ parameters = []
+ for optimizer in trainer.driver.optimizers:
+ for param_group in optimizer.param_groups:
+ parameters.extend(param_group['params'])
+ self.parameters = parameters
+ if isinstance(trainer.driver, FairScaleDriver):
+ if isinstance(trainer.driver.model, FullyShardedDataParallel) and self.clip_type == 'norm':
+ self.clip_fun = trainer.driver.model.clip_grad_norm_
+
+ assert len(self.parameters), "There is no parameters need to be clipped."
+
+ def on_before_optimizers_step(self, trainer, optimizers):
+ for optimizer in trainer.driver.optimizers:
+ trainer.driver.grad_scaler.unscale_(optimizer)
+ self.clip_fun(self.parameters, self.clip_value)
diff --git a/fastNLP/core/callbacks/torch_callbacks/torch_lr_sched_callback.py b/fastNLP/core/callbacks/torch_callbacks/torch_lr_sched_callback.py
new file mode 100644
index 00000000..97e1c544
--- /dev/null
+++ b/fastNLP/core/callbacks/torch_callbacks/torch_lr_sched_callback.py
@@ -0,0 +1,59 @@
+__all__ = [
+ 'TorchWarmupCallback'
+]
+import math
+from typing import Union
+
+from ..callback import Callback
+
+
+class TorchWarmupCallback(Callback):
+ r"""
+ 调整学习率的 **callback** 。
+
+ :param warmup: 如果 ``warmup`` 为整数,则在该 step 之前,学习率根据 ``schedule`` 的策略变化; 如果 ``warmup`` 为 ``float``,
+ 如 0.1, 则前 10% 的 step 是按照 ``schedule`` 策略调整。
+ :param schedule: 对学习率进行调整的策略:
+
+ 1. *linear* -- 前 ``warmup`` 的 step 上升到指定的学习率(从 Trainer 中 optimizer 处获取), 在剩下的 step 中下降到 0;
+ 2. *constant* -- 前 ``warmup`` 的 step 上升到指定的学习率,余下的 step 保持不变。
+ """
+ def __init__(self, warmup:Union[int, float]=0.1, schedule:str='linear'):
+ super().__init__()
+ self.warmup = max(warmup, 0.)
+
+ self.initial_lrs = [] # 存放param_group的learning rate
+ if schedule == 'constant':
+ self.get_lr = self._get_constant_lr
+ elif schedule == 'linear':
+ self.get_lr = self._get_linear_lr
+ else:
+ raise RuntimeError("Only support 'linear', 'constant'.")
+
+ def _get_constant_lr(self, progress):
+ if progress 1:
+ self.warmup = self.warmup / trainer.n_batches
+ self.t_steps = max(2, trainer.n_batches) # 不能小于2
+ # 防止 t_steps 不能整除 accumulation_steps
+ self.t_steps = math.ceil(self.t_steps/trainer.accumulation_steps) * trainer.accumulation_steps
+ # 获取param_group的初始learning rate
+ for optimizer in trainer.driver.optimizers:
+ for group in optimizer.param_groups:
+ self.initial_lrs.append(group['lr'])
+
+ def on_before_optimizers_step(self, trainer, optimizers):
+ # 这里需要加 accumulation_steps 是防止 lr 从 0 开始
+ progress = (trainer.global_forward_batches + trainer.accumulation_steps) / self.t_steps
+ for optimizer in trainer.driver.optimizers:
+ for lr, group in zip(self.initial_lrs, optimizer.param_groups):
+ group['lr'] = lr * self.get_lr(progress)
diff --git a/fastNLP/core/callbacks/utils.py b/fastNLP/core/callbacks/utils.py
new file mode 100644
index 00000000..436560d9
--- /dev/null
+++ b/fastNLP/core/callbacks/utils.py
@@ -0,0 +1,60 @@
+from typing import Optional, Union, Tuple
+import os
+
+from fastNLP.core.log.logger import logger
+from difflib import SequenceMatcher
+from fastNLP.core.utils.utils import _get_fun_msg
+
+
+def _get_monitor_value(monitor: Union[callable, str], real_monitor: Optional[str], res: dict) ->Tuple[str, float]:
+ """
+ 从 ``res`` 中寻找 ``monitor`` 并返回。如果 ``monitor`` 没找到则尝试用 ``_real_monitor`` ,若 ``_real_monitor`` 为 ``None``
+ 则尝试使用 ``monitor`` 的值进行匹配。
+
+ :param monitor:
+ :param real_monitor:
+ :param res:
+ :return: 两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value。如果value为None说明当前results中没有
+ 找到对应的 monitor
+ """
+ if len(res) == 0 or monitor is None:
+ return monitor, None
+
+ if callable(monitor):
+ try:
+ monitor_value = monitor(res)
+ except BaseException as e:
+ logger.error(f"Exception happens when calling customized monitor function:{_get_fun_msg(monitor)}.")
+ raise e
+ return monitor, monitor_value
+
+ if monitor in res:
+ return monitor, res[monitor]
+
+ if real_monitor in res:
+ return real_monitor, res[real_monitor]
+
+ pairs = []
+ for idx, (key, value) in enumerate(res.items()):
+ match_size = _match_length(monitor, key)
+ pairs.append((key, value, match_size, idx))
+
+ pairs.sort(key=lambda pair: (pair[2], -pair[3]), reverse=True)
+ key, value, match_size = pairs[0][:3]
+
+ return key, value
+
+
+def _match_length(a:str, b:str)->int:
+ """
+ 需要把长度短的放在前面
+
+ :param a:
+ :param b:
+ :return:
+ """
+ short = a if len(a) < len(b) else b
+ long = a if len(a)>=len(b) else b
+ match = SequenceMatcher(None, short, long).find_longest_match(0, len(short), 0, len(long))
+ return match.size
+
diff --git a/fastNLP/core/collate_fn.py b/fastNLP/core/collate_fn.py
deleted file mode 100644
index 403af270..00000000
--- a/fastNLP/core/collate_fn.py
+++ /dev/null
@@ -1,147 +0,0 @@
-r"""undocumented"""
-from builtins import sorted
-
-import torch
-import numpy as np
-from .field import _get_ele_type_and_dim
-from .utils import logger
-from copy import deepcopy
-
-__all__ = ['ConcatCollateFn']
-
-
-def _check_type(batch_dict, fields):
- if len(fields) == 0:
- raise RuntimeError
- types = []
- dims = []
- for f in fields:
- t, d = _get_ele_type_and_dim(batch_dict[f])
- types.append(t)
- dims.append(d)
- diff_types = set(types)
- diff_dims = set(dims)
- if len(diff_types) > 1 or len(diff_dims) > 1:
- raise ValueError
- return types[0]
-
-
-def batching(samples, max_len=0, padding_val=0):
- if len(samples) == 0:
- return samples
- if max_len <= 0:
- max_len = max(s.shape[0] for s in samples)
- batch = np.full((len(samples), max_len), fill_value=padding_val)
- for i, s in enumerate(samples):
- slen = min(s.shape[0], max_len)
- batch[i][:slen] = s[:slen]
- return batch
-
-
-class Collater:
- r"""
- 辅助DataSet管理collate_fn的类
-
- """
- def __init__(self):
- self.collate_fns = {}
-
- def add_fn(self, fn, name=None):
- r"""
- 向collater新增一个collate_fn函数
-
- :param callable fn:
- :param str,int name:
- :return:
- """
- if name in self.collate_fns:
- logger.warn(f"collate_fn:{name} will be overwritten.")
- if name is None:
- name = len(self.collate_fns)
- self.collate_fns[name] = fn
-
- def is_empty(self):
- r"""
- 返回是否包含collate_fn
-
- :return:
- """
- return len(self.collate_fns) == 0
-
- def delete_fn(self, name=None):
- r"""
- 删除collate_fn
-
- :param str,int name: 如果为None就删除最近加入的collate_fn
- :return:
- """
- if not self.is_empty():
- if name in self.collate_fns:
- self.collate_fns.pop(name)
- elif name is None:
- last_key = list(self.collate_fns.keys())[0]
- self.collate_fns.pop(last_key)
-
- def collate_batch(self, ins_list):
- bx, by = {}, {}
- for name, fn in self.collate_fns.items():
- try:
- batch_x, batch_y = fn(ins_list)
- except BaseException as e:
- logger.error(f"Exception:`{e}` happens when call collate_fn:`{name}`.")
- raise e
- bx.update(batch_x)
- by.update(batch_y)
- return bx, by
-
- def copy_from(self, col):
- assert isinstance(col, Collater)
- new_col = Collater()
- new_col.collate_fns = deepcopy(col.collate_fns)
- return new_col
-
-
-class ConcatCollateFn:
- r"""
- field拼接collate_fn,将不同field按序拼接后,padding产生数据。
-
- :param List[str] inputs: 将哪些field的数据拼接起来, 目前仅支持1d的field
- :param str output: 拼接后的field名称
- :param pad_val: padding的数值
- :param max_len: 拼接后最大长度
- :param is_input: 是否将生成的output设置为input
- :param is_target: 是否将生成的output设置为target
- """
-
- def __init__(self, inputs, output, pad_val=0, max_len=0, is_input=True, is_target=False):
- super().__init__()
- assert isinstance(inputs, list)
- self.inputs = inputs
- self.output = output
- self.pad_val = pad_val
- self.max_len = max_len
- self.is_input = is_input
- self.is_target = is_target
-
- @staticmethod
- def _to_numpy(seq):
- if torch.is_tensor(seq):
- return seq.numpy()
- else:
- return np.array(seq)
-
- def __call__(self, ins_list):
- samples = []
- for i, ins in ins_list:
- sample = []
- for input_name in self.inputs:
- sample.append(self._to_numpy(ins[input_name]))
- samples.append(np.concatenate(sample, axis=0))
- batch = batching(samples, max_len=self.max_len, padding_val=self.pad_val)
- b_x, b_y = {}, {}
- if self.is_input:
- b_x[self.output] = batch
- if self.is_target:
- b_y[self.output] = batch
-
- return b_x, b_y
diff --git a/fastNLP/core/collators/__init__.py b/fastNLP/core/collators/__init__.py
new file mode 100644
index 00000000..3033c37e
--- /dev/null
+++ b/fastNLP/core/collators/__init__.py
@@ -0,0 +1,21 @@
+__all__ = [
+ 'Collator',
+
+ 'NumpyNumberPadder',
+ 'NumpySequencePadder',
+ "NumpyTensorPadder",
+ "Padder",
+ "NullPadder",
+ "RawNumberPadder",
+ "RawSequencePadder",
+ "RawTensorPadder",
+ 'TorchNumberPadder',
+ 'TorchSequencePadder',
+ 'TorchTensorPadder',
+ "PaddleNumberPadder",
+ "PaddleTensorPadder",
+ "PaddleSequencePadder",
+ "get_padded_numpy_array",
+]
+from .collator import Collator
+from .padders import *
diff --git a/fastNLP/core/collators/collator.py b/fastNLP/core/collators/collator.py
new file mode 100644
index 00000000..5a3b1967
--- /dev/null
+++ b/fastNLP/core/collators/collator.py
@@ -0,0 +1,355 @@
+__all__ = [
+ 'Collator'
+]
+
+from typing import List, Union, Dict, Callable, Sequence, Mapping
+import os
+import sys
+import inspect
+import re
+
+from fastNLP.core.log import logger
+from .padders.get_padder import get_padder
+from ...envs import SUPPORT_BACKENDS
+from .padders import Padder
+
+
+from .packer_unpacker import SequencePackerUnpacker, SinglePackerUnpacker, MappingPackerUnpacker, \
+ NestedMappingPackerUnpacker
+
+sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1
+SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'oneflow', 'numpy', 'raw', 'auto', None]
+# 由于 jittor DataLoader 存在自动的 to_jittor 的转换,所以只需要 collate 成为 numpy 就行
+AUTO_BACKEND_MAPPING = {'jittor': 'numpy'}
+
+def _get_backend() -> str:
+ """
+ 当 Collator 的 backend 为 None 的时候如何,通过这个函数自动判定其 backend 。判断方法主要为以下两个:
+ (1)尝试通过向上寻找当前 collator 的 callee 对象,根据 callee 对象寻找。然后使用 '/site-packages/{backend}' 来寻找是否是
+ 某个 backend 的 dataloader 。
+ (2)如果方式(1)没找,则通过分析 sys.modules 中的内容进行寻找。
+
+ 如果都没有找到则返回 numpy 。
+ :return:
+ """
+ def _check_module(module):
+ """
+ 检查该 module 是否含有 某个 backend 的特征
+
+ :param module: module 对象
+ :return:
+ """
+ catch_backend = []
+ try:
+ file = module.__file__
+ for backend in SUPPORT_BACKENDS:
+ if f'{os.sep}site-packages{os.sep}{backend}' in file:
+ catch_backend = [backend, file]
+ except:
+ pass
+ return catch_backend
+
+ currentframe = inspect.currentframe()
+ # 方式(1)
+ catch_backend = []
+ for i in range(100):
+ currentframe = currentframe.f_back
+ if currentframe is not None:
+ module = inspect.getmodule(currentframe)
+ if module is not None:
+ catch_backend = _check_module(module)
+ if len(catch_backend): # 主要捕获到一个就结束吧
+ break
+ else:
+ break
+ if len(catch_backend):
+ logger.debug(f"Find a file named:{catch_backend[1]} from stack contains backend:{catch_backend[0]}.")
+ return AUTO_BACKEND_MAPPING.get(catch_backend[0], catch_backend[0])
+
+ # 方式 (2)
+ for backend in SUPPORT_BACKENDS:
+ if backend in sys.modules:
+ logger.debug(f"sys.modules contains backend:{backend}.")
+ return backend
+ for key, module in sys.modules.items():
+ catch_backend = _check_module(module)
+ if catch_backend:
+ break
+ if len(catch_backend):
+ logger.debug(f"Find a module file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.")
+ return catch_backend[0]
+
+ return 'numpy'
+
+
+class Collator:
+ """
+ 用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。
+ 判定一个 field 是否可以 pad 的方式为:
+
+ 1. 当前这个 field 是否所有对象都是一样的数据类型;比如,如果某 field 的数据有些是 float ,有些是 int ,则该 field 将被
+ 判定为不可 pad 类型;
+ 2. 当前这个 field 是否每个 sample 都具有一样的深度;比如,如果某 field 的数据转为 batch 类型后为 ``[1, [1,2]]``, 则会
+ 被判定为不可 pad ,因为第一个 sample 与 第二个 sample 深度不同;
+ 3. 当前这个 field 的类型是否是可以 pad (例如 str 类型的数据)。可以通过设置 ``logger.setLevel('debug')`` 来打印是判定不可
+ pad 的原因。
+
+ .. note::
+
+ ``Collator`` 的原理是使用第一个 ``batch`` 的数据尝试推断每个 ``field`` 应该使用哪种类型的 ``Padder``,如果第一个 ``batch``
+ 的数据刚好比较特殊,可能导致在之后的 pad 中遭遇失败,这种情况请通过 :meth:`set_pad` 函数手动设置一下。
+
+ .. todo::
+
+ 补充 code example 。
+
+ 如果需要将某个本可以 pad 的 field 设置为不可 pad ,则可以通过 :meth:`~fastNLP.Collator.set_pad` 的 ``pad_val`` 设置为 ``None`` 实现。
+ 如果需要某些 field 不要包含在 pad 之后的结果中,可以使用 :meth:`~fastNLP.Collator.set_ignore` 进行设置。
+
+ Collator 在第一次进行 pad 的时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应
+ 的 Padder 给对应的 field 。由于 Collator 只能在某个 field 内进行 pad ,如果 pad 操作需要同时操作多个 field ,请不要使用 Collator 。
+
+ :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ``['torch','jittor','paddle','oneflow','numpy','raw', 'auto', None]``。
+ 若为 ``'auto'`` ,则在进行 pad 的时候会根据调用的环境决定其 ``backend`` 。该参数对不能进行 pad 的数据没有影响,无法 pad 的数据返回一定
+ 是 :class:`list` 。
+ """
+ def __init__(self, backend='auto'):
+ self.unpack_batch_func = None
+ self.pack_batch_func = None
+ self.ignore_fields = set()
+ self.padders = {}
+ self.input_fields = {}
+ self.batch_data_type = None # 只能是 d ,s ,l 三种,分别对应输入的batch的每个sample为 dict, single,list。
+ self.set_backend(backend)
+
+ def __call__(self, batch)->Union[List, Dict]:
+ """
+ batch可能存在三种可能性:List[Dict], List[List], List[Sample]
+
+ 第一步:使用 unpack_batch_func 将相同 field 的内容打包到一个 list 中。
+ 第二步:使用每个 field 各自的 padder 进行 pad 。
+ 第三步:根据 batch 中每个 sample 的类型,返回也保证为该类型。
+
+ 第一次调用会根据当前 batch 数据决定使用哪个 unpack_batch_func ,这个函数的作用是把不同 sample 的同一个 field 的放入到一个
+ list 中;同时也会决定 pack_batch_func,这个函数的作用是在返回 pad 好的 batch 之前,将 batch 恢复为 输入时一个 sample
+ 的类别。
+ 第一次调用会根据当前 field 决定对应的 Padder 。
+
+ """
+ if self.unpack_batch_func is None:
+ # 决定使用哪个unpack_batch_func,让它都 return 回 dict 类型
+ if self.batch_data_type is None:
+ if isinstance(batch[0], Mapping):
+ self.batch_data_type = 'd'
+ elif isinstance(batch[0], Sequence): # 这里存在误判的风险
+ self.batch_data_type = 'l'
+ else:
+ self.batch_data_type = 's'
+ logger.debug(f"Since batch[0] has type:{type(batch[0])}, so the batch_data_type "
+ f"is `{self.batch_data_type}`.")
+ if self.batch_data_type == 's':
+ self.packer_unpacker = SinglePackerUnpacker() # 不需要做任何调整
+ elif self.batch_data_type == 'l':
+ self.packer_unpacker = SequencePackerUnpacker()
+ elif self.batch_data_type == 'd':
+ if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{('a', 'b'): value}
+ self.packer_unpacker = NestedMappingPackerUnpacker()
+ else:
+ self.packer_unpacker = MappingPackerUnpacker()
+
+ # 将 batch 中各个 field 组成自己的 batch;同时忽略处于 ignore_fields 中的数据。
+ unpack_batch = self.packer_unpacker.unpack_batch(batch, self.ignore_fields, self.input_fields)
+
+ pad_batch = {}
+ if len(self.padders)==0: # 第一次运行,准备 padder
+ if self.backend == 'auto': # 如果 backend 为 auto ,则尝试通过调用栈等自动获取 backend 。
+ self.backend = _get_backend()
+
+ for field_name, batch_field in unpack_batch.items():
+ setting = self.input_fields.get(field_name, {'backend': self.backend, 'pad_val': 0 ,
+ 'dtype': None, 'pad_fn': None})
+ pad_fn = setting['pad_fn']
+ if callable(pad_fn):
+ padder = pad_fn
+ else:
+ backend = self.backend if setting['backend'] == 'auto' else setting['backend']
+ padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'],
+ dtype=setting['dtype'], backend=backend,
+ field_name=field_name)
+ self.padders[field_name] = padder
+
+ if self.batch_data_type == 'l':
+ self.padders = dict(sorted(self.padders.items(), key=lambda x:int(x[0][1:]))) # sort, 这样 _0, _1 能够保持顺序
+ try:
+ for key, padder in self.padders.items():
+ batch = unpack_batch.get(key)
+ pad_batch[key] = padder(batch)
+ except BaseException as e:
+ try:
+ logger.error(f"The following exception happens when try to pad the `{key}` field with padder:{padder}:")
+ except:
+ pass
+ raise e
+
+ return self.packer_unpacker.pack_batch(pad_batch) # 根据情况恢复成与输入一致的类型
+
+ def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend='auto',
+ pad_fn:Callable=None) -> "Collator":
+ """
+ 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。
+
+ :param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的
+ field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``;
+ 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。
+ 如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 "_single" 。
+ :param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
+ field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``,
+ 该值无意义。
+ :param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。
+ :param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`,
+ :class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。
+ 若 ``pad_val`` 为 ``None`` ,该值无意义 。
+ :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的
+ batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。
+ :return: Collator 自身;
+ """
+ self._renew()
+
+ if self.batch_data_type == 's':
+ logger.debug("Set as single field mode.")
+ self.input_fields.clear()
+ elif self.batch_data_type == 'd':
+ if isinstance(field_name, str):
+ assert sequence_idx_str.match(field_name) is None, f"Field name:{field_name} will be recognized as list " \
+ f"index, but other field is set as dict mode."
+ elif self.batch_data_type == 'l':
+ if isinstance(field_name, str):
+ assert sequence_idx_str.match(field_name) is not None, f"Other field is set as list mode. But the new " \
+ f"field name is {field_name}."
+
+ if field_name == '_single':
+ self.batch_data_type = 's'
+ elif isinstance(field_name, str) and sequence_idx_str.match(field_name):
+ self.batch_data_type = 'l'
+ else:
+ self.batch_data_type = 'd'
+
+ # 检测是否已经设置了,主要需要考虑它的父亲节点的情况
+ ignore_fields = [(field, field) if isinstance(field, tuple) else ((field,), field)
+ for field in self.ignore_fields]
+ input_field_names = [(field, field) if isinstance(field, tuple) else ((field,), field)
+ for field in self.input_fields.keys()]
+ if isinstance(field_name, tuple):
+ _field_name = field_name
+ else:
+ _field_name = (field_name,)
+ for field, o_field in ignore_fields:
+ d = _compare_tuple(field, _field_name)
+ if d is None:
+ continue
+ if d == 0:
+ logger.rank_zero_warning(f"Field:`{field_name}` has been set as ignored before. It will not be "
+ f"ignored afterwards.")
+ self.ignore_fields.remove(o_field)
+ if d > 0:
+ raise KeyError(f"Cannot set `{field_name}` as input, since its children `{o_field}` has been set "
+ f"as ignore field.")
+ if d < 0:
+ raise KeyError(f"Cannot set `{field_name}` as input, since its parent `{o_field}` has been set "
+ f"as ignore field.")
+ for field, o_field in input_field_names:
+ d = _compare_tuple(field, _field_name)
+ if d is None:
+ continue
+ if d > 0:
+ raise KeyError(f"Cannot set `{field_name}` as input, since its children `{o_field}` has been set "
+ f"pad.")
+ if d < 0:
+ raise KeyError(f"Cannot set `{field_name}` as input, since its parent `{o_field}` has been set "
+ f"pad.")
+
+ if backend is None:
+ backend = self.backend
+ else:
+ assert backend in SUPPORTED_BACKENDS
+
+ self.input_fields[field_name] = {'pad_val': pad_val, 'dtype': dtype, 'backend': backend, 'pad_fn': pad_fn}
+
+ return self
+
+ def set_backend(self, backend:str):
+ """
+ 设置可以 pad 的 field 默认 pad 为什么类型的 tensor
+
+ :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ``['torch','jittor','paddle','oneflow','numpy','raw', 'auto', None]``,
+ 若为 ``'auto'`` ,则在进行 pad 的时候会自动根据调用的环境决定其 ``backend`` ;
+ :return:
+ """
+ assert backend in SUPPORTED_BACKENDS
+ self._renew()
+ self.backend = backend
+
+ def set_ignore(self, *field_names) -> "Collator":
+ """
+ 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略::
+
+ >>> collator = Collator().set_ignore('field1', 'field2')
+
+ :param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的
+ field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``;
+ 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。
+ :return: Collator 自身;
+ """
+ self._renew()
+ input_field_names = [(field, field) if isinstance(field, tuple) else ((field,), field)
+ for field in self.input_fields.keys()]
+
+ # 需要考虑父节点之类的情况
+ for field in field_names:
+ if not isinstance(field, tuple):
+ _field = (field,)
+ else:
+ _field = field
+ for _field_name, o_field_name in input_field_names:
+ d = _compare_tuple(_field, _field_name)
+ if d is None:
+ continue
+ if d == 0:
+ self.input_fields.pop(o_field_name)
+ logger.rank_zero_warning(f"Field:{o_field_name} has been set as pad before. It will be ignored afterwards.")
+ if d < 0:
+ self.input_fields.pop(o_field_name)
+ logger.rank_zero_warning(f"Field:{o_field_name} has been set as pad before. It will be ignored afterwards.")
+ if d > 0:
+ raise KeyError(f"Cannot ignore {field} since its parent key {o_field_name} has been set as pad.")
+ self.ignore_fields.add(field)
+
+ return self
+
+ def _renew(self):
+ self.packer_unpacker = None
+ self.padders.clear()
+
+
+def _compare_tuple(t1, t2):
+ """
+ 检测 t1 和 t2 的关系。
+ 例如 (1, ) 和 (1, ) 关系为 0,表示两者完全没有差异
+ 例如 (1, ) 和 (2, ) 关系为 None,表示完全不同
+ 例如 (1, 2, 3) 和 (1, ) 关系为 2,表示前者比后者长 2 位
+ 但 例如 (1, 2, 3) 和 (2, ) 关系为 None,因为它们从前往后的key 不一样
+ 例如 (1, 2, 3) 和 (1, 3) 关系为 None,因为它们从前往后的key 不一样
+
+ 例如 (1, ) 和 (1, 2, 3) 关系为 -2,表示后者比前者长 2 位
+ 但 例如 (2, ) 和 (1, 2, 3) 关系为 None,因为它们从前往后的key 不一样
+ 例如 (1, 3) 和 (1, 2, 3) 关系为 None,因为它们从前往后的key 不一样
+ :param t1:
+ :param t2:
+ :return: None 没有关系; 0 两者完全一样; >0 t1比t2长,<0 t2比t1长
+ """
+ if t1 == t2:
+ return 0
+ for _t1, _t2 in zip(t1, t2): # 会按照最短的计算
+ if _t1 != _t2:
+ return None
+ return len(t1) - len(t2)
diff --git a/fastNLP/core/collators/packer_unpacker.py b/fastNLP/core/collators/packer_unpacker.py
new file mode 100644
index 00000000..7d9c23cd
--- /dev/null
+++ b/fastNLP/core/collators/packer_unpacker.py
@@ -0,0 +1,149 @@
+from collections import defaultdict
+from functools import reduce
+from typing import Sequence, Mapping, Dict
+
+__all__ = []
+
+class MappingPackerUnpacker:
+ @staticmethod
+ def unpack_batch(batch:Sequence[Mapping], ignore_fields:set, input_fields:Dict)->Dict:
+ """
+ 将 Sequence[Mapping] 转为 Dict 。例如 [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] -> {'a': [[1, 2], [3]], 'b': [1, 2]}
+
+ :param batch: 需要 unpack 的 batch 数据。
+ :param ignore_fields: 需要忽略的 field 。
+ :param input_fields: 需要设置为 input 的 field 。
+ :return:
+ """
+ dict_batch = defaultdict(list)
+ for sample in batch:
+ for key, value in sample.items():
+ if key in ignore_fields:
+ continue
+ dict_batch[key].append(value)
+ return dict_batch
+
+ @staticmethod
+ def pack_batch(batch):
+ return batch
+
+
+class NestedMappingPackerUnpacker:
+ @staticmethod
+ def unpack_batch(batch:Sequence[Mapping], ignore_fields:set, input_fields:Dict)->Dict:
+ """
+ 将 nested 的 dict 中的内容展开到一个 flat dict 中
+
+ :param batch: 需要 unpack 的 batch 数据。
+ :param ignore_fields: 需要忽略的 field 。
+ :param input_fields: 需要设置为 input 的 field 。
+ :return:
+ """
+ dict_batch = defaultdict(list)
+ for sample in batch:
+ for key, value in sample.items():
+ if key in ignore_fields:
+ continue
+ if isinstance(value, Mapping) and key not in input_fields:
+ _dict_batch = _unpack_batch_nested_mapping(value, ignore_fields, input_fields, _parent=(key,))
+ for key, value in _dict_batch.items():
+ dict_batch[key].append(value)
+ else:
+ dict_batch[key].append(value)
+ return dict_batch
+
+ @staticmethod
+ def pack_batch(batch):
+ if len(batch) == 0:
+ return []
+ dicts = []
+ for key, value in batch.items():
+ if not isinstance(key, tuple):
+ key = [key]
+ d = {key[-1]: value}
+ for k in key[:-1:][::-1]:
+ d = {k: d}
+ dicts.append(d)
+ return reduce(_merge_dict, dicts)
+
+
+class SequencePackerUnpacker:
+ @staticmethod
+ def unpack_batch(batch:Sequence[Sequence], ignore_fields, input_fields)->Dict:
+ """
+ 将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [2, 2]}
+
+ :param batch: 需要 unpack 的 batch 数据。
+ :param ignore_fields: 需要忽略的 field 。
+ :param input_fields: 需要设置为 input 的 field 。
+ :return:
+ """
+ dict_batch = defaultdict(list)
+ for sample in batch:
+ for i, content in enumerate(sample):
+ field_name = f'_{i}'
+ if field_name in ignore_fields:
+ continue
+ dict_batch[field_name].append(content)
+ return dict_batch
+
+ @staticmethod
+ def pack_batch(batch):
+ return list(batch.values())
+
+
+class SinglePackerUnpacker:
+ @staticmethod
+ def unpack_batch(batch:Sequence[Sequence], ignore_fields, input_fields):
+ return {'_single': batch}
+
+ @staticmethod
+ def pack_batch(batch):
+ return batch['_single']
+
+
+def _unpack_batch_nested_mapping(value, ignore_fields, stop_deep_fields, _parent)->Dict:
+ _dict = {}
+ for k, v in value.items():
+ _k = _parent + (k,)
+ if _k in ignore_fields:
+ continue
+ if isinstance(v, Mapping) and _k not in stop_deep_fields:
+ __dict = _unpack_batch_nested_mapping(v, ignore_fields, stop_deep_fields, _parent=_k)
+ _dict.update(__dict)
+ else:
+ _dict[_k] = v
+ return _dict
+
+
+def pack_batch_nested_mapping(batch:Mapping) -> Dict:
+ """
+ 需要恢复出 nested 的 dict 原来的样式
+
+ :param batch:
+ :return:
+ """
+ dicts = []
+
+ for key, value in batch.items():
+ if not isinstance(key, tuple):
+ key = [key]
+ d = {key[-1]: value}
+ for k in key[:-1:][::-1]:
+ d = {k: d}
+ dicts.append(d)
+ return reduce(_merge_dict, dicts)
+
+
+def _merge_dict(a, b, path=None):
+ "merges b into a"
+ if path is None: path = []
+ for key in b:
+ if key in a:
+ if isinstance(a[key], dict) and isinstance(b[key], dict):
+ _merge_dict(a[key], b[key], path + [str(key)])
+ else:
+ raise Exception('Conflict at %s' % '.'.join(path + [str(key)]))
+ else:
+ a[key] = b[key]
+ return a
diff --git a/fastNLP/core/collators/padders/__init__.py b/fastNLP/core/collators/padders/__init__.py
new file mode 100644
index 00000000..11ffc07b
--- /dev/null
+++ b/fastNLP/core/collators/padders/__init__.py
@@ -0,0 +1,31 @@
+
+__all__ = [
+ 'NumpyNumberPadder',
+ 'NumpySequencePadder',
+ "NumpyTensorPadder",
+
+ "Padder",
+ "NullPadder",
+
+ "RawNumberPadder",
+ "RawSequencePadder",
+ "RawTensorPadder",
+
+ 'TorchNumberPadder',
+ 'TorchSequencePadder',
+ 'TorchTensorPadder',
+
+ "PaddleNumberPadder",
+ "PaddleTensorPadder",
+ "PaddleSequencePadder",
+
+ "get_padded_numpy_array",
+]
+
+
+from .numpy_padder import *
+from .padder import Padder, NullPadder
+from .raw_padder import *
+from .torch_padder import *
+from .paddle_padder import *
+from .utils import get_padded_numpy_array
\ No newline at end of file
diff --git a/fastNLP/core/collators/padders/exceptions.py b/fastNLP/core/collators/padders/exceptions.py
new file mode 100644
index 00000000..a2b97cbf
--- /dev/null
+++ b/fastNLP/core/collators/padders/exceptions.py
@@ -0,0 +1,51 @@
+__all__ = [
+ 'InconsistencyError',
+ 'EleDtypeUnsupportedError',
+ 'EleDtypeDtypeConversionError',
+ 'DtypeUnsupportedError',
+ "DtypeError",
+ "NoProperPadderError"
+]
+
+
+class InconsistencyError(BaseException):
+ """
+ 当一个 batch 中的数据存在 shape,dtype 之类的不一致时的报错。
+
+ """
+ def __init__(self, msg, *args):
+ super(InconsistencyError, self).__init__(msg, *args)
+
+
+class DtypeError(BaseException):
+ def __init__(self, msg, *args):
+ super(DtypeError, self).__init__(msg, *args)
+ self.msg = msg
+
+
+class NoProperPadderError(BaseException):
+ def __init__(self, msg, *args):
+ super(NoProperPadderError, self).__init__(msg, *args)
+ self.msg = msg
+
+
+class EleDtypeUnsupportedError(DtypeError):
+ """
+ 当 batch 中的 element 的类别本身无法 pad 的时候报错。
+ 例如要求 str 类型的数据进行 padding 。
+
+ """
+
+
+class EleDtypeDtypeConversionError(DtypeError):
+ """
+ 当 batch 中的 element 的类别无法转换为 dtype 类型时报错。
+
+ """
+
+
+class DtypeUnsupportedError(DtypeError):
+ """
+ 当当前 backend 不支持这种类型的 dtype 时报错。
+
+ """
\ No newline at end of file
diff --git a/fastNLP/core/collators/padders/get_padder.py b/fastNLP/core/collators/padders/get_padder.py
new file mode 100644
index 00000000..41bcd8c0
--- /dev/null
+++ b/fastNLP/core/collators/padders/get_padder.py
@@ -0,0 +1,201 @@
+from typing import Sequence, Any, Union, Dict
+from abc import ABC
+
+from fastNLP.core.log import logger
+
+
+from .padder import Padder, NullPadder
+from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPadder
+from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder
+from .raw_padder import RawNumberPadder, RawSequencePadder, RawTensorPadder
+from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder
+from .jittor_padder import JittorTensorPadder, JittorSequencePadder, JittorNumberPadder
+from .oneflow_padder import OneflowTensorPadder, OneflowSequencePadder, OneflowNumberPadder
+from .exceptions import *
+
+
+def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->Padder:
+ """
+ 根据 参数 与 ``batch_field`` ,返回适合于当前 ``batch_field`` 的 *padder* 。
+
+ :param batch_field: 将某 field 的内容组合成一个 batch 传入;
+ :param pad_val:
+ :param backend:
+ :param dtype:
+ :param field_name: field 名称,方便在报错时显示;
+ :return:
+ """
+ try:
+ assert len(batch_field)!=0, "Empty batch encountered."
+ logger.debug(f"The content in the field:`{field_name}` is:\n" + str(batch_field))
+ if pad_val is None:
+ logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.")
+ return NullPadder()
+ if backend is None:
+ logger.debug(f"The backend for field:{field_name} is None, not padding this field.")
+ return NullPadder()
+
+ # 首先判断当前 field 是否是必须要 pad ,根据用户设置的 pad_val、dtype 等判断。
+ must_pad = False
+ if pad_val != 0 or dtype is not None:
+ must_pad = True
+
+ catalog = _get_element_shape_dtype(batch_field) # 首先获取数据的基本信息。
+
+ # 根据 catalog 来判定当前是否可以进行 pad 。
+ # 首先检查是否所有的 key 是一样长的,表明深度是一致的
+ depths = set(map(len, catalog.keys()))
+ num_depth = len(depths)
+ if num_depth != 1:
+ msg = f'Field:`{field_name}` cannot pad, since it has various depths({depths}) of data. To view more ' \
+ f"information please set logger's level to DEBUG."
+ if must_pad:
+ raise InconsistencyError(msg)
+ raise NoProperPadderError(msg)
+
+ # 再检查所有的元素 shape 是否一致?
+ shape_lens = set([len(v[0]) for v in catalog.values()])
+ num_shape = len(shape_lens)
+ if num_shape != 1:
+ msg = f'Field:`{field_name}` cannot pad, since it has various shape length({shape_lens}) of data. To view more ' \
+ f"information please set logger's level to DEBUG."
+ if must_pad:
+ raise InconsistencyError(msg)
+ raise NoProperPadderError(msg)
+
+ # 再检查所有的元素 type 是否一致
+ try:
+ ele_dtypes = set([v[1] for v in catalog.values()])
+ except TypeError:
+ ele_dtypes = set([str(v[1]) for v in catalog.values()])
+ num_eletypes = len(ele_dtypes)
+ if num_eletypes != 1:
+ msg = f'Field:`{field_name}` cannot pad, since it has various types({ele_dtypes}) of data. To view more ' \
+ f"information please set logger's level to DEBUG."
+ if must_pad:
+ raise InconsistencyError(msg)
+ raise NoProperPadderError(msg)
+
+ depth = depths.pop()
+ shape_len = shape_lens.pop()
+ ele_dtype = list(catalog.values())[0][1] # 因为上面有except的情况,所以这样处理了
+
+ # 需要由 padder 自己决定是否能够 pad 。
+ if depth == 1 and shape_len == 0: # 形如 [0, 1, 2] 或 [True, False, True]
+ if backend == 'raw':
+ return RawNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
+ elif backend == 'numpy':
+ return NumpyNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
+ elif backend == 'torch':
+ return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
+ elif backend == 'paddle':
+ return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
+ elif backend == 'jittor':
+ return JittorNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
+ elif backend == 'oneflow':
+ return OneflowNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
+ else:
+ raise ValueError(f"backend={backend} is not supported for list(Field:{field_name}).")
+
+ if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种
+ if backend == 'raw':
+ return RawSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
+ elif backend == 'numpy':
+ return NumpySequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
+ elif backend == 'torch':
+ return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
+ elif backend == 'paddle':
+ return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
+ elif backend == 'jittor':
+ return JittorSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
+ elif backend == 'oneflow':
+ return OneflowSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
+ else:
+ raise ValueError(f"backend={backend} is not supported for nested list(Field:{field_name}).")
+
+ # 如果有有 shape 的话,只有当该对象拥有 tolist() 方法才行
+ if depth == 1 and shape_len != 0 and callable(getattr(batch_field[0], 'tolist', None)):
+ if backend == 'raw':
+ return RawTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
+ elif backend == 'numpy':
+ return NumpyTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
+ elif backend == 'torch':
+ # 这里 ele_dtype 传入为 None 的原因是防止出现 paddle tensor 转换为 torch tensor
+ return TorchTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
+ elif backend == 'paddle':
+ return PaddleTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
+ elif backend == 'jittor':
+ return JittorTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
+ elif backend == 'oneflow':
+ return OneflowTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
+ else:
+ raise ValueError(f"backend={backend} is not supported for tensors(Field:{field_name}).")
+
+ if shape_len != 0 and depth>1:
+ msg = "Does not support pad tensor under nested list. If you need this, please report."
+ if must_pad:
+ raise RuntimeError(msg)
+ raise NoProperPadderError(msg)
+
+ except DtypeError as e:
+ msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \
+ "information please set logger's level to DEBUG."
+ if must_pad:
+ logger.error(msg)
+ raise type(e)(msg=msg)
+
+ except NoProperPadderError as e:
+ logger.debug(f"{e.msg}")
+
+ except BaseException as e:
+ raise e
+
+ return NullPadder()
+
+
+class HasShapeDtype(ABC):
+ """
+ 检测拥有 shape 和 dtype 属性的对象。一般就是 np.ndarray 或者各类 tensor 。
+
+ """
+
+ @classmethod
+ def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
+ if cls is HasShapeDtype:
+ if hasattr(subclass, 'shape') and hasattr(subclass, 'dtype'):
+ return True
+ return False
+ return NotImplemented
+
+
+def _get_element_shape_dtype(content, parent=None, catalog=None)->Dict:
+ """
+ 获取对象的中 element 的基本信息,用于判断是否可以 padding。
+
+ :param content:
+ :param tuple parent:
+ :param dict catalog: 记录元素信息的 dict。其中的 index 记录的是每一个元素的 拓扑 结构。
+ 例如: [1, 2, 3] -> {(0,): ((), ), (1,): ((), ), (2,): ((), )}
+ 例如: [1, [2, 3], 4] -> {(0,): ((), ), (1, 0): ((), ), (1, 1): ((), ), (2,): ((), )}
+ 例如: [[1, 2], [3], [4, 5]] -> {(0, 0): ((), ), (0, 1): ((), ), (1, 0): ((), ), (2, 0): ((), ), (2, 1): ((), )}
+ 例如: [torch.ones(3, 4), torch.ones(3, 4), torch.ones(3, 4)]
+ -> {(0,): (torch.Size([3, 4]), torch.float32), (1,): (torch.Size([3, 4]), torch.float32), (2,): (torch.Size([3, 4]), torch.float32)}
+
+ :return:
+ """
+ if catalog is None:
+ catalog = {}
+
+ if parent is None:
+ parent = ()
+
+ if isinstance(content, HasShapeDtype): # 各类 tensor 或者 np.ndarray
+ shape = content.shape
+ dtype = content.dtype
+ catalog[parent] = (shape, dtype)
+ elif isinstance(content, (tuple, list)):
+ for i, c in enumerate(content):
+ _get_element_shape_dtype(c, parent=parent + (i,), catalog=catalog)
+ else: # 包括 int/float/bool/dict 以及 其它无法pad 的等
+ catalog[parent] = ((), type(content)) # () 表示 shape 的长度为 0,后面表示其类别
+ return catalog
diff --git a/fastNLP/core/collators/padders/jittor_padder.py b/fastNLP/core/collators/padders/jittor_padder.py
new file mode 100644
index 00000000..6b37d61c
--- /dev/null
+++ b/fastNLP/core/collators/padders/jittor_padder.py
@@ -0,0 +1,219 @@
+__all__ = [
+ 'JittorNumberPadder',
+ 'JittorSequencePadder',
+ 'JittorTensorPadder'
+]
+
+from inspect import isclass
+import numpy as np
+
+from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
+
+if _NEED_IMPORT_JITTOR:
+ import jittor
+
+ numpy_to_jittor_dtype_dict = {
+ np.bool_: 'bool',
+ np.uint8: 'uint8',
+ np.int8: "int8",
+ np.int16: "int16",
+ np.int32: "int32",
+ np.int64: "int64",
+ np.float16: "float16",
+ np.float32: 'float32',
+ np.float64: 'float32', # 这里都统一为到 float32 吧,这是由于 numpy 大部分时候都默认 float64 了
+ }
+ # number_to_jittor_dtype_dict = {
+ # float: 'float32', # 因为 paddle.tensor([1], dtype=float)是paddle.float64
+ # int: 'int64',
+ # bool: 'bool'
+ # }
+
+from .padder import Padder
+from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, get_shape, is_numpy_generic_class
+from .exceptions import *
+
+
+def is_jittor_tensor(dtype):
+ if not isclass(dtype) and isinstance(dtype, jittor.jittor_core.Var):
+ return True
+ return False
+
+
+def is_jittor_dtype_str(dtype):
+ """
+ 判断数据类型是否为 jittor 使用的字符串类型
+
+ :param: dtype 数据类型
+ """
+ try:
+ if isinstance(dtype, str) and dtype in {'bool', 'float16', 'uint16', 'float32', 'float64', 'int8',
+ 'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128',
+ u'bool', u'float16', u'uint16', u'float32', u'float64', u'int8',
+ u'int16', u'int32', u'int64', u'uint8'}:
+ return True
+ except:
+ pass
+ return False
+
+
+def _get_dtype(ele_dtype, dtype, class_name):
+ """
+ 用于检测数据的 dtype 类型, 根据内部和外部数据判断。
+
+ :param ele_dtype 内部数据的类型
+ :param dtype 数据外部类型
+ :param class_name 类的名称
+ """
+ if not (ele_dtype is None or (
+ is_number_or_numpy_number(ele_dtype) or is_jittor_tensor(ele_dtype) or is_jittor_dtype_str(dtype))):
+ raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
+ f"or numpy numbers or jittor.Var but get `{ele_dtype}`.")
+
+ if dtype is not None:
+ if not (is_jittor_tensor(dtype) or is_number(dtype) or is_jittor_dtype_str(dtype)):
+ raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers "
+ f"or jittor.dtype but get `{dtype}`.")
+ else:
+ if is_numpy_generic_class(ele_dtype):
+ dtype = numpy_to_jittor_dtype_dict.get(ele_dtype)
+ else:
+ dtype = ele_dtype
+
+ return dtype
+
+
+class JittorNumberPadder(Padder):
+ """
+ 可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``jittor.Var([1, 2, 3])``
+
+ :param pad_val: 该值无意义
+ :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`jittor.Var` 类型;
+ :param dtype: 输出的数据的 dtype 是什么。如 :class:`jittor.long`, :class:`jittor.float32`, :class:`int`, :class:`float` 等;
+ """
+ def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
+ dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
+ super().__init__(pad_val=pad_val, dtype=dtype)
+
+ @staticmethod
+ def pad(batch_field, pad_val=0, dtype=None):
+ """
+ :param batch_field 输入的某个 field 的 batch 数据。
+ :param pad_val 需要填充的值
+ :dtype 数据的类型
+ """
+ return jittor.Var(np.array(batch_field, dtype=dtype))
+
+
+class JittorSequencePadder(Padder):
+ """
+ 可以将形如 ``[[1], [1, 2]]`` 这类的数据转为 ``jittor.Var([[1], [1, 2]])``
+
+ :param pad_val: 该值无意义
+ :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`jittor.Var` 类型;
+ :param dtype: 输出的数据的 dtype 是什么。如 :class:`jittor.long`, :class:`jittor.float32`, :class:`int`, :class:`float` 等;
+ """
+ def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
+ dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
+ super().__init__(pad_val=pad_val, dtype=dtype)
+
+ @staticmethod
+ def pad(batch_field, pad_val=0, dtype=None):
+ """
+ :param batch_field: 输入的某个 field 的 batch 数据。
+ :param pad_val: 需要填充的值
+ :param dtype: 数据的类型
+ """
+ tensor = get_padded_jittor_tensor(batch_field, dtype=dtype, pad_val=pad_val)
+ return tensor
+
+
+class JittorTensorPadder(Padder):
+ def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
+ """
+ 目前支持 ``[jittor.Var([3, 2], jittor.Var([1])]`` 类似的输入。若内部元素不为 :class:`jittor.Var` ,则必须含有 :meth:`tolist` 方法。
+
+ :param pad_val: 需要 pad 的值;
+ :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`jittor.Var` 类型;
+ :param dtype: 输出的数据的 dtype 是什么。如 :class:`jittor.long`, :class:`jittor.float32`, :class:`int`, :class:`float` 等
+ """
+ dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
+ super().__init__(pad_val=pad_val, dtype=dtype)
+
+ @staticmethod
+ def pad(batch_field, pad_val=0, dtype=None):
+ """
+ 将 ``batch_field`` 数据 转为 :class:`jittor.Var` 并 pad 到相同长度。
+
+ :param batch_field: 输入的某个 field 的 batch 数据。
+ :param pad_val: 需要填充的值
+ :param dtype: 数据的类型
+ """
+ try:
+ if not isinstance(batch_field[0], jittor.Var):
+ batch_field = [jittor.Var(np.array(field.tolist(), dtype=dtype)) for field in batch_field]
+ except AttributeError:
+ raise RuntimeError(f"If the field is not a jittor.Var (it is {type(batch_field[0])}), "
+ f"it must have tolist() method.")
+
+ shapes = [field.shape for field in batch_field]
+ if len(batch_field) < 2:
+ max_shape = [len(batch_field)] + list(shapes[0])
+ else:
+ max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
+
+ tensor = jittor.full(max_shape, pad_val, dtype=dtype)
+ for i, field in enumerate(batch_field):
+ slices = (i,) + tuple(slice(0, s) for s in shapes[i])
+ tensor[slices] = field
+ return tensor
+
+
+def fill_tensor(batch_field, padded_batch, dtype):
+ """
+ 将 batch_field 中的值填入到 tensor 中。
+
+ :param batch_field: 需要填充进入 array 中的内容
+ :param padded_batch: 待填充的 tensor
+ :param dtype: 数据的类别
+
+ :return:
+ """
+ if padded_batch.ndim == 2:
+ for i, content_i in enumerate(batch_field):
+ padded_batch[i, :len(content_i)] = jittor.Var(np.array(content_i, dtype=dtype))
+ elif padded_batch.ndim == 3:
+ for i, content_i in enumerate(batch_field):
+ for j, content_ii in enumerate(content_i):
+ padded_batch[i, j, :len(content_ii)] = jittor.Var(np.array(content_ii, dtype=dtype))
+ elif padded_batch.ndim == 4:
+ try: # 应该是图像,所以直接应该就 ok 了。
+ padded_batch = jittor.Var(batch_field)
+ except:
+ for i, content_i in enumerate(batch_field):
+ for j, content_ii in enumerate(content_i):
+ for k, content_iii in enumerate(content_ii):
+ padded_batch[i, j, k, :len(content_iii)] = jittor.Var(np.array(content_iii, dtype=dtype))
+ elif padded_batch.ndim == 1:
+ padded_batch[:] = jittor.Var(np.array(batch_field, dtype=dtype))
+ else:
+ raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please "
+ "report.")
+ return padded_batch
+
+
+def get_padded_jittor_tensor(batch_field, dtype=None, pad_val=0):
+ """
+ 例如:
+ [[1,2], [3]] -> jittor.LongTensor([[1, 2], [3, 0]])
+
+ :param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列)
+ /4d(多为图片)。
+ :param dtype: 目标类别是什么
+ :param pad_val: pad 的 value
+ :return:
+ """
+ shapes = get_shape(batch_field)
+ tensor = jittor.full(shapes, pad_val, dtype=dtype)
+ tensor = fill_tensor(batch_field, tensor, dtype=dtype)
+ return tensor
diff --git a/fastNLP/core/collators/padders/numpy_padder.py b/fastNLP/core/collators/padders/numpy_padder.py
new file mode 100644
index 00000000..2f386978
--- /dev/null
+++ b/fastNLP/core/collators/padders/numpy_padder.py
@@ -0,0 +1,138 @@
+__all__ = [
+ 'NumpyNumberPadder',
+ 'NumpySequencePadder',
+ "NumpyTensorPadder"
+]
+
+from numbers import Number
+from abc import ABC
+from typing import Any, Union
+import numpy as np
+
+from .padder import Padder
+from .utils import get_padded_numpy_array, is_number_or_numpy_number
+from .exceptions import *
+
+
+def _get_dtype(ele_dtype, dtype, class_name):
+ """
+ 用于检测数据的 dtype 类型, 根据内部和外部数据判断。
+
+ :param ele_dtype: 内部数据的类型
+ :param dtype: 数据外部类型
+ :param class_name: 类的名称
+ """
+ if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype):
+ raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
+ f"or numpy numbers but get `{ele_dtype}`.")
+
+ if dtype is None:
+ dtype = ele_dtype
+ else:
+ if not is_number_or_numpy_number(dtype):
+ raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers "
+ f"or numpy numbers but get `{dtype}`.")
+ dtype = dtype
+ return dtype
+
+
+class NumpyNumberPadder(Padder):
+ """
+ 可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``np.array([1, 2, 3])`` 。可以通过::
+
+ >>> NumpyNumberPadder.pad([1, 2, 3])
+
+ 使用。
+
+ :param pad_val: 该值无意义;
+ :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`np.array` 类型;
+ :param dtype: 输出的数据的 dtype ;
+ """
+ def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
+ dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
+ super().__init__(pad_val=pad_val, dtype=dtype)
+
+ @staticmethod
+ def pad(batch_field, pad_val=0, dtype=None):
+ """
+ 将 ``batch_field`` 数据 转为 :class:`numpy.array` 并 pad 到相同长度。
+
+ :param batch_field: 输入的某个 field 的 batch 数据。
+ :param pad_val: 需要填充的值
+ :param dtype: 数据的类型
+ """
+ return np.array(batch_field, dtype=dtype)
+
+
+class NumpySequencePadder(Padder):
+ """
+ 将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``np.array([[1, 0], [1, 2]])``, 可以 pad 多重嵌套的数据。
+ 可以通过以下的方式直接使用:
+
+ >>> NumpySequencePadder.pad([[1], [1, 2]], pad_val=-100, dtype=float)
+ [[ 1. -100.]
+ [ 1. 2.]]
+
+ :param pad_val: pad 的值是多少;
+ :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`np.array` 类型;
+ :param dtype: 输出的数据的 dtype ;
+ """
+ def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
+ dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
+ super().__init__(pad_val=pad_val, dtype=dtype)
+
+ @staticmethod
+ def pad(batch_field, pad_val=0, dtype=None):
+ """
+ 将 ``batch_field`` 数据 转为 :class:`numpy.array` 并 pad 到相同长度。
+
+ :param batch_field: 输入的某个 field 的 batch 数据。
+ :param pad_val: 需要填充的值
+ :param dtype: 数据的类型
+ """
+ return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val)
+
+
+class NumpyTensorPadder(Padder):
+ """
+ pad 类似于 ``[np.array([3, 4]), np.array([1])]`` 的 field 。若内部元素不为 :class:`np.ndarray` ,则必须含有 :meth:`tolist` 方法。
+
+ >>> NumpyTensorPadder.pad([np.array([3, 4]), np.array([1])], pad_val=-100)
+ [[ 3. 4.]
+ [ 1. -100.]]
+ :param pad_val: pad 的值是多少。
+ :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`np.array` 类型。
+ :param dtype: 输出的数据的 dtype 是什么
+ """
+ def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
+ dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
+ super().__init__(pad_val=pad_val, dtype=dtype)
+
+ @staticmethod
+ def pad(batch_field, pad_val=0, dtype=None):
+ """
+ 将 ``batch_field`` 数据 转为 :class:`numpy.array` 并 pad 到相同长度。
+
+ :param batch_field: 输入的某个 field 的 batch 数据。
+ :param pad_val: 需要填充的值
+ :param dtype: 数据的类型
+ """
+ try:
+ if not isinstance(batch_field[0], np.ndarray):
+ batch_field = [np.array(field.tolist(), dtype=dtype) for field in batch_field]
+ except AttributeError:
+ raise RuntimeError(f"If the field is not a np.ndarray (it is {type(batch_field[0])}), "
+ f"it must have tolist() method.")
+
+ shapes = [field.shape for field in batch_field]
+ if len(batch_field) < 2:
+ max_shape = [len(batch_field)] + list(shapes[0])
+ else:
+ max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
+
+ array = np.full(max_shape, fill_value=pad_val, dtype=dtype)
+ for i, field in enumerate(batch_field):
+ slices = (i, ) + tuple(slice(0, s) for s in shapes[i])
+ array[slices] = field
+ return array
+
diff --git a/fastNLP/core/collators/padders/oneflow_padder.py b/fastNLP/core/collators/padders/oneflow_padder.py
new file mode 100644
index 00000000..a990e87f
--- /dev/null
+++ b/fastNLP/core/collators/padders/oneflow_padder.py
@@ -0,0 +1,226 @@
+__all__ = [
+ 'OneflowNumberPadder',
+ 'OneflowSequencePadder',
+ 'OneflowTensorPadder'
+]
+from inspect import isclass
+import numpy as np
+
+from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW
+from fastNLP.envs.utils import _module_available
+
+if _NEED_IMPORT_ONEFLOW:
+ import oneflow
+ numpy_to_oneflow_dtype_dict = {
+ np.bool_: oneflow.bool,
+ np.uint8: oneflow.uint8,
+ np.int8: oneflow.int8,
+ np.int32: oneflow.int32,
+ np.int64: oneflow.int64,
+ np.float16: oneflow.float16,
+ np.float32: oneflow.float32,
+ np.float64: oneflow.float32, # 这里都统一为到 float32 吧,这是由于 numpy 大部分时候都默认 float64 了
+ }
+ number_to_oneflow_dtype_dict = {
+ float: oneflow.float32, # 因为 oneflow.tensor([1], dtype=float)是oneflow.float64
+ int: oneflow.int64,
+ bool: oneflow.bool
+ }
+
+from .padder import Padder
+from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, get_shape, is_numpy_generic_class
+from .exceptions import *
+
+
+def is_oneflow_tensor(dtype):
+ """
+ 判断是否为 oneflow 的 tensor
+
+ :param dtype 数据的 dtype 类型
+ """
+ if not isclass(dtype) and isinstance(dtype, oneflow.dtype):
+ return True
+ return False
+
+
+def _get_dtype(ele_dtype, dtype, class_name):
+ """
+ 用于检测数据的 dtype 类型, 根据内部和外部数据判断。
+
+ :param ele_dtype: 内部数据的类型
+ :param dtype: 数据外部类型
+ :param class_name: 类的名称
+ """
+ if not (ele_dtype is None or (is_number_or_numpy_number(ele_dtype) or is_oneflow_tensor(ele_dtype))):
+ raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
+ f"or numpy numbers or oneflow.Tensor but get `{ele_dtype}`.")
+
+ if dtype is not None:
+ if not (is_oneflow_tensor(dtype) or is_number(dtype)):
+ raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers "
+ f"or oneflow.dtype but get `{dtype}`.")
+ dtype = number_to_oneflow_dtype_dict.get(dtype, dtype)
+ else:
+ if ele_dtype is not None:
+ if (is_number(ele_dtype) or is_oneflow_tensor(ele_dtype)):
+ ele_dtype = number_to_oneflow_dtype_dict.get(ele_dtype, ele_dtype)
+ dtype = ele_dtype
+ elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了
+ dtype = numpy_to_oneflow_dtype_dict.get(ele_dtype.type)
+ elif is_numpy_generic_class(ele_dtype):
+ dtype = numpy_to_oneflow_dtype_dict.get(ele_dtype)
+
+ return dtype
+
+
+class OneflowNumberPadder(Padder):
+ """
+ 可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``oneflow.Tensor([1, 2, 3])``。
+
+ :param pad_val: 该值无意义;
+ :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`oneflow.Tensor` 类型;
+ :param dtype: 输出的数据的 dtype,。如 :class:`oneflow.long`, :class:`oneflow.float32`, :class:`int`, :class:`float` 等;
+ """
+ def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
+ dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
+ super().__init__(pad_val=pad_val, dtype=dtype)
+
+ @staticmethod
+ def pad(batch_field, pad_val=0, dtype=None):
+ """
+ 将 ``batch_field`` 数据 转为 :class:`oneflow.Tensor` 并 pad 到相同长度。
+
+ :param batch_field: 输入的某个 field 的 batch 数据。
+ :param pad_val: 需要填充的值
+ :param dtype: 数据的类型
+ """
+ return oneflow.tensor(batch_field, dtype=dtype)
+
+
+class OneflowSequencePadder(Padder):
+ """
+ 将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``oneflow.Tensor([[1, 0], [1, 2]])``, 可以 pad 多重嵌套的数据。
+
+ :param pad_val: 需要 pad 的值;
+ :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`oneflow.Tensor` 类型;
+ :param type: 输出的数据的 dtype,。如 :class:`oneflow.long`, :class:`oneflow.float32`, :class:`int`, :class:`float` 等;
+ """
+ def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
+ dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
+ super().__init__(pad_val=pad_val, dtype=dtype)
+
+ @staticmethod
+ def pad(batch_field, pad_val=0, dtype=None):
+ """
+ 将 ``batch_field`` 数据 转为 :class:`oneflow.Tensor` 并 pad 到相同长度。
+
+ :param batch_field: 输入的某个 field 的 batch 数据。
+ :param pad_val: 需要填充的值
+ :param dtype: 数据的类型
+ """
+ tensor = get_padded_oneflow_tensor(batch_field, dtype=dtype, pad_val=pad_val)
+ return tensor
+
+
+class OneflowTensorPadder(Padder):
+ """
+ 目前支持 ``[oneflow.tensor([3, 2], oneflow.tensor([1])]`` 类似的输入,若内部元素不为 :class:`oneflow.Tensor` ,则必须含有 :meth:`tolist` 方法。
+
+ >>> OneflowTensorPadder.pad([np.array([3, 4]), np.array([1])], pad_val=-100)
+ [[ 3. 4.]
+ [ 1. -100.]]
+ >>> OneflowTensorPadder.pad([oneflow.LongTensor([3, 4]), oneflow.LongTensor([1])], pad_val=-100)
+ tensor([[ 3, 4],
+ [ 1, -100]])
+
+ :param pad_val: 需要 pad 的值。
+ :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`oneflow.Tensor` 类型。
+ :param dtype: 输出的数据的 dtype,。如 :class:`oneflow.long`, :class:`oneflow.float32`, :class:`int`, :class:`float` 等;
+ """
+ def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
+ dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
+ super().__init__(pad_val=pad_val, dtype=dtype)
+
+ @staticmethod
+ def pad(batch_field, pad_val=0, dtype=None):
+ """
+ 将 ``batch_field`` 数据 转为 :class:`oneflow.Tensor` 并 pad 到相同长度。
+
+ :param batch_field: 输入的某个 field 的 batch 数据。
+ :param pad_val: 需要填充的值
+ :param dtype: 数据的类型
+ """
+ device = None
+ try:
+ if not isinstance(batch_field[0], oneflow.Tensor):
+ batch_field = [oneflow.tensor(field.tolist(), dtype=dtype) for field in batch_field]
+ else:
+ batch_field = [field.to(dtype) for field in batch_field]
+ device = batch_field[0].device
+ if dtype is None:
+ dtype = batch_field[0].dtype
+ except AttributeError:
+ raise RuntimeError(f"If the field is not a oneflow.Tensor (it is {type(batch_field[0])}), "
+ f"it must have tolist() method.")
+
+ shapes = [field.shape for field in batch_field]
+ if len(batch_field) < 2:
+ max_shape = [len(batch_field)] + list(shapes[0])
+ else:
+ max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
+
+ tensor = oneflow.full(max_shape, pad_val, dtype, device)
+ for i, field in enumerate(batch_field):
+ slices = (i, ) + tuple(slice(0, s) for s in shapes[i])
+ tensor[slices] = field
+ return tensor
+
+
+def fill_tensor(batch_field, padded_batch, dtype):
+ """
+ 将 batch_field 中的值填入到 tensor 中。
+
+ :param batch_field: 需要填充进入 array 中的内容
+ :param padded_batch: 待填充的 tensor
+ :param dtype: 数据的类别
+
+ :return:
+ """
+ if padded_batch.ndim == 2:
+ for i, content_i in enumerate(batch_field):
+ padded_batch[i, :len(content_i)] = oneflow.tensor(content_i, dtype=dtype)
+ elif padded_batch.ndim == 3:
+ for i, content_i in enumerate(batch_field):
+ for j, content_ii in enumerate(content_i):
+ padded_batch[i, j, :len(content_ii)] = oneflow.tensor(content_ii, dtype=dtype)
+ elif padded_batch.ndim == 4:
+ try: # 应该是图像,所以直接应该就 ok 了。
+ padded_batch = oneflow.tensor(batch_field)
+ except:
+ for i, content_i in enumerate(batch_field):
+ for j, content_ii in enumerate(content_i):
+ for k, content_iii in enumerate(content_ii):
+ padded_batch[i, j, k, :len(content_iii)] = oneflow.tensor(content_iii, dtype=dtype)
+ elif padded_batch.ndim == 1:
+ padded_batch[:] = oneflow.tensor(batch_field, dtype=dtype)
+ else:
+ raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please "
+ "report.")
+ return padded_batch
+
+
+def get_padded_oneflow_tensor(batch_field, dtype=None, pad_val=0):
+ """
+ 例如:
+ [[1,2], [3]] -> oneflow.LongTensor([[1, 2], [3, 0]])
+
+ :param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列)
+ /4d(多为图片)。
+ :param dtype: 目标类别是什么
+ :param pad_val: pad 的 value
+ :return:
+ """
+ shapes = get_shape(batch_field)
+ tensor = oneflow.full(shapes, pad_val, dtype)
+ tensor = fill_tensor(batch_field, tensor, dtype=dtype)
+ return tensor
diff --git a/fastNLP/core/collators/padders/padder.py b/fastNLP/core/collators/padders/padder.py
new file mode 100644
index 00000000..783d8fa2
--- /dev/null
+++ b/fastNLP/core/collators/padders/padder.py
@@ -0,0 +1,32 @@
+
+class Padder:
+ """
+ 所有 **Padder** 对象的父类,所有的 Padder 对象都会实现静态函数 ``pad(batch_field, pad_val=0, dtype=None)`` 。
+
+ """
+ def __init__(self, pad_val, dtype):
+ self.pad_val = pad_val
+ self.dtype = dtype
+
+ def __call__(self, batch_field):
+ return self.pad(batch_field=batch_field, pad_val=self.pad_val, dtype=self.dtype)
+
+ @staticmethod
+ def pad(batch_field, pad_val=0, dtype=None):
+ raise NotImplementedError()
+
+
+class NullPadder(Padder):
+ """
+ 不进行任何 检查 与 pad 的空 padder 。
+
+ :param ele_dtype:
+ :param pad_val:
+ :param dtype:
+ """
+ def __init__(self, ele_dtype=None, pad_val=None, dtype=None):
+ super().__init__(pad_val=pad_val, dtype=dtype)
+
+ def __call__(self, batch_field):
+ # 直接返回,不调用 pad() 方法加快速度。
+ return batch_field
diff --git a/fastNLP/core/collators/padders/paddle_padder.py b/fastNLP/core/collators/padders/paddle_padder.py
new file mode 100644
index 00000000..57d31967
--- /dev/null
+++ b/fastNLP/core/collators/padders/paddle_padder.py
@@ -0,0 +1,217 @@
+__all__ = [
+ "PaddleNumberPadder",
+ "PaddleTensorPadder",
+ "PaddleSequencePadder"
+]
+from inspect import isclass
+import numpy as np
+
+from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
+
+if _NEED_IMPORT_PADDLE:
+ import paddle
+ numpy_to_paddle_dtype_dict = {
+ np.bool_: 'bool',
+ np.uint8: 'uint8',
+ np.int8: "int8",
+ np.int16: "int16",
+ np.int32: "int32",
+ np.int64: "int64",
+ np.float16: "float16",
+ np.float32: 'float32',
+ np.float64: 'float32', # 这里都统一为到 float32 吧,这是由于 numpy 大部分时候都默认 float64 了
+ np.complex64: 'complex64',
+ np.complex128: "complex128"
+ }
+ number_to_paddle_dtype_dict = {
+ float: 'float32', # 因为 paddle.tensor([1], dtype=float)是paddle.float64
+ int: 'int64',
+ bool: 'bool'
+ }
+
+from .padder import Padder
+from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, is_numpy_generic_class, \
+ get_padded_numpy_array
+from .exceptions import *
+
+
+def is_paddle_tensor(dtype):
+ """
+ 判断 dtype 是否为 paddle 的 tensor
+
+ :param dtype 数据的 dtype 类型
+ """
+ if not isclass(dtype) and isinstance(dtype, paddle.dtype):
+ return True
+
+ return False
+
+
+def is_paddle_dtype_str(dtype):
+ """
+ 判断 dtype 是 str 类型 且属于 paddle 支持的 str 类型
+
+ :param dtype 数据的 dtype 类型
+ """
+
+ try:
+ if isinstance(dtype, str) and dtype in {'bool', 'float16', 'uint16', 'float32', 'float64', 'int8',
+ 'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128',
+ u'bool', u'float16', u'uint16', u'float32', u'float64', u'int8',
+ u'int16', u'int32', u'int64', u'uint8', u'complex64',
+ u'complex128'}:
+ return True
+ except:
+ pass
+ return False
+
+
+def _get_dtype(ele_dtype, dtype, class_name):
+ """
+ 用于检测数据的 dtype 类型, 根据内部和外部数据判断。
+
+ :param ele_dtype 内部数据的类型
+ :param dtype 数据外部类型
+ :param class_name 类的名称
+ """
+ if not (ele_dtype is None or is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)):
+ raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
+ f"or numpy numbers or paddle.Tensor but get `{ele_dtype}`.")
+
+ if dtype is not None:
+ if not (is_paddle_tensor(dtype) or is_number(dtype) or is_paddle_dtype_str(dtype)):
+ raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers "
+ f"or paddle.dtype but get `{dtype}`.")
+ dtype = number_to_paddle_dtype_dict.get(dtype, dtype)
+ else:
+ if (is_number(ele_dtype) or is_paddle_tensor(ele_dtype)):
+ ele_dtype = number_to_paddle_dtype_dict.get(ele_dtype, ele_dtype)
+ dtype = ele_dtype
+ elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了
+ dtype = numpy_to_paddle_dtype_dict.get(ele_dtype.type)
+ elif is_numpy_generic_class(ele_dtype):
+ dtype = numpy_to_paddle_dtype_dict.get(ele_dtype)
+ else:
+ dtype = ele_dtype
+
+ return dtype
+
+
+class PaddleNumberPadder(Padder):
+ """
+ 可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``paddle.Tensor([1, 2, 3])``
+
+ :param pad_val: 该值无意义;
+ :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`paddle.Tensor` 类型;
+ :param dtype: 输出的数据的 dtype 是什么。如 :class:`int`, :class:`float`, :class:`int32` 等;
+ """
+ def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
+ # 仅当 ele_dtype 是 python number/ numpy number 或者 tensor
+ dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
+ super().__init__(pad_val=pad_val, dtype=dtype)
+
+ @staticmethod
+ def pad(batch_field, pad_val=0, dtype=None):
+ """
+ 将 ``batch_field`` 数据 转为 :class:`paddle.Tensor` 并 pad 到相同长度。
+
+ :param batch_field: 输入的某个 field 的 batch 数据。
+ :param pad_val: 需要填充的值
+ :param dtype: 数据的类型
+ """
+ return paddle.to_tensor(batch_field, dtype=dtype)
+
+
+class PaddleSequencePadder(Padder):
+ """
+ 将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``paddle.Tensor([[1, 0], [1, 2]])`` 可以 pad 多重嵌套的数据。
+
+ :param pad_val: pad 的值。
+ :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`paddle.Tensor` 类型;
+ :param dtype: 输出的数据的 dtype 是什么。如 :class:`int`, :class:`float`, :class:`int32` 等;
+ """
+ def __init__(self, ele_dtype=None, pad_val=0, dtype=None):
+ dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
+ super().__init__(pad_val=pad_val, dtype=dtype)
+
+ @staticmethod
+ def pad(batch_field, pad_val=0, dtype=None):
+ """
+ 将 ``batch_field`` 数据 转为 :class:`paddle.Tensor` 并 pad 到相同长度。
+
+ :param batch_field: 输入的某个 field 的 batch 数据。
+ :param pad_val: 需要填充的值
+ :param dtype: 数据的类型
+ """
+ tensor = get_padded_paddle_tensor(batch_field, dtype=dtype, pad_val=pad_val)
+ return tensor
+
+
+class PaddleTensorPadder(Padder):
+ """
+ 目前支持 ``[paddle.tensor([3, 2], paddle.tensor([2, 1])]`` 类似的输入,若内部元素不为 :class:`paddle.Tensor` ,则必须含有 :meth:`tolist` 方法。
+
+ >>> PaddleTensorPadder.pad([np.array([3, 4]), np.array([1])], pad_val=-100)
+ [[ 3. 4.]
+ [ 1. -100.]]
+ >>> PaddleTensorPadder.pad([paddle.to_tensor([3, 4]), paddle.to_tensor([1])], pad_val=-100)
+ tensor([[ 3, 4],
+ [ 1, -100]])
+ :param pad_val: pad 的值。
+ :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`paddle.Tensor` 类型;
+ :param dtype: 输出的数据的 dtype 是什么。如 :class:`int`, :class:`float`, :class:`int32` 等;
+ """
+ def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
+ dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
+ super().__init__(pad_val=pad_val, dtype=dtype)
+
+ @staticmethod
+ def pad(batch_field, pad_val=0, dtype=None):
+ """
+ 将 ``batch_field`` 数据 转为 :class:`paddle.Tensor` 并 pad 到相同长度。
+
+ :param batch_field: 输入的某个 field 的 batch 数据。
+ :param pad_val: 需要填充的值
+ :param dtype: 数据的类型
+ """
+ try:
+ if not isinstance(batch_field[0], paddle.Tensor):
+ batch_field = [np.array(field.tolist()) for field in batch_field]
+ else:
+ if dtype is None:
+ dtype = batch_field[0].dtype
+ except AttributeError:
+ raise RuntimeError(f"If the field is not a paddle.Tensor (it is {type(batch_field[0])}), "
+ f"it must have tolist() method.")
+
+ shapes = [field.shape for field in batch_field]
+ if len(batch_field) < 2:
+ max_shape = [len(batch_field)] + list(shapes[0])
+ else:
+ max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
+
+ if isinstance(batch_field[0], paddle.Tensor):
+ array = paddle.full(max_shape, fill_value=pad_val, dtype=dtype)
+ else:
+ array = np.full(max_shape, fill_value=pad_val, dtype=batch_field[0].dtype)
+ for i, field in enumerate(batch_field):
+ slices = (i, ) + tuple(slice(0, s) for s in shapes[i])
+ array[slices] = field
+ tensor = paddle.to_tensor(array, dtype=dtype)
+ return tensor
+
+
+def get_padded_paddle_tensor(batch_field, dtype=None, pad_val=0):
+ """
+ 例如:
+ [[1,2], [3]] -> paddle.LongTensor([[1, 2], [3, 0]])
+
+ :param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列)
+ /4d(多为图片)。
+ :param dtype: 目标类别是什么
+ :param pad_val: pad 的 value
+ :return:
+ """
+ array = get_padded_numpy_array(batch_field=batch_field, dtype=None, pad_val=pad_val)
+ tensor = paddle.to_tensor(array, dtype=dtype)
+ return tensor
diff --git a/fastNLP/core/collators/padders/raw_padder.py b/fastNLP/core/collators/padders/raw_padder.py
new file mode 100644
index 00000000..52ba6617
--- /dev/null
+++ b/fastNLP/core/collators/padders/raw_padder.py
@@ -0,0 +1,107 @@
+__all__ = [
+ "RawNumberPadder",
+ "RawSequencePadder",
+ "RawTensorPadder"
+]
+
+from .padder import Padder
+from .utils import is_number, get_padded_numpy_array, is_number_or_numpy_number
+from .exceptions import *
+
+
+def _get_dtype(ele_dtype, dtype, class_name):
+ """
+ 用于检测数据的 dtype 类型, 根据内部和外部数据判断。
+
+ :param ele_dtype: 内部数据的类型
+ :param dtype: 数据外部类型
+ :param class_name: 类的名称
+ """
+ if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype):
+ raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
+ f"or numpy numbers but get `{ele_dtype}`.")
+
+ if dtype is None:
+ dtype = ele_dtype
+ else:
+ if not is_number_or_numpy_number(dtype):
+ raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers "
+ f"or numpy numbers but get `{dtype}`.")
+ dtype = dtype
+
+ return dtype
+
+
+class RawNumberPadder(Padder):
+ """
+ 可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``[1, 2, 3]`` 。实际上该 padder 无意义。
+
+ :param pad_val:
+ :param ele_dtype:
+ :param dtype:
+ """
+ def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
+ dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
+ super().__init__(pad_val=pad_val, dtype=dtype)
+
+ def __call__(self, batch_field):
+ return batch_field
+
+ @staticmethod
+ def pad(batch_field, pad_val=0, dtype=None):
+ raise NotImplementedError()
+
+
+class RawSequencePadder(Padder):
+ """
+ 将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``[[1, 0], [1, 2]]`` 。可以 pad 多重嵌套的数据。
+
+ :param pad_val: pad 的值;
+ :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`np.array` 类型;
+ :param dtype: 输出的数据的 dtype ;
+ """
+ def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
+ dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
+ super().__init__(pad_val=pad_val, dtype=dtype)
+
+ @staticmethod
+ def pad(batch_field, pad_val=0, dtype=None):
+ """
+
+ :param batch_field: 输入的某个 field 的 batch 数据。
+ :param pad_val: 需要填充的值
+ :param dtype: 该参数无意义。
+ :return:
+ """
+ return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val).tolist()
+
+
+class RawTensorPadder(Padder):
+ """
+ 将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``[[1, 0], [1, 2]]`` 。可以 pad 多重嵌套的数据。
+
+ :param pad_val: pad 的值;
+ :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`np.array` 类型;
+ :param dtype: 输出的数据的 dtype ;
+ """
+ def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
+ dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
+ super().__init__(pad_val=pad_val, dtype=dtype)
+
+ @staticmethod
+ def pad(batch_field, pad_val=0, dtype=None):
+ """
+
+ :param batch_field: 输入的某个 field 的 batch 数据。
+ :param pad_val: 需要填充的值
+ :param dtype: 该参数无意义。
+ :return:
+ """
+ try:
+ if not isinstance(batch_field[0], (list, tuple)):
+ batch_field = [field.tolist() for field in batch_field]
+ except AttributeError:
+ raise RuntimeError(f"If the field is not a list or tuple(it is {type(batch_field[0])}), "
+ f"it must have tolist() method.")
+
+ return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val).tolist()
diff --git a/fastNLP/core/collators/padders/torch_padder.py b/fastNLP/core/collators/padders/torch_padder.py
new file mode 100644
index 00000000..a5ab9149
--- /dev/null
+++ b/fastNLP/core/collators/padders/torch_padder.py
@@ -0,0 +1,206 @@
+__all__ = [
+ 'TorchNumberPadder',
+ 'TorchSequencePadder',
+ 'TorchTensorPadder'
+]
+from inspect import isclass
+import numpy as np
+
+from fastNLP.envs.imports import _NEED_IMPORT_TORCH
+
+if _NEED_IMPORT_TORCH:
+ import torch
+ numpy_to_torch_dtype_dict = {
+ np.bool_: torch.bool,
+ np.uint8: torch.uint8,
+ np.int8: torch.int8,
+ np.int16: torch.int16,
+ np.int32: torch.int32,
+ np.int64: torch.int64,
+ np.float16: torch.float16,
+ np.float32: torch.float32,
+ np.float64: torch.float32, # 这里都统一为到 float32 吧,这是由于 numpy 大部分时候都默认 float64 了
+ np.complex64: torch.complex64,
+ np.complex128: torch.complex128
+ }
+ number_to_torch_dtype_dict = {
+ float: torch.float32, # 因为 torch.tensor([1], dtype=float)是torch.float64
+ int: torch.int64,
+ bool: torch.bool
+ }
+
+from .padder import Padder
+from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, get_shape, is_numpy_generic_class
+from .exceptions import *
+
+
+def is_torch_tensor(dtype):
+ """
+ 判断是否为 torch 的 tensor
+
+ :param dtype: 数据的 dtype 类型
+ """
+ if not isclass(dtype) and isinstance(dtype, torch.dtype):
+ return True
+ return False
+
+
+def _get_dtype(ele_dtype, dtype, class_name):
+ """
+ 用于检测数据的 dtype 类型, 根据内部和外部数据判断。
+
+ :param ele_dtype: 内部数据的类型
+ :param dtype: 数据外部类型
+ :param class_name: 类的名称
+ """
+ if not (ele_dtype is None or (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))):
+ raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
+ f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.")
+
+ if dtype is not None:
+ if not (is_torch_tensor(dtype) or is_number(dtype)):
+ raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers "
+ f"or torch.dtype but get `{dtype}`.")
+ dtype = number_to_torch_dtype_dict.get(dtype, dtype)
+ else:
+ if ele_dtype is not None:
+ if (is_number(ele_dtype) or is_torch_tensor(ele_dtype)):
+ ele_dtype = number_to_torch_dtype_dict.get(ele_dtype, ele_dtype)
+ dtype = ele_dtype
+ elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了
+ dtype = numpy_to_torch_dtype_dict.get(ele_dtype.type)
+ elif is_numpy_generic_class(ele_dtype):
+ dtype = numpy_to_torch_dtype_dict.get(ele_dtype)
+
+ return dtype
+
+
+class TorchNumberPadder(Padder):
+ """
+ 可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``torch.Tensor([1, 2, 3])``
+
+ :param pad_val: 该值无意义;
+ :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`torch.Tensor` 类型;
+ :param dtype: 输出的数据的 dtype 是什么。如 :class:`torch.long`, :class:`torch.float32`, :class:`int`, :class:`float` 等;
+ """
+ def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
+ dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
+ super().__init__(pad_val=pad_val, dtype=dtype)
+
+ @staticmethod
+ def pad(batch_field, pad_val=0, dtype=None):
+ return torch.tensor(batch_field, dtype=dtype)
+
+
+class TorchSequencePadder(Padder):
+ """
+ 将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``torch.Tensor([[1, 0], [1, 2]])`` 可以 pad 多重嵌套的数据。
+
+ :param pad_val: 需要 pad 的值;
+ :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`torch.Tensor` 类型;
+ :param dtype: 输出的数据的 dtype 是什么。如 :class:`torch.long`, :class:`torch.float32`, :class:`int`, :class:`float` 等;
+ """
+ def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
+ dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
+ super().__init__(pad_val=pad_val, dtype=dtype)
+
+ @staticmethod
+ def pad(batch_field, pad_val=0, dtype=None):
+ tensor = get_padded_torch_tensor(batch_field, dtype=dtype, pad_val=pad_val)
+ return tensor
+
+
+class TorchTensorPadder(Padder):
+ """
+ 目前支持 ``[torch.tensor([3, 2], torch.tensor([1])]`` 类似的输入。若内部元素不为 :class:`torch.Tensor` ,则必须含有 :meth:`tolist` 方法。
+
+ >>> TorchTensorPadder.pad([np.array([3, 4]), np.array([1])], pad_val=-100)
+ [[ 3. 4.]
+ [ 1. -100.]]
+ >>> TorchTensorPadder.pad([torch.LongTensor([3, 4]), torch.LongTensor([1])], pad_val=-100)
+ tensor([[ 3, 4],
+ [ 1, -100]])
+
+ :param pad_val: 需要 pad 的值;
+ :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`torch.Tensor` 类型;
+ :param dtype: 输出的数据的 dtype 是什么。如 :class:`torch.long`, :class:`torch.float32`, :class:`int`, :class:`float` 等;
+ """
+ def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
+ dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
+ super().__init__(pad_val=pad_val, dtype=dtype)
+
+ @staticmethod
+ def pad(batch_field, pad_val=0, dtype=None):
+ device = None
+ try:
+ if not isinstance(batch_field[0], torch.Tensor):
+ batch_field = [torch.tensor(field.tolist(), dtype=dtype) for field in batch_field]
+ else:
+ device = batch_field[0].device
+ if dtype is None:
+ dtype = batch_field[0].dtype
+ except AttributeError:
+ raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), "
+ f"it must have tolist() method.")
+
+ shapes = [field.shape for field in batch_field]
+ if len(batch_field) < 2:
+ max_shape = [len(batch_field)] + list(shapes[0])
+ else:
+ max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
+
+ tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype, device=device)
+ for i, field in enumerate(batch_field):
+ slices = (i, ) + tuple(slice(0, s) for s in shapes[i])
+ tensor[slices] = field
+ return tensor
+
+
+def fill_tensor(batch_field, padded_batch, dtype):
+ """
+ 将 batch_field 中的值填入到 tensor 中。
+
+ :param batch_field: 需要填充进入 array 中的内容
+ :param padded_batch: 待填充的 tensor
+ :param dtype: 数据的类别
+
+ :return:
+ """
+ if padded_batch.ndim == 2:
+ for i, content_i in enumerate(batch_field):
+ padded_batch[i, :len(content_i)] = torch.tensor(content_i, dtype=dtype)
+ elif padded_batch.ndim == 3:
+ for i, content_i in enumerate(batch_field):
+ for j, content_ii in enumerate(content_i):
+ padded_batch[i, j, :len(content_ii)] = torch.tensor(content_ii, dtype=dtype)
+ elif padded_batch.ndim == 4:
+ try: # 应该是图像,所以直接应该就 ok 了。
+ padded_batch = torch.tensor(batch_field)
+ except:
+ for i, content_i in enumerate(batch_field):
+ for j, content_ii in enumerate(content_i):
+ for k, content_iii in enumerate(content_ii):
+ padded_batch[i, j, k, :len(content_iii)] = torch.tensor(content_iii, dtype=dtype)
+ elif padded_batch.ndim == 1:
+ padded_batch[:] = torch.tensor(batch_field, dtype=dtype)
+ else:
+ raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please "
+ "report.")
+ return padded_batch
+
+
+def get_padded_torch_tensor(batch_field, dtype=None, pad_val=0):
+ """
+ 例如:
+ [[1,2], [3]] -> torch.LongTensor([[1, 2], [3, 0]])
+
+ :param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列)
+ /4d(多为图片)。
+ :param dtype: 目标类别是什么
+ :param pad_val: pad 的 value
+ :return:
+ """
+ shapes = get_shape(batch_field)
+ tensor = torch.full(shapes, dtype=dtype, fill_value=pad_val)
+ tensor = fill_tensor(batch_field, tensor, dtype=dtype)
+ return tensor
diff --git a/fastNLP/core/collators/padders/torch_utils.py b/fastNLP/core/collators/padders/torch_utils.py
new file mode 100644
index 00000000..d1887b36
--- /dev/null
+++ b/fastNLP/core/collators/padders/torch_utils.py
@@ -0,0 +1,20 @@
+
+
+from fastNLP.envs.imports import _NEED_IMPORT_TORCH
+
+if _NEED_IMPORT_TORCH:
+ import torch
+
+__all__ = []
+
+def is_torch_tensor_dtype(dtype) -> bool:
+ """
+ 返回当前 dtype 是否是 torch 的 dtype 类型
+
+ :param dtype: 类似与 torch.ones(3).dtype
+ :return:
+ """
+ try:
+ return isinstance(dtype, torch.dtype)
+ except:
+ return False
diff --git a/fastNLP/core/collators/padders/utils.py b/fastNLP/core/collators/padders/utils.py
new file mode 100644
index 00000000..6a50b33d
--- /dev/null
+++ b/fastNLP/core/collators/padders/utils.py
@@ -0,0 +1,182 @@
+
+__all__ = [
+ 'get_padded_numpy_array'
+]
+
+
+from typing import Sequence, List
+import re
+from inspect import isclass
+
+import numpy as np
+np_str_obj_array_pattern = re.compile(r'[SaUO]')
+
+
+def get_shape(batch_field:List, shape=None):
+ """
+ 给定 field 返回这个 field pad 完成之后的 shape 。
+ 例如: [[1, 2, 3], [3]] -> [2, 3]
+ [[[1], [2], [3, 4]], [[2, 3, 4]]] -> [2, 3, 3]
+
+ :param batch_field: list,第 0 维一般为 batch 维度。
+ :param shape: 无需传入。
+ :return:
+ """
+ if shape is None:
+ shape = []
+ if isinstance(batch_field, Sequence):
+ num_ele = len(batch_field)
+ _shape = shape + [num_ele]
+ try:
+ shapes = []
+ if isinstance(batch_field[0], Sequence):
+ for _field in batch_field:
+ shapes.append(get_shape(_field, _shape))
+ if len(shapes) == 1:
+ max_shape = shapes[0]
+ else:
+ max_shape = [max(_) for _ in zip(*shapes)]
+
+ return max_shape
+ except IndexError: # 空的shape
+ pass
+ return _shape # 说明是一个空的 sequence
+ else:
+ return shape
+
+
+def fill_array(batch_field:List, padded_batch:np.ndarray):
+ """
+ 将 batch_field 中的值填入到 array 中。
+
+ :param batch_field: 需要填充进入 array 中的内容
+ :param padded_batch: 待填充的 np.ndarray
+ :return:
+ """
+ if padded_batch.ndim == 2:
+ for i, content_i in enumerate(batch_field):
+ padded_batch[i, :len(content_i)] = content_i
+ elif padded_batch.ndim == 3:
+ for i, content_i in enumerate(batch_field):
+ for j, content_ii in enumerate(content_i):
+ padded_batch[i, j, :len(content_ii)] = content_ii
+ elif padded_batch.ndim == 4:
+ try: # 应该是图像,所以直接应该就 ok 了。
+ padded_batch = np.array(batch_field)
+ except:
+ for i, content_i in enumerate(batch_field):
+ for j, content_ii in enumerate(content_i):
+ for k, content_iii in enumerate(content_ii):
+ padded_batch[i, j, k, :len(content_iii)] = content_iii
+ elif padded_batch.ndim == 1:
+ padded_batch[:] = batch_field
+ else:
+ raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please "
+ "report.")
+ return padded_batch
+
+
+def get_padded_numpy_array(batch_field: List, dtype=None, pad_val=0) -> np.ndarray:
+ """
+ 将输入 pad 为 :class:`numpy.arraay` 类型,如:``[[1,2], [3]] -> np.array([[1, 2], [3, 0]])``
+
+ :param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 **1d** (多为句子长度)/ **2d** (多为文本序列)/ **3d** (多为字符序列)
+ /4d(多为图片);
+ :param dtype: 输出数据的 dtype 类型;
+ :param pad_val: 填充值;
+ :return:
+ """
+ shapes = get_shape(batch_field)
+ array = np.full(shapes, dtype=dtype, fill_value=pad_val)
+ array = fill_array(batch_field, array)
+ return array
+
+
+def get_padded_nest_list(batch_field: List, pad_val=0) -> List:
+ """
+ 例如:
+ [[1,2], [3]] -> [[1, 2], [3, 0]]
+
+ :param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列)
+ /4d(多为图片)。
+ :param pad_val: pad 的 value
+ :return:
+ """
+
+ array = get_padded_numpy_array(batch_field, pad_val=pad_val, dtype=None).tolist()
+ return array
+
+
+def is_number_or_numpy_number(dtype):
+ """
+ 判断 dtype 是否是数字类型,或者 numpy 的数字类型。
+ is_number_or_numpy_number(type(3)) # True
+ is_number_or_numpy_number(type(3.1)) # True
+ is_number_or_numpy_number(type('3')) # False
+ is_number_or_numpy_number(type(True)) # True
+ is_number_or_numpy_number(type(np.zeros(3)[0])) # True
+ is_number_or_numpy_number(np.zeros(3, dtype=float).dtype) # True
+ is_number_or_numpy_number(np.zeros(3, dtype=int).dtype) # True
+ is_number_or_numpy_number(np.zeros(3, dtype=str).dtype) # False
+ is_number_or_numpy_number(np.array([1, [2]]).dtype) # False
+
+ :param dtype:
+ :return:
+ """
+ if is_number(dtype):
+ return True
+ else:
+ if isclass(dtype):
+ return is_numpy_generic_class(dtype)
+ elif isinstance(dtype, np.dtype) and np_str_obj_array_pattern.search(dtype.str) is None:
+ return True
+ return False
+
+
+def is_numpy_number_dtype(dtype):
+ if not isclass(dtype) and isinstance(dtype, np.dtype) and np_str_obj_array_pattern.search(dtype.str) is None:
+ return True
+ return False
+
+
+def is_numpy_generic_class(dtype):
+ """
+ 形如 np.int64,或者 np.zeros(1).dtype.type 的值
+
+ :param dtype:
+ :return:
+ """
+ if isclass(dtype) and issubclass(dtype, np.generic):
+ return True
+ return False
+
+
+def is_number(dtype):
+ try:
+ if dtype in (float, int, complex, bool) and not is_numpy_generic_class(dtype) \
+ and not is_numpy_number_dtype(dtype):
+ return True
+ return False
+ except:
+ return False
+
+
+
+if __name__ == '__main__':
+ # a = [[[1]], [1, 2, 3], [3]]
+ # a = [[[1], [2], [3, 4]], [[2, 3, 4]]]
+ # b = get_padded_nest_list(a)
+ # print(type(b[0]))
+ # print(b)
+ # import torch
+ print(is_number(type('a')))
+ print(is_number_or_numpy_number(type(3))) # True
+ print(is_number_or_numpy_number(type(3.1))) # True
+ print(is_number_or_numpy_number(type('3'))) # False
+ print(is_number_or_numpy_number(type(True))) # True
+ print(is_number_or_numpy_number(type(np.zeros(3)[0]))) # True
+ print(is_number_or_numpy_number(np.zeros(3, dtype=float).dtype)) # True
+ print(is_number_or_numpy_number(np.zeros(3, dtype=int).dtype)) # True
+ print(is_number_or_numpy_number(np.zeros(3, dtype=str).dtype)) # False
+ print(is_number_or_numpy_number(np.array([1, [2]]).dtype)) # False
+
diff --git a/fastNLP/core/const.py b/fastNLP/core/const.py
deleted file mode 100644
index e53c1f92..00000000
--- a/fastNLP/core/const.py
+++ /dev/null
@@ -1,84 +0,0 @@
-r"""
-fastNLP包当中的field命名均符合一定的规范,该规范由fastNLP.Const类进行定义。
-"""
-
-__all__ = [
- "Const"
-]
-
-
-class Const:
- r"""
- fastNLP中field命名常量。
-
- .. todo::
- 把下面这段改成表格
-
- 具体列表::
-
- INPUT 模型的序列输入 words(具有多列words时,依次使用words1, words2, )
- CHAR_INPUT 模型character输入 chars(具有多列chars时,依次使用chars1, chars2)
- INPUT_LEN 序列长度 seq_len(具有多列seq_len时,依次使用seq_len1,seq_len2)
- OUTPUT 模型输出 pred(具有多列pred时,依次使用pred1, pred2)
- TARGET 真实目标 target(具有多列target时,依次使用target1,target2)
- LOSS 损失函数 loss (具有多列loss时,依次使用loss1,loss2)
- RAW_WORD 原文的词 raw_words (具有多列raw_words时,依次使用raw_words1, raw_words2)
- RAW_CHAR 原文的字 raw_chars (具有多列raw_chars时,依次使用raw_chars1, raw_chars2)
-
- """
- INPUT = 'words'
- CHAR_INPUT = 'chars'
- INPUT_LEN = 'seq_len'
- OUTPUT = 'pred'
- TARGET = 'target'
- LOSS = 'loss'
- RAW_WORD = 'raw_words'
- RAW_CHAR = 'raw_chars'
-
- @staticmethod
- def INPUTS(i):
- r"""得到第 i 个 ``INPUT`` 的命名"""
- i = int(i) + 1
- return Const.INPUT + str(i)
-
- @staticmethod
- def CHAR_INPUTS(i):
- r"""得到第 i 个 ``CHAR_INPUT`` 的命名"""
- i = int(i) + 1
- return Const.CHAR_INPUT + str(i)
-
- @staticmethod
- def RAW_WORDS(i):
- r"""得到第 i 个 ``RAW_WORDS`` 的命名"""
- i = int(i) + 1
- return Const.RAW_WORD + str(i)
-
- @staticmethod
- def RAW_CHARS(i):
- r"""得到第 i 个 ``RAW_CHARS`` 的命名"""
- i = int(i) + 1
- return Const.RAW_CHAR + str(i)
-
- @staticmethod
- def INPUT_LENS(i):
- r"""得到第 i 个 ``INPUT_LEN`` 的命名"""
- i = int(i) + 1
- return Const.INPUT_LEN + str(i)
-
- @staticmethod
- def OUTPUTS(i):
- r"""得到第 i 个 ``OUTPUT`` 的命名"""
- i = int(i) + 1
- return Const.OUTPUT + str(i)
-
- @staticmethod
- def TARGETS(i):
- r"""得到第 i 个 ``TARGET`` 的命名"""
- i = int(i) + 1
- return Const.TARGET + str(i)
-
- @staticmethod
- def LOSSES(i):
- r"""得到第 i 个 ``LOSS`` 的命名"""
- i = int(i) + 1
- return Const.LOSS + str(i)
diff --git a/fastNLP/core/controllers/__init__.py b/fastNLP/core/controllers/__init__.py
new file mode 100644
index 00000000..ec47f254
--- /dev/null
+++ b/fastNLP/core/controllers/__init__.py
@@ -0,0 +1,13 @@
+__all__ = [
+ 'Loop',
+ 'EvaluateBatchLoop',
+ 'TrainBatchLoop',
+ 'Evaluator',
+ 'Trainer',
+]
+
+from .loops import Loop, EvaluateBatchLoop, TrainBatchLoop
+from .utils import State, TrainerState
+from .evaluator import Evaluator
+from .trainer import Trainer
+
diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py
new file mode 100644
index 00000000..597d13de
--- /dev/null
+++ b/fastNLP/core/controllers/evaluator.py
@@ -0,0 +1,530 @@
+r"""
+``Evaluator`` 是新版 **fastNLP** 中用来进行评测模型的评测器,其与 ``Trainer`` 相对应,二者共同构建起了 **fastNLP** 中 **训练** 和 **评测** 的框架。
+``Evaluator`` 的整体架构与 ``Trainer`` 类似,也是利用 ``Driver`` 来负责底层的评测逻辑。通过使用 ``Evaluator``,您可以快速、方便、准确地
+对您的模型进行全方位地评测。
+
+.. note::
+
+ ``Trainer`` 通过来自己内部内置一个 ``Evaluator`` 实例来支持在训练过程中进行验证的功能;
+"""
+
+from typing import Union, List, Optional, Dict, Callable, BinaryIO
+import os
+from pathlib import Path
+import io
+from dataclasses import is_dataclass
+
+__all__ = [
+ 'Evaluator'
+]
+
+from fastNLP.core.drivers import Driver, TorchDriver
+from ..drivers.choose_driver import choose_driver
+from .loops import Loop, EvaluateBatchLoop
+from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \
+ match_and_substitute_params, f_rich_progress, flat_nest_dict, f_tqdm_progress
+from fastNLP.core.metrics import Metric
+from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric
+from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader
+from fastNLP.core.utils.utils import _check_valid_parameters_number
+from fastNLP.core.log import logger
+from fastNLP.envs import FASTNLP_MODEL_FILENAME
+
+
+
+class Evaluator:
+ """
+ 用于评测模型性能好坏的评测器;
+
+ .. note::
+
+ ``Evaluator`` 与 ``Trainer`` 类似,都是使用 ``Driver`` 作为底层来实现评测或者训练,因此大多数与 ``Trainer`` 同名的参数的意义和使用都与
+ ``Trainer`` 中的参数相同,对于这些参数,您可以参考 ``Trainer`` 的文档来获取更详细的信息;详见 :class:`~fastNLP.core.controllers.trainer.Trainer`;
+
+ :param model: 训练所需要的模型,例如 ``torch.nn.Module``,等价于 ``Trainer`` 中的 ``model`` 参数;
+ :param dataloaders: 用于评测的数据集。如果为多个,您需要使用 ``dict`` 传入,即对每一个数据集标上用于标识它们的标签;也可以使用 evaluate_dataloaders
+ 作为参数的名称。
+ :param metrics: 评测时使用的指标。注意该参数必须为 ``dict`` 类型,其中 ``key`` 为一个 ``metric`` 的名称,``value`` 为具体的 ``Metric`` 对象。目前支持以下 metrics:
+
+ 1. fastNLP 自己的 ``metric``:详见 :class:`~fastNLP.core.metrics.Metric`;
+ 2. torchmetrics;
+ 3. allennlp.training.metrics;
+ 4. paddle.metric;
+
+ :param driver: 等价于 ``Trainer`` 中的 ``driver`` 参数;
+
+ .. note::
+
+ 如果在您的脚本中在初始化 ``Evaluator`` 前也初始化了 ``Trainer`` 进行训练,那么强烈建议您直接将 ``trainer.driver`` 传入 ``Evaluator`` 当做该参数的值;
+
+ .. code-block::
+
+ # 初始化 Trainer
+ trainer = Trainer(
+ ...
+ driver='torch',
+ device=[0,1]
+ )
+ trainer.run()
+
+ # 此时再初始化 Evaluator 时应当直接使用 trainer.driver;
+ evaluator = Evaluator(
+ ...
+ driver=trainer.driver
+ )
+
+ :param device: 等价于 ``Trainer`` 中的 ``device`` 参数;
+ :param evaluate_batch_step_fn: 您可以传入该参数来定制每次评测一个 batch 的数据时所执行的函数。该函数应接受的两个参数为 ``evaluator`` 和 ``batch``,
+ 不需要有返回值;可以参考 :meth:`~fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop.batch_step_fn`;
+ :param evaluate_fn: 用来控制 ``Evaluator`` 在评测的前向传播过程中调用的是哪一个函数,例如对于 pytorch 而言,通过该参数确定使用的是 :meth:`model.evaluate_step` 还是
+ :meth:`model.forward` (不同训练框架所使用的的前向传播函数的方法名称不同);
+
+ 1. 如果该值是 ``None``,那么我们会默认使用 ``evaluate_step`` 当做前向传播的函数,如果在模型中没有找到该方法,则使用训练框架默认的前向传播函数;
+ 2. 如果为 ``str`` 类型,例如为 ``'my_evaluate_step_fn'``,则尝试寻找 :meth:`model.my_evaluate_step_fn`,如果找不到则直接报错;
+
+ :param input_mapping: 等价于 ``Trainer`` 中的 ``input_mapping`` 参数;对具体的用于评测一个 batch 的数据使用 ``input_mapping`` 处理之后再输入到 ``model`` 以及 ``metric`` 中。如果针对
+ ``model`` 和 ``metric`` 需要不同的 ``mapping``,请考虑使用 ``evaluate_batch_step_fn`` 参数定制;也可以使用 evaluate_input_mapping 参数名传入。
+
+ .. todo::
+
+ 之后链接上 参数匹配 的文档;
+
+ :param output_mapping: 等价于 ``Trainer`` 中的 ``output_mapping`` 参数;对 ``model`` 输出的内容,将通过 ``output_mapping`` 处理之后再输入到 ``metric`` 中;
+ 也可以使用 evaluate_output_mapping 参数名传入。
+ :param model_wo_auto_param_call: 等价于 ``Trainer`` 中的 ``model_wo_auto_param_call`` 参数;
+
+ .. note::
+
+ 一个十分需要注意的问题在于 ``model_wo_auto_param_call`` 只会关闭部分的参数匹配,即指挥关闭前向传播时的参数匹配,但是由于 ``Evaluator`` 中
+ ``metric`` 的计算都是自动化的,因此其一定需要参数匹配:根据 ``metric.update`` 的函数签名直接从字典数据中抽取其需要的参数传入进去;
+
+
+ :param fp16: 是否在评测时使用 fp16 混合精度;
+ :param verbose: 是否打印 evaluate 的结果;
+ :kwargs:
+ * *torch_kwargs* -- 等价于 ``Trainer`` 中的 ``torch_kwargs`` 参数;
+ * *paddle_kwargs* -- 等价于 ``Trainer`` 中的 ``paddle_kwargs`` 参数;
+ * *fairscale_kwargs* -- 等价于 ``Trainer`` 中的 ``fairscale_kwargs`` 参数;
+ * *deepspeed_kwargs* -- 等价于 ``Trainer`` 中的 ``deepspeed_kwargs`` 参数;
+ * *oneflow_kwargs* -- 等价于 ``Trainer`` 中的 ``oneflow_kwargs`` 参数;
+ * *data_device* -- 等价于 ``Trainer`` 中的 ``data_device`` 参数;
+ * *model_use_eval_mode* (``bool``) --
+ 是否在评测的时候将 ``model`` 的状态设置成 ``eval`` 状态。在 ``eval`` 状态下,``model`` 的
+ ``dropout`` 与 ``batch normalization`` 将会关闭。默认为 ``True``。如果为 ``False``,``fastNLP`` 不会对 ``model`` 的 ``evaluate`` 状态做任何设置。无论
+ 该值是什么,``fastNLP`` 都会在评测后将 ``model`` 的状态设置为 ``train``;
+ * *use_dist_sampler* --
+ 表示在 ``Evaluator`` 中在使用分布式的时候是否将保证 dataloader 的 ``sampler`` 替换为
+ 分布式的 ``sampler``,其特点是每个卡上的数据之间不重叠,所有卡上数据的加起来是整个数据集。若传入的 dataloader
+ 的 sampler 为:
+
+ - 深度学习框架自带的默认 sampler ;
+ - fastNLP 的 Sampler ;
+ 则将替换为 :class:`~fastNLP.UnrepeatedSequentialSampler`,如果这个行为不是期待的,请本参数设置为 ``False``,并针对每个卡控制其可以
+ 用到的数据。如果不是以上两类 sampler ,fastNLP 将报错。
+ * *output_from_new_proc* -- 等价于 ``Trainer`` 中的 ``output_from_new_proc`` 参数;
+ * *progress_bar* -- 等价于 ``Trainer`` 中的 ``progress_bar`` 参数;
+ * *check_dataloader_legality* -- 是否检查 ``DataLoader`` 是否合法,默认为 ``True`` 。
+
+ """
+
+ driver: Driver
+ _evaluate_batch_loop: Loop
+
+ def __init__(self, model, dataloaders=None, metrics: Optional[Dict] = None,
+ driver: Union[str, Driver] = 'auto', device: Optional[Union[int, List[int], str]] = None,
+ evaluate_batch_step_fn: Optional[Callable] = None, evaluate_fn: Optional[str] = None,
+ input_mapping: Optional[Union[Callable, Dict]] = None,
+ output_mapping: Optional[Union[Callable, Dict]] = None, model_wo_auto_param_call: bool = False,
+ fp16: bool = False, verbose: int = 1, **kwargs):
+ self.model = model
+ self.metrics = metrics
+ self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call,
+ **kwargs)
+
+ dataloaders = dataloaders if dataloaders is not None else kwargs.get('evaluate_dataloaders')
+ if dataloaders is None:
+ raise ValueError("Parameter `dataloaders` can not be None.")
+ self.dataloaders = dataloaders
+ self.device = device
+ self.verbose = verbose
+
+ self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=evaluate_batch_step_fn)
+
+ if evaluate_batch_step_fn is not None:
+ _check_valid_parameters_number(evaluate_batch_step_fn, ['evaluator', 'batch'], fn_name='evaluate_batch_step_fn')
+ self.evaluate_batch_step_fn = evaluate_batch_step_fn
+
+ self.input_mapping = input_mapping if input_mapping is not None else kwargs.get('evaluate_input_mapping')
+ self.output_mapping = output_mapping if output_mapping is not None else kwargs.get('evaluate_output_mapping')
+
+ # check dataloader
+ if not isinstance(dataloaders, dict):
+ if kwargs.get('check_dataloader_legality', True):
+ try:
+ self.driver.check_dataloader_legality(dataloader=dataloaders)
+ except TypeError as e:
+ logger.error("`dataloaders` is invalid.")
+ raise e
+ dataloaders = {None: dataloaders}
+ else:
+ if kwargs.get('check_dataloader_legality', True):
+ for key, dataloader in dataloaders.items():
+ try:
+ self.driver.check_dataloader_legality(dataloader=dataloader)
+ except TypeError as e:
+ logger.error(f"The dataloader named:{key} is invalid.")
+ raise e
+
+ self.driver.setup()
+ self.driver.barrier()
+
+ self.separator = kwargs.get('separator', '#')
+ self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True)
+ use_dist_sampler = kwargs.get("use_dist_sampler", None)
+ if use_dist_sampler is None:
+ use_dist_sampler = self.driver.is_distributed()
+ if use_dist_sampler:
+ self._dist_sampler = "unrepeatdist"
+ else:
+ self._dist_sampler = None
+ self._metric_wrapper = None
+ _ = self.metrics_wrapper # 触发检查
+
+ if evaluate_fn is not None and not isinstance(evaluate_fn, str):
+ raise TypeError("Parameter `evaluate_fn` can only be `str` type when it is not None.")
+ self._evaluate_step, self._evaluate_step_signature_fn = \
+ self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn)
+ self.evaluate_fn = evaluate_fn
+
+ self.dataloaders = {}
+ for name, dl in dataloaders.items(): # 替换为正确的 sampler
+ dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist=self._dist_sampler, reproducible=False)
+ self.dataloaders[name] = dl
+
+ self.progress_bar = kwargs.get('progress_bar', 'auto')
+ assert self.progress_bar in [None, 'rich', 'auto', 'tqdm', 'raw']
+ if self.progress_bar == 'auto':
+ self.progress_bar = 'raw' if f_rich_progress.dummy else 'rich'
+
+ self.driver.barrier()
+
+ def load_model(self, folder: Union[str, Path, BinaryIO, io.BytesIO], only_state_dict: bool = True,
+ model_load_fn: Optional[Callable] = None, **kwargs):
+ """
+ 用于帮助您加载模型的辅助函数;
+
+ :param folder: 存放着您需要加载的 model 的文件夹,默认会尝试读取该文件夹下的 ``fastnlp_model.pkl.tar`` 文件。在 ``model_load_fn`` 不为空时,
+ 直接将该 folder 传递到 ``model_load_fn`` 中;
+ :param only_state_dict: 要读取的文件中是否仅包含模型权重。在 ``model_load_fn`` 不为 ``None`` 时,该参数无意义;
+ :param model_load_fn: :class:`Callable` 的函数,接受一个 folder 作为参数,需要注意该函数不需要返回任何内容;
+ :param kwargs: 理论上您不需要使用到该参数;
+
+ .. note::
+
+ 注意您需要在初始化 ``Evaluator`` 后再通过 ``evaluator`` 实例来调用该函数;这意味着您需要保证在保存和加载时使用的 ``driver`` 是属于同一个
+ 训练框架的,例如都是 **pytorch** 或者 **PaddlePaddle** ;
+ """
+ self.driver.barrier()
+ if not isinstance(folder, (io.BytesIO, BinaryIO)):
+ try:
+ if model_load_fn is not None:
+ if not callable(model_load_fn):
+ raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.")
+ model_load_fn(folder)
+ else:
+ if isinstance(folder, str):
+ folder = Path(folder)
+ self.driver.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs)
+ except FileNotFoundError as e:
+ if FASTNLP_MODEL_FILENAME not in os.listdir(folder):
+ logger.error(f"fastNLP model checkpoint file:{FASTNLP_MODEL_FILENAME} is not found in {folder}.")
+ raise e
+ else:
+ if model_load_fn is not None:
+ raise RuntimeError("It is not allowed to specify a `model_save_fn` parameter with `folder` being "
+ "`io.BytesIO` type.")
+ self.driver.load_model(folder, only_state_dict, **kwargs)
+ self.driver.barrier()
+
+ def run(self, num_eval_batch_per_dl: int = -1) -> Dict:
+ """
+ 该函数是在 ``Evaluator`` 初始化后用于真正开始评测的函数;
+
+ 返回一个字典类型的数据,其中 key 为 metric 的名字,value 为对应 metric 的结果。
+
+ 1. 如果存在多个 metric ,一个 dataloader 的情况,key 的命名规则是
+ ``metric_indicator_name#metric_name``;
+ 2. 如果存在多个数据集,一个metric的情况,key的命名规则是
+ ``metric_indicator_name#metric_name#dataloader_name`` (其中 **#** 是默认的 separator ,可以通过 Evaluator 初始化参数修改);
+ 3. 如果存在多个metric,多个dataloader的情况,key的命名规则是
+ ``metric_indicator_name#metric_name#dataloader_name``,其中 metric_indicator_name 可能不存在;
+
+ :param num_eval_batch_per_dl: 每个 dataloader 测试前多少个 batch 的数据,-1 为测试所有数据。
+ :return: 评测得到的结果,是一个没有嵌套的字典;
+ """
+ assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type."
+ assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0."
+
+ metric_results = {}
+ self.reset()
+ evaluate_context = self.driver.get_evaluate_context()
+ self.driver.set_model_mode(mode='eval' if self.model_use_eval_mode else 'train')
+ with evaluate_context():
+ try:
+ for dataloader_name, dataloader in self.dataloaders.items():
+ self.driver.barrier()
+ if num_eval_batch_per_dl != -1:
+ dataloader = _TruncatedDataLoader(dataloader, num_eval_batch_per_dl)
+ self.driver.set_sampler_epoch(dataloader, -1)
+ self.start_progress_bar(total=len(dataloader), dataloader_name=dataloader_name)
+ self.cur_dataloader_name = dataloader_name
+ results = self.evaluate_batch_loop.run(self, dataloader)
+ self.remove_progress_bar(dataloader_name)
+ metric_results[dataloader_name] = results
+ self.reset()
+ self.driver.barrier()
+ except BaseException as e:
+ self.driver.on_exception()
+ raise e
+ finally:
+ self.finally_progress_bar()
+ metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False)
+ if len(metric_results) > 0: # 如果 metric 不为 None 需要 print 。
+ # metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False)
+ if self.verbose:
+ if self.progress_bar == 'rich':
+ f_rich_progress.print(metric_results)
+ else:
+ logger.info(metric_results)
+ self.driver.set_model_mode(mode='train')
+
+ return metric_results
+
+ def start_progress_bar(self, total: int, dataloader_name):
+ if self.progress_bar in ('rich', 'tqdm'):
+ if dataloader_name is None:
+ desc = f'Eval. Batch'
+ else:
+ desc = f'Eval. on {dataloader_name} Batch'
+ if self.progress_bar == 'rich':
+ self._task_id = f_rich_progress.add_task(description=desc, total=total)
+ else:
+ self._task_id = f_tqdm_progress.add_task(description=desc, total=total)
+ elif self.progress_bar == 'raw':
+ desc = 'Evaluation starts'
+ if dataloader_name is not None:
+ desc += f' on {dataloader_name}'
+ logger.info('\n' + "*" * 10 + desc + '*' * 10)
+
+ def update_progress_bar(self, batch_idx, dataloader_name, **kwargs):
+ if dataloader_name is None:
+ desc = f'Eval. Batch:{batch_idx}'
+ else:
+ desc = f'Eval. on {dataloader_name} Batch:{batch_idx}'
+ if self.progress_bar == 'rich':
+ assert hasattr(self, '_task_id'), "You must first call `start_progress_bar()` before calling " \
+ "update_progress_bar()"
+ f_rich_progress.update(self._task_id, description=desc, post_desc=kwargs.get('post_desc', ''),
+ advance=kwargs.get('advance', 1), refresh=kwargs.get('refresh', True),
+ visible=kwargs.get('visible', True))
+ elif self.progress_bar == 'raw':
+ if self.verbose > 1:
+ logger.info(desc)
+ elif self.progress_bar == 'tqdm':
+ f_tqdm_progress.update(self._task_id, advance=1)
+
+ def remove_progress_bar(self, dataloader_name):
+ if self.progress_bar == 'rich' and hasattr(self, '_task_id'):
+ f_rich_progress.destroy_task(self._task_id)
+ delattr(self, '_task_id')
+
+ elif self.progress_bar == 'tqdm' and hasattr(self, '_task_id'):
+ f_tqdm_progress.destroy_task(self._task_id)
+ delattr(self, '_task_id')
+
+ elif self.progress_bar == 'raw':
+ desc = 'Evaluation ends'
+ if dataloader_name is not None:
+ desc += f' on {dataloader_name}'
+ logger.info("*" * 10 + desc + '*' * 10 + '\n')
+
+ def finally_progress_bar(self):
+ if self.progress_bar == 'rich' and hasattr(self, '_task_id'):
+ f_rich_progress.destroy_task(self._task_id)
+ delattr(self, '_task_id')
+ elif self.progress_bar == 'tqdm' and hasattr(self, '_task_id'):
+ f_tqdm_progress.destroy_task(self._task_id)
+ delattr(self, '_task_id')
+
+ @property
+ def evaluate_batch_loop(self):
+ return self._evaluate_batch_loop
+
+ @evaluate_batch_loop.setter
+ def evaluate_batch_loop(self, loop: Loop):
+ if getattr(self, 'evaluate_step_fn', None) is not None:
+ logger.rank_zero_warning("`evaluate_batch_step_fn` was customized in the Evaluator initialization, it will be ignored "
+ "when the `evaluate_batch_loop` is also customized.")
+ self._evaluate_batch_loop = loop
+
+ def reset(self):
+ """
+ 调用所有 metric 的 :meth:`reset` 方法,清除累积的状态。
+
+ :return:
+ """
+ self.metrics_wrapper.reset()
+
+ def update(self, batch, outputs):
+ """
+ 自动调用所有 metric 的 :meth:`update` 方法,会根据不同 metric 的参数列表进行匹配传参。
+
+ :param batch: 一般是来自于 DataLoader 的输出,如果不为 dict 类型的话,该值将被忽略。
+ :param outputs: 一般是来自于模型的输出。类别应为 dict 或者 dataclass 类型。
+ :return:
+ """
+ self.metrics_wrapper.update(batch, outputs)
+
+ def get_metric(self) -> Dict:
+ """
+ 调用所有 metric 的 :meth:`get_metric` 方法,并返回结果。其中 key 为 metric 的名称,value 是各个 metric 的结果。
+
+ :return:
+ """
+ return self.metrics_wrapper.get_metric()
+
+ @property
+ def metrics_wrapper(self):
+ """
+ 由于需要保持 Evaluator 中 ``metrics`` 对象与用户传入的 ``metrics`` 保持完全一致(方便在 ``evaluate_batch_step_fn`` )中使用,同时也为了支持
+ 不同形式的 metric( fastNLP 的 metric/torchmetrics 等),所以 Evaluator 在进行 metric 操作的时候都调用 metrics_wrapper
+ 进行操作。
+ """
+ if self._metric_wrapper is None:
+ self._metric_wrapper = _MetricsWrapper(self.metrics, evaluator=self)
+ return self._metric_wrapper
+
+ def evaluate_step(self, batch):
+ """
+ 将 ``batch`` 传递到 model 中进行处理,根据当前 ``evaluate_fn`` 选择进行 evaluate 。会将返回结果经过 ``output_mapping``
+ 处理后再
+返回。
+
+ :param batch: ``evaluate_fn`` 函数支持的输入类型
+ :return: ``evaluate_fn`` 函数的输出结果,如果有设置 ``output_mapping`` ,将是 ``output_mapping`` 之后的结果。
+ """
+ outputs = self.driver.model_call(batch, self._evaluate_step, self._evaluate_step_signature_fn)
+ outputs = match_and_substitute_params(self.output_mapping, outputs)
+ return outputs
+
+ @property
+ def metrics(self):
+ """
+ 返回用户传入的 ``metrics`` 对象。
+
+ :return:
+ """
+ return self._metrics
+
+ @metrics.setter
+ def metrics(self, metrics):
+ self._metrics = metrics
+
+ def move_data_to_device(self, batch):
+ return self.driver.move_data_to_device(batch)
+
+
+class _MetricsWrapper:
+ """
+ 注意 metrics 的输入只支持:Dict[str, Metric];
+ 并且通过对 update() , reset() , get_metric() 函数的封装,实现支持 fastNLP 的 metric 以及 torchmetrics 或者更多。
+
+ """
+
+ def __init__(self, metrics, evaluator):
+ self.evaluator = evaluator
+ self._metrics = []
+ self._metric_names = []
+ if metrics is not None:
+ if not isinstance(metrics, Dict):
+ raise TypeError("Parameter `metrics` can only be `Dict` type.")
+ for metric_name, metric in metrics.items():
+ # 因为 torchmetrics 是一个 nn.Module,因此我们需要先将其移到对应的机器上;
+ if _is_torchmetrics_metric(metric) and isinstance(evaluator.driver, TorchDriver):
+ # torchmetrics 是默认自动开启了多卡的
+ evaluator.driver.move_model_to_device(metric, evaluator.driver.data_device)
+ elif isinstance(metric, Metric):
+ # 如果数据是分布式的,但是不aggregate的话可能有问题
+ if evaluator._dist_sampler is not None and metric.aggregate_when_get_metric is False:
+ logger.rank_zero_warning(
+ "You have replaced the sampler as distributed sampler when evaluation, but your metric "
+ f"{metric_name}:{metric.__class__.__name__}'s `aggregate_when_get_metric` is False.", once=True)
+ if metric.aggregate_when_get_metric is None:
+ metric.aggregate_when_get_metric = evaluator._dist_sampler is not None
+
+ metric.to(evaluator.driver.data_device)
+ self._metric_names.append(metric_name)
+ self._metrics.append(metric)
+
+ def update(self, batch, outputs):
+ if is_dataclass(outputs):
+ outputs = dataclass_to_dict(outputs)
+ for metric in self._metrics:
+ args = []
+ if not isinstance(batch, dict):
+ logger.rank_zero_warning(
+ f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on "
+ f"the output of model to update metric.", once=True)
+ else:
+ args.append(batch)
+ if not isinstance(outputs, dict):
+ raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly"
+ f" return a dict from your model or use `output_mapping` to convert it into dict "
+ f"type.")
+ if isinstance(metric, Metric):
+ # 这样在 auto_param_call 报错的时候才清晰。
+ auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__)
+ elif _is_torchmetrics_metric(metric):
+ auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__)
+ elif _is_allennlp_metric(metric):
+ auto_param_call(metric.__call__, outputs, *args)
+ elif _is_paddle_metric(metric):
+ res = auto_param_call(metric.compute, outputs, *args)
+ metric.update(res)
+
+ def reset(self):
+ """
+ 将 Metric 中的状态重新设置。
+
+ :return:
+ """
+ for metric in self._metrics:
+ if _is_allennlp_metric(metric):
+ metric.get_metric(reset=True)
+ elif _is_torchmetrics_metric(metric) or _is_paddle_metric(metric) or isinstance(metric, Metric):
+ metric.reset()
+
+ def get_metric(self) -> Dict:
+ """
+ 调用各个 metric 得到 metric 的结果。并使用 {'metric_name1': metric_results, 'metric_name2': metric_results} 的形式
+ 返回。
+
+ :return:
+ """
+ results = {}
+ for metric_name, metric in zip(self._metric_names, self._metrics):
+ if isinstance(metric, Metric):
+ _results = metric.get_metric()
+ elif _is_allennlp_metric(metric):
+ _results = metric.get_metric(reset=False)
+ elif _is_torchmetrics_metric(metric):
+ _results = metric.compute()
+ elif _is_paddle_metric(metric):
+ _results = metric.accumulate()
+ else:
+ raise RuntimeError(f"Not support `{type(metric)}` for now.")
+ if _results is not None:
+ results[metric_name] = _results
+ else:
+ logger.warning_once(f"Metric:{metric_name} returns None when getting metric results.")
+ return results
diff --git a/fastNLP/core/controllers/loops/__init__.py b/fastNLP/core/controllers/loops/__init__.py
new file mode 100644
index 00000000..88b7d881
--- /dev/null
+++ b/fastNLP/core/controllers/loops/__init__.py
@@ -0,0 +1,9 @@
+__all__ = [
+ 'EvaluateBatchLoop',
+ 'Loop',
+ 'TrainBatchLoop'
+]
+
+from .loop import Loop
+from .evaluate_batch_loop import EvaluateBatchLoop
+from .train_batch_loop import TrainBatchLoop
diff --git a/fastNLP/core/controllers/loops/evaluate_batch_loop.py b/fastNLP/core/controllers/loops/evaluate_batch_loop.py
new file mode 100644
index 00000000..c31cfa0e
--- /dev/null
+++ b/fastNLP/core/controllers/loops/evaluate_batch_loop.py
@@ -0,0 +1,67 @@
+from typing import Optional, Callable, Dict
+
+__all__ = [
+ 'EvaluateBatchLoop'
+]
+
+from .loop import Loop
+from fastNLP.core.log import logger
+from fastNLP.core.utils import match_and_substitute_params
+
+
+class EvaluateBatchLoop(Loop):
+ r"""
+ ``EvaluateBatchLoop`` 针对一个 dataloader 的数据完成一个 epoch 的评测迭代过程;
+
+ :param batch_step_fn: 您可以传入该参数来替换默认的 ``bath_step_fn``;
+ """
+ def __init__(self, batch_step_fn:Optional[Callable]=None):
+ if batch_step_fn is not None:
+ self.batch_step_fn = batch_step_fn
+
+ def run(self, evaluator, dataloader) -> Dict:
+ r"""
+ 需要返回在传入的 ``dataloader`` 中的 evaluation 结果
+
+ :param evaluator: :class:`~fastNLP.core.controllers.Evaluator` 对象
+ :param dataloader: 当前需要进行评测的 ``dataloader``
+ :return:
+ """
+ iterator = iter(dataloader)
+ batch_idx = 0
+ while True:
+ try:
+ batch = next(iterator)
+ except StopIteration:
+ break
+ try:
+ batch = match_and_substitute_params(evaluator.input_mapping, batch)
+ batch = evaluator.move_data_to_device(batch)
+
+ self.batch_step_fn(evaluator, batch)
+ batch_idx += 1
+ evaluator.update_progress_bar(batch_idx, evaluator.cur_dataloader_name)
+
+ except BaseException as e:
+ if callable(getattr(dataloader, 'get_batch_indices', None)):
+ indices = dataloader.get_batch_indices()
+ if evaluator.cur_dataloader_name is not None:
+ logger.error(f"Exception happens when evaluating on samples in dataloader:"
+ f"{evaluator.cur_dataloader_name}: {indices}")
+ else:
+ logger.error(f"Exception happens when evaluating on samples: {indices}")
+ raise e
+ # 获取metric结果。返回的dict内容示例为{'metric_name1': metric_results, 'metric_name2': metric_results, ...}
+ results = evaluator.get_metric()
+ return results
+
+ @staticmethod
+ def batch_step_fn(evaluator, batch):
+ r"""
+ 针对一个 ``batch`` 的数据的评测过程;
+
+ :param evaluator: :class:`~fastNLP.core.controllers.Evaluator` 对象
+ :param batch: 当前需要评测的一个 ``batch`` 的数据;
+ """
+ outputs = evaluator.evaluate_step(batch) # 将batch输入到model中得到结果
+ evaluator.update(batch, outputs) # evaluator将根据metric的形参名字从batch/outputs中取出对应的值进行赋值
diff --git a/fastNLP/core/controllers/loops/loop.py b/fastNLP/core/controllers/loops/loop.py
new file mode 100644
index 00000000..dc149587
--- /dev/null
+++ b/fastNLP/core/controllers/loops/loop.py
@@ -0,0 +1,38 @@
+r"""
+``TrainBatchLoop`` 和 ``EvaluateBatchLoop`` 的父类,为了在实现 **fastNLP** 主要功能的同时保证 **fastNLP** 的易用性和代码的易读性,我们只对
+训练中的循环做了非常简单的抽象,``Loop`` 表示的是在训练或者评测的过程中针对单独一个 ``dataloader`` 的一个 ``epoch`` 的运算过程;
+
+更为具体的使用详见 :class:`~fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop` 和
+:class:`~fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop` ;
+"""
+
+from typing import Union
+
+__all__ = [
+ 'Loop'
+]
+
+
+class Loop:
+ r"""
+ ``TrainBatchLoop`` 和 ``EvaluateBatchLoop`` 的父类,您可以继承此类来定制自己的训练或者评测 ``loop``;
+ """
+
+ def run(self, controller: Union["Trainer", "Evaluator"], dataloader):
+ r"""
+ 遍历参数 ``dataloader`` 的所有数据,使用 ``controller`` 进行训练或者评测;
+
+ .. note::
+
+ ``Trainer`` 和 ``Evaluator`` 中都提供了方便您进行定制 ``Loop`` 的接口函数,例如 ``Trainer.train_step``, ``Trainer.backward`` 等;
+
+ 在定制您自己的 ``TrainBatchLoop`` 时,请务必记得在正确的时机调用对应的 callback 函数,详见 :class:`~fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop`
+ 中对于 callback 函数的调用;
+
+ """
+
+ @staticmethod
+ def batch_step_fn(controller: Union["Trainer", "Evaluator"], batch):
+ r"""
+ 对于具体的一个 ``batch`` 的数据,实现训练或者评测过程中的一步;
+ """
\ No newline at end of file
diff --git a/fastNLP/core/controllers/loops/train_batch_loop.py b/fastNLP/core/controllers/loops/train_batch_loop.py
new file mode 100644
index 00000000..ca97fe9e
--- /dev/null
+++ b/fastNLP/core/controllers/loops/train_batch_loop.py
@@ -0,0 +1,83 @@
+__all__ = [
+ 'TrainBatchLoop'
+]
+
+from typing import Optional, Callable
+
+from .loop import Loop
+from fastNLP.core.log import logger
+from fastNLP.core.utils import match_and_substitute_params
+from fastNLP.core.utils.exceptions import EarlyStopException
+
+
+class TrainBatchLoop(Loop):
+ r"""
+ ``TrainBatchLoop`` 针对一个 dataloader 的数据完成一个 epoch 的训练迭代过程;
+
+ :param batch_step_fn: 您可以传入该参数来替换默认的 ``bath_step_fn``;
+ """
+
+ def __init__(self, batch_step_fn: Optional[Callable] = None):
+ if batch_step_fn is not None:
+ self.batch_step_fn = batch_step_fn
+
+ def run(self, trainer, dataloader):
+ r"""
+ 对传入的 ``dataloader`` 进行一个 epoch 的主要的训练的循环过程;
+
+ .. note::
+
+ 您不需要自己主动地调用该方法,``Trainer`` 会负责调用该方法来完成训练过程;
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ :param dataloader: 当前训练所使用的 ``dataloader``;
+ """
+ get_batch_indices = dataloader.get_batch_indices if callable(getattr(dataloader, 'get_batch_indices', None))\
+ else lambda *args, **kwargs: None
+ dataloader = iter(dataloader)
+ while trainer.batch_idx_in_epoch<=trainer.num_batches_per_epoch:
+ try:
+ trainer.on_fetch_data_begin()
+ batch = next(dataloader)
+ indices = get_batch_indices()
+ except StopIteration:
+ trainer.on_fetch_data_end()
+ break
+
+ trainer.on_fetch_data_end()
+
+ try:
+ batch = match_and_substitute_params(trainer.input_mapping, batch)
+ batch = trainer.move_data_to_device(batch)
+
+ trainer.on_train_batch_begin(batch, indices)
+ with trainer.get_no_sync_context(): # 在多卡的时候可能需要关闭 sync
+ self.batch_step_fn(trainer, batch)
+ trainer.global_forward_batches += 1
+ trainer.batch_idx_in_epoch += 1
+
+ trainer.check_batch_step_fn()
+ trainer.on_train_batch_end()
+ except BaseException as e:
+ if indices is not None and not isinstance(e, (EarlyStopException, KeyboardInterrupt)):
+ logger.error(f"Exception happens when training on samples: {indices}")
+ raise e
+ trainer.step_evaluate()
+ trainer.batch_idx_in_epoch = 0
+
+ @staticmethod
+ def batch_step_fn(trainer, batch):
+ r"""
+ 针对一个 ``batch`` 的数据的训练过程;
+
+ :param trainer: :class:`~fastNLP.core.controllers.Trainer` 实例;
+ :param batch: 一个 ``batch`` 的数据;
+ """
+ outputs = trainer.train_step(batch)
+ trainer.backward(outputs)
+ trainer.step()
+ trainer.zero_grad()
+
+
+
+
diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py
new file mode 100644
index 00000000..41ad2f7b
--- /dev/null
+++ b/fastNLP/core/controllers/trainer.py
@@ -0,0 +1,1540 @@
+"""
+``Trainer`` 是 **fastNLP** 用于训练模型的专门的训练器,其支持多种不同的驱动模式 ``Driver``,不仅包括最为经常使用的 DDP,而且还支持 jittor 等国产
+的训练框架;新版的 **fastNLP** 新加入了方便的 callback 函数修饰器,并且支持定制用户自己特定的训练循环过程;通过使用该训练器,用户只需要自己实现
+模型部分,而将训练层面的逻辑完全地交给 **fastNLP**;
+"""
+
+from typing import Union, Optional, List, Callable, Dict, BinaryIO
+from functools import partial
+from collections import defaultdict
+import copy
+from contextlib import contextmanager
+from dataclasses import is_dataclass
+import os
+from pathlib import Path
+import io
+
+__all__ = [
+ 'Trainer',
+]
+
+from .loops import Loop, TrainBatchLoop
+from .utils import State, TrainerState
+from .utils.utils import check_evaluate_every
+from .evaluator import Evaluator
+from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader
+from fastNLP.core.callbacks import Callback, CallbackManager
+from fastNLP.core.callbacks.callback import _CallbackWrapper
+from fastNLP.core.callbacks.callback_manager import prepare_callbacks
+from fastNLP.core.callbacks.callback_event import Event
+from fastNLP.core.drivers import Driver
+from ..drivers.choose_driver import choose_driver
+from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext
+from fastNLP.core.utils.utils import _check_valid_parameters_number
+from fastNLP.envs import rank_zero_call
+from fastNLP.core.log import logger
+from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
+from fastNLP.core.utils.exceptions import EarlyStopException
+from fastNLP.core.dataloaders import OverfitDataLoader
+from fastNLP.core.callbacks.progress_callback import ProgressCallback
+
+
+class Trainer(TrainerEventTrigger):
+ r"""
+ 用于支持快速训练的训练器。
+
+ :param model: 训练所需要的模型,例如 :class:`torch.nn.Module`;
+
+ .. note::
+
+ 当使用 pytorch 时,注意参数 ``model`` 在大多数情况下为 ``nn.Module``。但是您仍能够通过使用一些特定的组合来使用情况,如下所示:
+
+ 1. 当希望使用 ``DataParallel`` 时,您应当使用 ``TorchSingleDriver``,意味着您在初始化 ``Trainer`` 时参数 ``device`` 不应当为
+ 一个 ``List``;
+
+ 2. 当您选择自己初始化 ``init_process_group`` 时(这种情况要求您传入的 ``model`` 参数一定为 ``DistributedDataParallel``),
+ 您应当使用 ``TorchDDPDriver``,意味着您需要通过 ``python -m torch.distributed.launch`` 的方式来启动训练,此时参数 ``device``
+ 应当设置为 None(此时我们会忽略该参数),具体见下面对于参数 ``device`` 的更详细的解释。
+
+ :param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:``["auto", "torch", "paddle", "jittor", "fairscale", "deepspeed", "oneflow"]``:
+
+ 1. 值为 ``"auto"`` 时,**fastNLP** 会根据传入模型的类型自行判断使用哪一种模式;
+ 2. 其值为 ``"torch"`` 时,表示使用 :class:`~fastNLP.core.drivers.TorchSingleDriver` 或者 :class:`~fastNLP.core.drivers.TorchDDPDriver`;
+ 3. 其值为 ``"torch_fsdp"`` 时,表示使用 :class:`~fastNLP.core.drivers.TorchFSDPDriver`;
+ 4. 其值为 ``"paddle"`` 时,表示使用 :class:`~fastNLP.core.drivers.PaddleSingleDriver` 或者 :class:`~fastNLP.core.drivers.PaddleFleetDriver`;
+ 5. 其值为 ``"jittor"`` 时,表示使用 :class:`~fastNLP.core.drivers.JittorSingleDriver` 或者 :class:`~fastNLP.core.drivers.JittorMPIDriver`;
+ 6. 其值为 ``"fairscale"`` 时,表示使用 :class:`~fastNLP.core.drivers.FairScaleDriver`;
+ 7. 其值为 ``"deepspeed"`` 时,表示使用 :class:`~fastNLP.core.drivers.DeepSpeedDriver`;
+ 8. 其值为 ``"oneflow"`` 时,表示使用 :class:`~fastNLP.core.drivers.OneflowSingleDriver` 或者 :class:`~fastNLP.core.drivers.OneflowDDPDriver`;
+
+ 在指定了框架的情况下,具体使用哪一种取决于参数 ``device`` 的设置;
+
+ .. warning::
+
+ 因为设计上的原因,您可以直接传入一个初始化好的 ``driver`` 实例,但是需要注意的是一个 ``Driver`` 在初始化时需要 ``model`` 这一参数,
+ 这意味着当您传入一个 ``Driver`` 实例时,您传入给 ``Trainer`` 的 ``model`` 参数将会被忽略;也就是说模型在训练时使用的真正的模型是
+ 您传入的 ``Driver`` 实例中的模型;
+
+ .. note::
+
+ 如果您选择使用 :mod:`deepspeed` 、:mod:`fairscale` 或 :mod:`torch.distributed.fsdp` 进行训练,请不要将 ``driver`` 的值设为 ``'auto'`` 。
+
+ :param train_dataloader: 训练数据集,注意其必须是单独的一个数据集,不能是 :class:`List` 或者 :class:`Dict`;
+
+ .. warning::
+
+ 当使用分布式训练时, **fastNLP** 会默认将 ``dataloader`` 中的 ``Sampler`` 进行处理,以使得在一个 epoch 中,不同卡
+ 用以训练的数据是不重叠的。如果您对 sampler 有特殊处理,那么请将 ``use_dist_sampler`` 参数设置为 ``False`` ,此刻需要由
+ 您自身保证每张卡上所使用的数据是不同的。
+
+ :param optimizers: 训练所需要的优化器;可以是单独的一个优化器实例,也可以是多个优化器组成的 List;
+ :param device: 该参数用来指定具体训练时使用的机器;注意当该参数仅当您通过 ``torch.distributed.launch/run`` 启动时可以为 ``None``,
+ 此时 fastNLP 不会对模型和数据进行设备之间的移动处理,但是您可以通过参数 ``input_mapping`` 和 ``output_mapping`` 来实现设备之间
+ 数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时您也可以通过在 kwargs 添加参数 ``data_device`` 来让我们帮助您将数据
+ 迁移到指定的机器上(注意这种情况理应只出现在用户在 Trainer 实例化前自己构造 DDP 的场景);
+
+ device 的可选输入如下所示:
+
+ * *str*: 例如 ``'cpu'``, ``'cuda'``, ``'cuda:0'``, ``'cuda:1'``, ``'gpu:0'`` 等;
+ * *torch.device*: 例如 ``torch.device("cuda:0")``;
+ * *oneflow.device*:例如 ``oneflow.device("cuda", 0)``;
+ * *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练;如果值为 -1,那么默认使用全部的显卡,此时使用的 driver 实例是 `TorchDDPDriver` 这类
+ 执行分布式训练的 Driver
+ * *list(int)*: 如果多于 1 个device,应当通过该种方式进行设定;注意此时我们一定会使用分布式训练的 Driver ,不管您传入的列表的长度是 1 还是其它值;
+ * *None*: 仅当用户自己通过训练框架提供的并行训练启动脚本开启 ddp 进程时为 None;
+
+ .. note::
+
+ 如果希望使用 ``TorchDDPDriver``,在初始化 ``Trainer`` 时您应当使用::
+
+ Trainer(driver="torch", device=[0, 1])
+
+ 注意如果这时 ``device=[0]``,我们仍旧会使用 ``TorchDDPDriver``。
+
+ 如果希望使用 ``TorchSingleDriver``,则在初始化 ``Trainer`` 时您应当使用::
+
+ Trainer(driver="torch", device=0)
+
+ .. warning::
+
+ 注意参数 ``device`` 仅当您通过训练框架自身的并行训练启动脚本启动 ddp 训练时才允许为 ``None``!
+
+ 例如,在 pytorch 中,当您使用::
+
+ python -m torch.distributed.launch --nproc_per_node 2 train.py
+
+ 来使用 ``TorchDDPDriver`` 时,此时参数 ``device`` 不再有效(不管您是否自己初始化 ``init_process_group``),我们将直接
+ 通过 ``torch.device(f"cuda:{local_rank}")`` 来获取当前进程所使用的的具体的 gpu 设备。因此此时您需要使用 ``os.environ["CUDA_VISIBLE_DEVICES"]``
+ 来指定要使用的具体的 gpu 设备。
+
+ 另一点需要注意的是,当您没有选择自己初始化 ``init_process_group`` 时,我们仍旧会帮助您把模型和数据迁移到当前进程所使用的
+ 具体的 gpu 设备上。但是如果您选择自己在 ``Trainer`` 初始化前(意味着在 ``driver`` 的 ``setup`` 前)初始化 ``init_process_group``,
+ 那么对于模型的迁移应当完全由您自己来完成。此时对于数据的迁移,如果您在 ``Trainer`` 初始化时指定了参数 ``data_device``,那么
+ 我们会将数据迁移到 ``data_device`` 上;如果其为 None,那么将数据迁移到正确的设备上应当由您自己来完成。
+
+ 对于使用 ``TorchDDPDriver`` 的更多细节,请见 :class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver`。
+
+ :param n_epochs: 训练总共的 epoch 的数量,默认为 20;也可以通过 ``n_batches`` 参数设置总共迭代多少个 ``batch`` 。
+ :param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认
+ 为 ``None``;
+ :param batch_step_fn: 定制每次训练时前向运行一个 batch 的数据所执行的函数。该函数应接受两个参数为 ``trainer`` 和 ``batch``,
+ 不需要要返回值;更详细的使用位置和说明请见 :meth:`~fastNLP.core.controllers.TrainBatchLoop.batch_step_fn`;
+ :param evaluate_batch_step_fn: 定制每次验证时前向运行一个 batch 的数据所执行的函数。该函数应接受的两个参数为 ``evaluator`` 和 ``batch``,
+ 不需要有返回值;可以参考 :meth:`~fastNLP.core.controllers.EvaluateBatchLoop.batch_step_fn`;
+ :param train_fn: 用来控制 ``Trainer`` 在训练的前向传播过程中是调用模型的哪一个函数,例如是 ``train_step`` 还是框架默认的前向接口;
+ 默认为 ``None``,如果该值是 ``None``,那么我们会默认使用 ``train_step`` 当做前向传播的函数,如果在模型的定义类中没有找到该方法,
+ 则使用模型默认的前向传播函数,例如对于 pytorch 来说就是 ``forward``。
+
+ .. note::
+ 在 fastNLP 中,对于训练时使用的前向传播函数的查找逻辑如下所示:
+
+ 1. 如果 ``train_fn`` 为 None,那么在 model 的类 Model 中寻找方法 :meth:`Model.train_step` ;如果没有找到,那么默认使用 :meth:`Model.forward`;
+ 2. 如果 ``train_fn`` 为一个字符串,例如 ``'my_step_fn'``,那么我们首先会在 model 的类 Model 中寻找方法 :meth:`Model.my_step_fn`,
+ 如果没有找到,那么会直接报错;
+
+ :param evaluate_fn: 用来控制 ``Trainer`` 中内置的 ``Evaluator`` 在验证的前向传播过程中是调用模型的哪一个函数,应当为 ``None``
+ 或者一个字符串;其使用方式和 train_fn 类似;具体可见 :class:`~fastNLP.core.controllers.Evaluator`;
+ :param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 ``Callback`` 类;具体可见
+ :class:`~fastNLP.core.callbacks.Callback`;
+ :param metrics: 用于传给 ``Trainer`` 内部的 ``Evaluator`` 实例来进行训练过程中的验证。其应当为一个字典,其中 key 表示 monitor,
+ 例如 ``{"acc1": AccMetric(), "acc2": AccMetric()}``;
+
+ 目前我们支持的 ``metric`` 的种类有以下几种:
+
+ 1. fastNLP 自己的 ``metric``:详见 :class:`~fastNLP.core.metrics.Metric`;
+ 2. torchmetrics;
+ 3. allennlp.training.metrics;
+ 4. paddle.metric;
+
+ :param evaluate_every: 用来控制 ``Trainer`` 内部的 ``Evaluator`` 验证的频率,其可以为负数、正数或者函数:
+
+ 1. 为负数时表示每隔几个 ``epoch`` evaluate 一次;
+ 2. 为正数则表示每隔几个 ``batch`` evaluate 一次;
+ 3. 为函数时表示用户自己传入的用于控制 evaluate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并
+ 返回一个 bool 值,返回为 ``True`` 说明需要进行 evaluate ;将在每个 ``batch`` 结束后调用该函数判断是否需要 evaluate;
+
+ .. note::
+
+ 如果参数 ``evaluate_every`` 为函数,其应当类似:
+
+ >>> def my_evaluate_every(trainer) -> bool:
+ ... if (trainer.global_forward_batches+1) % 1000 == 0:
+ ... return True
+ ... else:
+ ... return False
+
+ 该函数表示当每经过 1000 个 batch,``Trainer`` 中内置的 ``Evaluator`` 就会验证一次;
+
+ 另一个需要注意的事情在于该函数会在每一次 batch 的结尾进行调用,当该函数返回 ``True`` 时,``Evaluator`` 才会进行验证;
+
+ :param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理:
+
+ 1. 如果 ``input_mapping`` 是一个字典:
+
+ 1. 如果此时 batch 也是一个 ``Dict``,那么我们会把 batch 中同样在 ``input_mapping`` 中的 key 修改为 ``input_mapping`` 的对应 ``key`` 的 ``value``;
+ 2. 如果此时 batch 是一个 ``dataclass``,那么我们会先将其转换为一个 ``Dict``,然后再进行上述转换;
+ 3. 如果此时 batch 此时是其它类型,那么我们将会直接报错;
+ 2. 如果 ``input_mapping`` 是一个函数,那么对于取出的 batch,我们将不会做任何处理,而是直接将其传入该函数里;
+
+ 注意该参数会被传进 ``Evaluator`` 中;因此您可以通过该参数来实现将训练数据 batch 移到对应机器上的工作(例如当参数 ``device`` 为 ``None`` 时);
+ 如果 ``Trainer`` 和 ``Evaluator`` 需要使用不同的 ``input_mapping``, 请使用 ``train_input_mapping`` 与 ``evaluate_input_mapping`` 分别进行设置。
+
+ :param output_mapping: 应当为一个字典或者函数。作用和 ``input_mapping`` 类似,区别在于其用于转换输出:
+
+ 1. 如果 ``output_mapping`` 是一个 ``Dict``,那么我们需要模型的输出必须是 ``Dict`` 或者 ``dataclass`` 类型:
+
+ 1. 如果此时模型的输出是一个 ``Dict``,那么我们会把输出中同样在 ``output_mapping`` 中的 key 修改为 ``output_mapping`` 的对应 key 的 value;
+ 2. 如果此时模型的输出是一个 ``dataclass``,那么我们会先将其转换为一个 Dict,然后再进行上述转换;
+ 2. 如果 ``output_mapping`` 是一个函数,那么我们将会直接将模型的输出传给该函数;
+
+ 如果 ``Trainer`` 和 ``Evaluator`` 需要使用不同的 ``output_mapping``, 请使用 ``train_output_mapping`` 与 ``evaluate_output_mapping`` 分别进行设置;
+
+ .. note::
+
+ ``input_mapping`` 和 ``output_mapping`` 与 fastNLP 的一个特殊的概念 **'参数绑定'** 高度相关,它们的存在也是为了 fastNLP
+ 中的参数匹配能够正确地运行;
+
+ .. todo::
+ 之后链接上 参数匹配 的文档;
+
+ .. warning::
+
+ 如果 ``Trainer`` 的参数 ``output_mapping`` 不为 ``None``,请保证其返回的一定是一个字典,并且其中含有关键字 **'loss'**;
+
+ :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为;
+
+ 1. 如果该值为 ``False``,并且当 batch 为字典时,我们会根据 **前向函数** 所需要的参数从 batch 中提取对应的对象,然后传入到 **前向函数** 中;
+ 2. 如果该值为 ``True``,那么我们会将 batch 直接透传给模型;
+
+ .. todo::
+ 之后链接上 参数匹配 的文档;
+
+ 函数 ``auto_param_call`` 详见 :func:`fastNLP.core.utils.auto_param_call`;
+
+ :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 才让优化器迭代一次,默认为 1;
+ :param fp16: 是否开启混合精度训练,默认为 False;
+ :param monitor: 对于一些特殊的 ``Callback``,例如 :class:`~fastNLP.core.callbacks.CheckpointCallback`,它们需要参数 ``monitor``
+ 来从 ``Evaluator`` 的验证结果中获取当前评测的值,从而来判断是否执行一些特殊的操作。例如,对于 :class:`~fastNLP.core.callbacks.CheckpointCallback`
+ 而言,如果我们想要每隔一个 epoch 让 ``Evaluator`` 进行一次验证,然后保存训练以来的最好的结果;那么我们需要这样设置:
+
+ .. code-block::
+
+ trainer = Trainer(
+ ...,
+ metrics={'acc': accMetric()},
+ callbacks=[CheckpointCallback(
+ ...,
+ monitor='acc',
+ topk=1
+ )]
+ )
+
+ 这意味着对于 :class:`~fastNLP.core.callbacks.CheckpointCallback` 来说,*'acc'* 就是一个监测的指标,用于在 ``Evaluator`` 验证后取出其需要监测的那个指标的值。
+
+ ``Trainer`` 中的参数 ``monitor`` 的作用在于为没有设置 ``monitor`` 参数但是需要该参数的 *callback* 实例设置该值。关于 ``monitor``
+ 参数更详细的说明,请见 :class:`~fastNLP.core.callbacks.CheckpointCallback`;
+
+ 注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效;
+
+ :param larger_better: 对于需要参数 ``monitor`` 的 *callback* 来说,``monitor`` 的值是否是越大越好;类似于 ``monitor``,其作用
+ 在于为没有设置 ``larger_better`` 参数但是需要该参数的 *callback* 实例设置该值;
+
+ 注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效;
+
+ :param n_batches: 总共迭代多少个 ``batch`` 的训练结束。当该值不为 -1 时,将直接忽略 ``n_epochs`` 的值。
+ :param overfit_batches: 使用该参数来支持 **'过拟合'** 的功能;支持的值为 ``-1``、``0`` 或者 大于 0 的整数,表示使用多少个 batch 的数据
+ 来进行过拟合训练;其中 0 为表示不进行任何操作;-1 表示使用所有的数据进行训练;
+
+ .. note::
+
+ 您可以使用该参数来简单地查看您的模型是否是 '正确的',即您的模型是否能够在少量的数据上快速进行收敛,从而说明损失函数以及优化器等
+ 没有问题。当使用该参数时,我们会直接从 ``train_dataloader`` 中提取固定数量的 batch,然后在所有 epoch 中都是用这些数据
+ 来进行训练;
+
+ .. warning::
+
+ 在使用该参数时,您同样可以指定 ``metrics`` 参数来进行简单的验证,当该参数和 ``metrics`` 同时出现时,我们会将 ``evaluate_dataloaders``
+ 直接替换为在过拟合中所使用的训练数据;因此您需要保证您的 ``metrics`` 是能够在 ``train_dataloader`` 上使用的;
+
+ :param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 ``None``;
+
+ .. note::
+
+ marker 的使用场景主要在于如果一个脚本中含有多个 ``Trainer`` 实例,并且含有多个使用 ``Trainer.on`` 修饰的函数时,不同的函数属于
+ 不同的 ``Trainer`` 实例;
+
+ 此时,通过将修饰器 ``Trainer.on`` 的参数 ``marker`` 和 ``Trainer`` 的参数 ``marker`` 置为相同,就可以使得该函数只会在这一
+ ``Trainer`` 实例中被调用;例如,
+
+ .. code-block::
+
+ @Trainer.on(Event.on_train_begin(), marker='trainer1')
+ def fn(trainer):
+ ...
+
+ trainer = Trainer(
+ ...,
+ marker='trainer1'
+ )
+
+ 另一点需要说明的是,如果一个被 ``Trainer.on`` 修饰的函数,其修饰时没有指明 ``marker``,那么会将该函数传给代码位于其之后的
+ 第一个 ``Trainer`` 实例,即使该 ``Trainer`` 实例的 marker 不为 ``None``;这一点详见 :meth:`~fastNLP.core.controllers.Trainer.on`
+
+ :kwargs:
+ * *torch_kwargs* -- ``TorchDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.TorchSingleDriver` 和
+ :class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver`;
+
+ .. note::
+
+ 注意如果对于 ``TorchDDPDriver`` 中初始化 ``DistributedDataParallel`` 时有特别的参数,您可以通过在 ``torch_kwargs`` 中传入
+ ``ddp_kwargs`` 来实现,例如:
+
+ .. code-block::
+
+ trainer = Trainer(
+ ...,
+ torch_kwargs = {'ddp_kwargs': {'find_unused_parameters': True, ...}}
+ )
+
+ 对于 ``TorchFSDPDriver`` 也是类似,只是对应的 ``**_kwargs`` 修改为 ``fsdp_kwargs``;
+
+ * *paddle_kwargs* -- ``PaddleDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.paddle_driver.PaddleSingleDriver` 和
+ :class:`~fastNLP.core.drivers.paddle_driver.PaddleSingleDriver`;
+ * *fairscale_kwargs* -- ``FairScaleDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.FairScaleDriver`;
+ * *deepspeed_kwargs* -- ``DeepSpeedDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.DeepSpeedDriver`;
+ * *oneflow_kwargs* -- ``OneflowDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.oneflow_driver.OneflowSingleDriver` 和
+ :class:`~fastNLP.core.drivers.oneflow_driver.OneflowDDPDriver`;
+ * *data_device* -- 一个具体的 driver 实例中,有 ``model_device`` 和 ``data_device``,前者表示模型所在的设备,后者表示
+ 当 ``model_device`` 为 None 时应当将数据迁移到哪个设备;
+
+ .. note::
+
+ **注意您在绝大部分情况下不会用到该参数!**
+
+ 1. 当 driver 实例的 ``model_device`` 不为 None 时,该参数无效;
+ 2. 对于 **pytorch**,仅当用户自己通过 ``python -m torch.distributed.launch`` 并且自己初始化 ``init_process_group`` 时,
+ driver 实例的 ``model_device`` 才会为 None;
+ 2. 对于 **deepspeed**,仅当用户自己通过 ``deepspeed xxx.py`` 并且自己初始化 ``model.initialize`` 时,
+ driver 实例的 ``model_device`` 才会为 None;
+ 3. 对于 **paddle** 和 **oneflow**,该参数无效;
+
+ * *use_dist_sampler* -- 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch
+ 内所有卡的 sample 加起来为一整个数据集的 sample,同时为了保证所有卡上拥有相同数量的 sample ,有的卡上可能会有重复的 sample ,例如
+ 8卡训练,只有9个 sample ,如果 batch_size 为 1,那么第二个 batch 时,有7张卡将没有 sample 可用,因此只有 **重复** 使用 sample 来 pad 到第二个
+ batch 中。如果不希望 fastNLP 对 dataloader 的 sampler 做特殊设置,请将该值设置为 False ,若确实需要分布式的训练,请在 Trainer 外
+ 对 ``train_dataloader`` 做的数据做特殊处理使得其在不同的卡之间 sample 是不同的。
+ * *evaluate_use_dist_sampler* -- 表示在 ``Evaluator`` 中在使用分布式的时候是否将保证 dataloader 的 ``sampler`` 替换为
+ evaluate 时使用的分布式的 ``sampler``,其特点是每个卡上的数据之间不重叠,所有卡上数据的加起来是整个数据集。若传入的 dataloader
+ 的 sampler 为:
+
+ - 深度学习框架自带的默认 sampler ;
+ - fastNLP 的 Sampler ;
+ 则将替换为 :class:`~fastNLP.UnrepeatedSequentialSampler`,如果这个行为不是期待的,请本参数设置为 ``False``,并针对每个卡控制其可以
+ 用到的数据。
+ * *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一:
+ ``["all", "ignore", "only_error"]`` ,分别代表 *全部输出*、 *全部忽略* 和 *仅输出错误* ,而 rank0 的 **所有信息** 都将被打印出来;
+ 当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 ``"only_error"``;
+
+ 注意该参数仅当使用分布式的 ``driver`` 时才有效,例如 ``TorchDDPDriver``;
+ * *progress_bar* -- 显示进度条的方式,目前支持 ``[None, 'raw', 'rich', 'auto', 'tqdm']`` 或者 :class:`~fastNLP.RichCallback` 、 :class:`~fastNLP.RawTextCallback` 等对象,
+ 默认为 ``'auto'`` , ``'auto'`` 表示如果检测到当前 terminal 为交互型则使用 :class:`~fastNLP.RichCallback`,否则使用 :class:`~fastNLP.RawTextCallback` 对象。如果
+ 需要定制 progress bar 的参数,例如打印频率等,可以传入 :class:`~fastNLP.RichCallback`, :class:`~fastNLP.RawTextCallback` 等对象。
+ * *train_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Trainer`` 中。与 input_mapping 互斥。
+ * *train_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Trainer`` 中。与 output_mapping 互斥。
+ * *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Evaluator`` 中。与 input_mapping 互斥。
+ * *evaluate_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Evaluator`` 中。与 output_mapping 互斥。
+ * *check_dataloader_legality* -- 是否检查 ``DataLoader`` 是否合法,默认为 ``True`` 。
+
+ .. note::
+ ``Trainer`` 是通过在内部直接初始化一个 ``Evaluator`` 来进行验证;
+ ``Trainer`` 内部的 ``Evaluator`` 默认是 None,如果您需要在训练过程中进行验证,您需要保证这几个参数得到正确的传入:
+
+ 必须的参数:``metrics`` 与 ``evaluate_dataloaders``;
+
+ 可选的其它参数:``evaluate_batch_step_fn``、 ``evaluate_fn``、``evaluate_every``、``input_mapping``、
+ ``output_mapping``、``model_wo_auto_param_call``、``fp16``、``monitor``、``larger_better``;
+
+ .. warning::
+
+ 如果 ``Trainer`` 中内置的 ``Evaluator`` 实例不为 ``None``,那么需要注意 ``Trainer`` 中的一些参数是与 ``Evaluator`` 一致的,它们分别为:
+
+ 1. ``Evaluator`` 在初始化时的 ``driver`` 参数是 ``Trainer`` 中已经实例化过的 driver;这一点使得一些参数对于 ``Trainer`` 内部的
+ ``Evaluator`` 没有用处,例如 ``device``,``torch_kwargs``,``data_device`` 和 ``output_from_new_proc`` 等;
+ 2. ``input_mapping``,``output_mapping``,``model_wo_auto_param_call`` 和 ``fp16`` 是 ``Trainer`` 和其内部默认的
+ ``Evaluator`` 是一致的;
+
+ 当然,对于 ``input_mapping`` 和 ``output_mapping``,您可以通过添加 ``kwargs`` 中的参数 ``evaluate_input_mapping`` 和
+ ``evaluate_output_mapping`` 来单独为 ``Evaluator`` 进行更细致的订制。
+
+ 另一方面,注意一些专门独属于 ``Evaluator`` 的参数仅当 ``Evaluator`` 不为 None 时才会生效。
+
+ """
+
+ _custom_callbacks: dict = defaultdict(list)
+
+ def __init__(
+ self,
+ model,
+ train_dataloader,
+ optimizers,
+ driver: str = "auto",
+ device: Optional[Union[int, List[int], str]] = "cpu",
+ n_epochs: int = 20,
+ evaluate_dataloaders=None,
+ batch_step_fn: Optional[Callable] = None,
+ evaluate_batch_step_fn: Optional[Callable] = None,
+ train_fn: Optional[str] = None,
+ evaluate_fn: Optional[str] = None,
+ callbacks: Union[List[Callback], Callback, None] = None,
+ metrics: Optional[dict] = None,
+ evaluate_every: Optional[Union[int, Callable]] = -1,
+ input_mapping: Optional[Union[Callable, Dict]] = None,
+ output_mapping: Optional[Union[Callable, Dict]] = None,
+ model_wo_auto_param_call: bool = False,
+ accumulation_steps: int = 1,
+ fp16: bool = False,
+ monitor: Union[str, Callable] = None,
+ larger_better: bool = True,
+ n_batches: int = -1,
+ overfit_batches: int = 0,
+ marker: Optional[str] = None,
+ **kwargs
+ ):
+
+ self.model = model
+ self.marker = marker
+ if isinstance(driver, str):
+ self.driver_name = driver
+ else:
+ self.driver_name = driver.__class__.__name__
+ self.device = device
+ if train_dataloader is None:
+ raise ValueError("Parameter `train_dataloader` can not be None.")
+ self.train_dataloader = train_dataloader
+ self.evaluate_dataloaders = evaluate_dataloaders
+ self.optimizers = optimizers
+ self.fp16 = fp16
+
+ train_input_mapping = kwargs.get('train_input_mapping', None)
+ train_output_mapping = kwargs.get('train_output_mapping', None)
+ evaluate_input_mapping = kwargs.get('evaluate_input_mapping', None)
+ evaluate_output_mapping = kwargs.get('evaluate_output_mapping', None)
+
+ train_input_mapping, train_output_mapping, evaluate_input_mapping, evaluate_output_mapping = \
+ _get_input_output_mapping(input_mapping, output_mapping, train_input_mapping, train_output_mapping,
+ evaluate_input_mapping, evaluate_output_mapping)
+
+ self.input_mapping = train_input_mapping
+ self.output_mapping = train_output_mapping
+ self.evaluate_fn = evaluate_fn
+
+ self.batch_step_fn = batch_step_fn
+ if batch_step_fn is not None:
+ _check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn')
+ self.check_batch_step_fn = partial(self._check_callback_called_legality, check_mode=True)
+ else:
+ self.check_batch_step_fn = lambda *args, **kwargs: ...
+ # 该变量表示是否检测过 `train_batch_loop`,主要用于当用户通过属性替换的方式使用自己定制的 `train_batch_loop` 时,我们需要检测
+ # 用户是否正确地调用了 callback 函数以及是否正确地更新了 `trainer_state` 的状态;
+ # 我们将其默认值置为 True,这表示默认的 `train_batch_loop` 已经检测过,不需要再进行检测;
+ # 我们只会在第一个 epoch 运行完后进行检测,之后的 epoch 不会再进行检测;
+ self.has_checked_train_batch_loop = True
+ self._train_batch_loop = TrainBatchLoop(batch_step_fn=batch_step_fn)
+
+ if not isinstance(accumulation_steps, int):
+ raise ValueError("Parameter `accumulation_steps` can only be `int` type.")
+ elif accumulation_steps < 0:
+ raise ValueError("Parameter `accumulation_steps` can only be bigger than 0.")
+ self.accumulation_steps = accumulation_steps
+
+ # todo 思路大概是,每个driver提供一下自己的参数是啥(需要对应回初始化的那个),然后trainer/evalutor在初始化的时候,就检测一下自己手上的参数和driver的是不是一致的,不一致的地方需要warn用户说这些值driver不太一样。感觉可以留到后面做吧
+ self.driver = choose_driver(
+ model=model,
+ driver=driver,
+ train_dataloader=train_dataloader,
+ optimizers=optimizers,
+ device=device,
+ n_epochs=n_epochs,
+ evaluate_dataloaders=evaluate_dataloaders,
+ batch_step_fn=batch_step_fn,
+ evaluate_batch_step_fn=evaluate_batch_step_fn,
+ evaluate_fn=evaluate_fn,
+ callbacks=callbacks,
+ metrics=metrics,
+ evaluate_every=evaluate_every,
+ input_mapping=train_input_mapping,
+ output_mapping=train_output_mapping,
+ model_wo_auto_param_call=model_wo_auto_param_call,
+ accumulation_steps=accumulation_steps,
+ fp16=fp16,
+ n_batches=n_batches,
+ marker=marker,
+ **kwargs
+ )
+ self.driver.set_optimizers(optimizers=optimizers)
+
+ # 根据 progress_bar 参数选择 ProgressBarCallback
+ self.progress_bar = kwargs.get('progress_bar', 'auto')
+ callbacks = prepare_callbacks(callbacks, self.progress_bar)
+ # 初始化 callback manager;
+ self.callback_manager = CallbackManager(callbacks)
+ # 添加所有的函数式 callbacks;
+ self._fetch_matched_fn_callbacks()
+ # 添加所有的类 callbacks;
+ self.callback_manager.initialize_class_callbacks()
+
+ # 初始化 state,包括提供给用户的接口和我们自己使用的接口;
+ self.state = State()
+ self.trainer_state = TrainerState(
+ n_epochs=n_epochs if n_batches==-1 else None,
+ cur_epoch_idx=0,
+ global_forward_batches=0,
+ batch_idx_in_epoch=0,
+ num_batches_per_epoch=None, # 会在具体的 train_batch_loop 中进行初始化;
+ n_batches=n_batches
+ )
+
+ if metrics is not None and evaluate_dataloaders is None:
+ raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloaders'.")
+
+ self.metrics = metrics
+ self.evaluate_every = evaluate_every
+
+ self.driver.setup()
+ self.driver.barrier()
+
+ # check train_dataloader
+ if kwargs.get('check_dataloader_legality', True):
+ try:
+ self.driver.check_dataloader_legality(dataloader=train_dataloader)
+ except TypeError as e:
+ logger.error("`train_dataloader` is invalid.")
+ raise e
+
+ use_dist_sampler = kwargs.get("use_dist_sampler", self.driver.is_distributed())
+ if use_dist_sampler:
+ _dist_sampler = "dist"
+ else:
+ _dist_sampler = None
+
+ self.dataloader = self.train_dataloader
+ self.driver.set_deterministic_dataloader(self.dataloader)
+
+ self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler,
+ reproducible=self.callback_manager._need_reproducible_sampler)
+ # 进行 overfit 相关的设置;
+ if overfit_batches != 0:
+ self.dataloader = OverfitDataLoader(self.dataloader, overfit_batches)
+ self.overfit_batches = overfit_batches
+
+ self.evaluator = None
+ self.monitor = monitor
+ self.larger_better = larger_better
+ if metrics is not None:
+ if overfit_batches != 0:
+ evaluate_dataloaders = self.dataloader
+ if evaluate_dataloaders is not None:
+ check_evaluate_every(evaluate_every)
+ progress_bar_name = None
+ for callback in self.callback_manager.class_callbacks:
+ if isinstance(callback, ProgressCallback):
+ progress_bar_name = callback.name
+ self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics,
+ driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn,
+ evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping,
+ output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0,
+ use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler),
+ progress_bar=progress_bar_name,
+ check_dataloader_legality=kwargs.get('check_dataloader_legality', True))
+ else:
+ raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.")
+
+ if train_fn is not None and not isinstance(train_fn, str):
+ raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.")
+ self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn)
+ self.train_fn = train_fn
+
+ self.evaluate_batch_step_fn = evaluate_batch_step_fn
+ self.kwargs = kwargs
+
+ self.on_after_trainer_initialized(self.driver)
+ self.driver.barrier()
+
+ def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1,
+ num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True,
+ catch_KeyboardInterrupt = None):
+ r"""
+ 该函数是在 ``Trainer`` 初始化后用于真正开始训练的函数;
+
+ 注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 ``resume_from`` 为 ``None``,并且使用 ``CheckpointCallback``
+ 去保存断点重训的文件;
+
+ :param num_train_batch_per_epoch: 每个 epoch 训练多少个 batch 后停止,*-1* 表示使用 ``train_dataloader`` 本身的长度;
+ :param num_eval_batch_per_dl: 每个 ``evaluate_dataloader`` 验证多少个 batch 停止,*-1* 表示使用 ``evaluate_dataloader`` 本身的长度;
+ :param num_eval_sanity_batch: 在训练之前运行多少个 evaluation batch 来检测一下 evaluation 的过程是否有错误。为 0 表示不检测;
+ :param resume_from: 从哪个路径下恢复 trainer 的状态,注意该值需要为一个文件夹,例如使用 ``CheckpointCallback`` 时帮助您创建的保存的子文件夹;
+ :param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态;该参数如果为 ``True``,
+ 在下一次断点重训的时候我们会精确到上次训练截止的具体的 sample 进行训练;否则我们只会恢复 model 和 optimizers 的状态,而 ``Trainer`` 中的
+ 其余状态都是保持初始化时的状态不会改变。仅当传入了 resume_from 参数时有意义。
+ :param catch_KeyboardInterrupt: 是否捕获 :class:`KeyboardInterrupt`;如果该参数为 ``True``,在训练时如果您使用 ``ctrl+c`` 来终止程序,
+ ``Trainer`` 不会抛出异常,但是会提前退出,然后 ``trainer.run()`` 之后的代码会继续运行。注意该参数在您使用分布式训练的 ``Driver``
+ 时无效,例如 ``TorchDDPDriver``;非分布式训练的 ``Driver`` 下该参数默认为 True;
+
+ .. warning::
+
+ 注意初始化的 ``Trainer`` 只能调用一次 ``run`` 函数,即之后的调用 ``run`` 函数实际不会运行,因为此时
+ ``trainer.cur_epoch_idx == trainer.n_epochs``;
+
+ 这意味着如果您需要再次调用 ``run`` 函数,您需要重新再初始化一个 ``Trainer``;
+
+ .. note::
+
+ 您可以使用 ``num_train_batch_per_epoch`` 来简单地对您的训练过程进行验证,例如,当您指定 ``num_train_batch_per_epoch=10`` 后,
+ 每一个 epoch 下实际训练的 batch 的数量则会被修改为 10。您可以先使用该值来设定一个较小的训练长度,在验证整体的训练流程没有错误后,再将
+ 该值设定为 **-1** 开始真正的训练;
+
+ ``num_eval_batch_per_dl`` 的意思和 ``num_train_batch_per_epoch`` 类似,即您可以通过设定 ``num_eval_batch_per_dl`` 来验证
+ 整体的验证流程是否正确;
+
+ ``num_eval_sanity_batch`` 的作用可能会让人产生迷惑,其本质和 ``num_eval_batch_per_dl`` 作用一致,但是其只被 ``Trainer`` 使用;
+ 并且其只会在训练的一开始使用,意思为:我们在训练的开始时会先使用 ``Evaluator`` (如果其不为 ``None``) 进行验证,此时验证的 batch 的
+ 数量只有 ``num_eval_sanity_batch`` 个;但是对于 ``num_eval_batch_per_dl`` 而言,其表示在实际的整体的训练过程中,每次 ``Evaluator``
+ 进行验证时会验证的 batch 的数量。
+
+ 并且,在实际真正的训练中,``num_train_batch_per_epoch`` 和 ``num_eval_batch_per_dl`` 应当都被设置为 **-1**,但是 ``num_eval_sanity_batch``
+ 应当为一个很小的正整数,例如 2;
+
+ .. note::
+
+ 参数 ``resume_from`` 和 ``resume_training`` 的设立是为了支持断点重训功能;仅当 ``resume_from`` 不为 ``None`` 时,``resume_training`` 才有效;
+
+ 断点重训的意思为将上一次训练过程中的 ``Trainer`` 的状态保存下来,包括模型和优化器的状态、当前训练过的 epoch 的数量、对于当前的 epoch
+ 已经训练过的 batch 的数量、callbacks 的状态等等;然后在下一次训练时直接加载这些状态,从而直接恢复到上一次训练过程的某一个具体时间点的状态开始训练;
+
+ fastNLP 将断点重训分为了 **保存状态** 和 **恢复断点重训** 两部分:
+
+ 1. 您需要使用 ``CheckpointCallback`` 来保存训练过程中的 ``Trainer`` 的状态;具体详见 :class:`~fastNLP.core.callbacks.CheckpointCallback`;
+ ``CheckpointCallback`` 会帮助您把 ``Trainer`` 的状态保存到一个具体的文件夹下,这个文件夹的名字由 ``CheckpointCallback`` 自己生成;
+ 2. 在第二次训练开始时,您需要找到您想要加载的 ``Trainer`` 状态所存放的文件夹,然后传入给参数 ``resume_from``;
+
+ 需要注意的是 **保存状态** 和 **恢复断点重训** 是互不影响的。
+ """
+
+ if catch_KeyboardInterrupt is None:
+ catch_KeyboardInterrupt = not self.driver.is_distributed()
+ else:
+ if self.driver.is_distributed():
+ if catch_KeyboardInterrupt:
+ logger.rank_zero_warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device "
+ "driver. And we are gonna to set it to False.")
+ catch_KeyboardInterrupt = False
+
+ self._set_num_eval_batch_per_dl(num_eval_batch_per_dl)
+
+ if resume_from is not None:
+ if os.path.exists(resume_from):
+ self.load_checkpoint(resume_from, resume_training=resume_training)
+ else:
+ raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.")
+
+ if self.evaluator is not None and num_eval_sanity_batch != 0:
+ logger.info(f"Running evaluator sanity check for {num_eval_sanity_batch} batches.")
+ self.on_sanity_check_begin()
+ sanity_check_res = self.evaluator.run(num_eval_batch_per_dl=num_eval_sanity_batch)
+ self.on_sanity_check_end(sanity_check_res)
+
+ if num_train_batch_per_epoch != -1:
+ self.dataloader = _TruncatedDataLoader(self.dataloader, num_train_batch_per_epoch)
+
+ self.num_batches_per_epoch = len(self.dataloader)
+ if self.n_batches == -1:
+ self.n_batches = self.num_batches_per_epoch * self.n_epochs
+ else:
+ self.n_epochs = (self.n_batches+self.num_batches_per_epoch-1)//self.num_batches_per_epoch
+
+ self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch
+
+ try:
+ self.on_train_begin()
+ self.driver.barrier()
+ self.driver.zero_grad()
+ while self.cur_epoch_idx < self.n_epochs and self.global_forward_batches < self.n_batches:
+ # 这个是防止在 Trainer.load_checkpoint 之后还没结束当前 epoch 又继续 save
+ self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch
+ self.driver.set_model_mode("train")
+ self.on_train_epoch_begin()
+ self.driver.set_sampler_epoch(self.dataloader, self.cur_epoch_idx)
+ self.train_batch_loop.run(self, self.dataloader)
+ if not self.has_checked_train_batch_loop:
+ self._check_train_batch_loop_legality()
+ self.cur_epoch_idx += 1
+ self.on_train_epoch_end()
+ self.driver.barrier()
+ self.epoch_evaluate()
+ self.driver.barrier()
+
+ except EarlyStopException as e:
+ logger.info(f"Catch early stop exception: {e.msg}.")
+ self.on_exception(e)
+ except KeyboardInterrupt as e:
+ self.driver.on_exception()
+ self.on_exception(e)
+ if not catch_KeyboardInterrupt:
+ raise e
+ except RuntimeError as e:
+ if 'torch' in self.driver_name.lower() and len(e.args) > 0: # 如果是 torch ,需要检测一下 find_unused_parameters
+ if 'find_unused_parameters' in e.args[0]:
+ logger.error("You may need to pass `torch_kwargs={'ddp_kwargs':{'find_unused_parameters': True}}` in the "
+ "Trainer initialization to avoid this error.")
+ self.driver.on_exception()
+ self.on_exception(e)
+ raise e
+ except BaseException as e:
+ self.driver.on_exception()
+ self.on_exception(e)
+ raise e
+ finally:
+ self.on_train_end()
+
+ def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl: int):
+ r"""
+ 用于设定训练过程中 ``Evaluator`` 进行验证时所实际验证的 batch 的数量;
+
+ :param num_eval_batch_per_dl: 等价于 :meth:`~fastNLP.core.controllers.Trainer.run` 中的参数 ``num_eval_batch_per_dl``;
+ """
+ def _evaluate_fn(trainer: Trainer, evaluate_fn: Callable) -> None:
+ trainer.on_evaluate_begin()
+ _evaluate_res: dict = evaluate_fn()
+ trainer.on_evaluate_end(_evaluate_res)
+
+ if self.evaluator is not None:
+ self.run_evaluate = partial(_evaluate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl))
+
+ def step_evaluate(self):
+ r"""
+ 在训练过程中的每个 batch 结束后被调用,注意实际的 ``Evaluator.run`` 函数是否在此时被调用取决于用户设置的 **"验证频率"**;
+ """
+ if self.evaluator is not None:
+ if callable(self.evaluate_every):
+ if self.evaluate_every(self):
+ self.run_evaluate()
+ elif self.evaluate_every > 0 and self.global_forward_batches % self.evaluate_every == 0:
+ self.run_evaluate()
+
+ def epoch_evaluate(self):
+ r"""
+ 在训练过程中的每个 epoch 结束后被调用,注意实际的 ``Evaluator.run`` 函数是否在此时被调用取决于用户设置的 **"验证频率"**;
+ """
+ if self.evaluator is not None:
+ if isinstance(self.evaluate_every, int) and self.evaluate_every < 0:
+ evaluate_every = -self.evaluate_every
+ if self.cur_epoch_idx % evaluate_every == 0:
+ self.run_evaluate()
+
+ def add_callback_fn(self, event: Event, fn: Callable):
+ r"""
+ 在初始化一个 trainer 实例后,您可以使用这一函数来方便地添加 ``callback`` 函数;
+
+ 注意这一函数应当交给具体的 trainer 实例去做,因此不需要 `mark` 参数;
+
+ :param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机;具体有哪些时机详见 :class:`~fastNLP.core.callbacks.Event`;
+ :param fn: 具体的 callback 函数;
+
+ .. note::
+
+ 对于训练一个神经网络的整体的流程来说,其可以分为很多个时间点,例如 **"整体的训练前"**,**"训练具体的一个 epoch 前"**,
+ **"反向传播前"**,**"整体的训练结束后"** 等;一个 ``callback`` 时机指的就是这些一个个具体的时间点;
+
+ 该函数的参数 ``event`` 需要是一个 ``Event`` 实例,其使用方式见下方的例子;
+
+ 一个十分需要注意的事情在于您需要保证您添加的 callback 函数 ``fn`` 的参数与对应的 callback 时机所需要的参数保持一致,更准确地说,
+ 是与 :class:`~fastNLP.core.callbacks.Callback` 中的对应的 callback 函数的参数保持一致;例如如果
+ 您想要在 ``on_after_trainer_initialized`` 这个时机添加一个您自己的 callback 函数,您需要保证其参数为 ``trainer, driver``;
+
+ 最后用一句话总结:对于您想要加入的一个 callback 函数,您首先需要确定您想要将该函数加入的 callback 时机,然后通过 ``Event.on_***()``
+ 拿到具体的 event 实例;再去 :class:`~fastNLP.core.callbacks.Callback` 中确定该 callback 时机的 callback 函数的参数应当是怎样的;
+
+ 例如:
+
+ .. code-block::
+
+ from fastNLP import Trainer, Event
+
+ # Trainer 初始化
+ trainer = Trainer(...)
+
+ # 定义您自己的 callback 函数,需要注意的是该函数的参数需要与您要添加的 callback 时机所需要的参数保持一致;因为我们要将该函数加入到
+ # on_after_trainer_initialized 这个 callback 时机,因此我们这里的
+ def my_callback_fn(trainer, driver):
+ # do something
+ # 您可以在函数内部使用 trainer 和 driver,我们会将这两个实例注入进去;
+
+ # 添加到 trainer 中;
+ trainer.add_callback_fn(Event.on_after_trainer_initialized(), my_callback_fn)
+
+ .. note::
+
+ 该函数与 ``Trainer.on`` 函数提供的作用相同,它们所需要的参数也基本相同,区别在于 ``Trainer.on`` 用于 ``Trainer`` 初始化前,而
+ ``Trainer.add_callback_fn`` 用于 ``Trainer`` 初始化之后;
+
+ 更为具体的解释见 :meth:`~fastNLP.core.controllers.Trainer.on`;
+
+ """
+ if not isinstance(event, Event):
+ raise ValueError("parameter event should only be `Event` type.")
+
+ _custom_callback = _CallbackWrapper(event, fn)
+ self.callback_manager.dissect_one_callback(_custom_callback)
+
+ @classmethod
+ def on(cls, event: Event, marker: Optional[str] = None):
+ r"""
+ 函数修饰器,用户可以使用该函数来方便地将一个函数转变为 callback 函数,从而进行训练流程中的控制;
+
+ 支持的 event 时机有以下这些,其执行的时机顺序也如下所示。每个时机装饰的函数应该接受的参数列表也如下所示,例如::
+
+ Trainer.__init__():
+ on_after_trainer_initialized(trainer, driver)
+ Trainer.run():
+ # load checkpoint if resume_from is not None
+ if num_eval_sanity_batch>0:
+ on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch
+ on_sanity_check_end(trainer, sanity_check_res)
+ try:
+ on_train_begin(trainer)
+ while cur_epoch_idx < n_epochs:
+ on_train_epoch_begin(trainer)
+ while batch_idx_in_epoch<=num_batches_per_epoch:
+ on_fetch_data_begin(trainer)
+ batch = next(dataloader)
+ on_fetch_data_end(trainer)
+ on_train_batch_begin(trainer, batch, indices)
+ on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。
+ on_after_backward(trainer)
+ on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
+ on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
+ on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
+ on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
+ on_train_batch_end(trainer)
+ on_train_epoch_end(trainer)
+ except BaseException:
+ self.on_exception(trainer, exception)
+ finally:
+ on_train_end(trainer)
+
+ 其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/
+ on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中
+ 特定的时间调用。
+
+ .. note::
+
+ 对于 event 的解释,建议先阅读 :meth:`~fastNLP.core.controllers.Trainer.add_callback_fn` 的文档;
+
+ 当生成一个具体的 ``Event`` 实例时,可以指定 ``every、once、filter_fn`` 这三个参数来控制您的 callback 函数的调用频率,例如当您
+ 指定 ``Event.on_train_epoch_begin(every=3)`` 时,其表示每隔三个 epoch 运行一次您的 callback 函数;对于这三个参数的更具体的解释,
+ 请见 :class:`~fastNLP.core.callbacks.Event`;
+
+ Example1::
+
+ from fastNLP import Event
+ @Trainer.on(Event.on_save_model())
+ def do_something_1(trainer):
+ # do something
+ # 以上函数会在 Trainer 保存模型时执行。
+
+ @Trainer.on(Event.on_save_model(once=True))
+ def do_something_2(trainer):
+ # do something
+ # 以上函数会在 Trainer 保存模型时执行,但只执行一次。
+
+ @Trainer.on(Event.on_train_batch_begin(every=2))
+ def do_something_3(trainer, batch, indices):
+ # do something
+ # 以上函数会在 Trainer 每个新的 batch 开始的时候执行,但是是两个 batch 才执行一次。
+
+ Example2::
+
+ @Trainer.on(Event.on_train_begin())
+ def fn1(trainer):
+ ...
+
+ @Trainer.on(Event.on_train_epoch_begin())
+ def fn2(trainer):
+ ...
+
+ trainer1 = Trainer(
+ ...,
+ marker='trainer1'
+ )
+
+ @Trainer.on(Event.on_fetch_data_begin())
+ def fn3(trainer):
+ ...
+
+ trainer2 = Trainer(
+ ...,
+ marker='trainer2'
+ )
+
+ 这段代码意味着 ``fn1`` 和 ``fn2`` 会被加入到 ``trainer1``,``fn3`` 会被加入到 ``trainer2``;
+
+ 注意如果您使用该函数修饰器来为您的训练添加 callback,请务必保证您加入 callback 函数的代码在实例化 `Trainer` 之前;
+
+ 补充性的解释见 :meth:`~fastNLP.core.controllers.Trainer.add_callback_fn`;
+
+ :param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机。每个时机运行的函数应该包含
+ 特定的参数,可以通过上述说明查阅。
+ :param marker: 用来标记该 callback 函数属于哪几个具体的 trainer 实例;两个特殊情况:1.当 ``marker`` 为 None(默认情况)时,
+ 表示该 callback 函数只属于代码下方最近的一个 trainer 实例;2.当 ``marker`` 为 'all' 时,该 callback 函数会被所有的 trainer
+ 实例使用;
+ :return: 原函数;
+ """
+
+ def wrapper(fn: Callable) -> Callable:
+ callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:]
+ _check_valid_parameters_number(fn, callback_fn_args)
+ cls._custom_callbacks[marker].append((event, fn))
+ return fn
+
+ return wrapper
+
+ def _fetch_matched_fn_callbacks(self):
+ r"""
+ 因为对于使用装饰器加入的函数 callback,我们是加在类属性 ``_custom_callbacks`` 中,因此在初始化一个具体的 trainer 实例后,我们需要从 Trainer 的
+ callback 类属性中将属于其的 callback 函数拿到,然后加入到 ``callback_manager`` 中;
+
+ 这里的主要需要注意的地方在于为了支持没有带 ``marker`` 的 callback 函数赋给下方代码距离其最近的 trainer,在每次收集到 self._custom_callbacks[None] 后将其置为 [];
+ """
+ _own_callbacks: List = copy.deepcopy(self._custom_callbacks["all"])
+ _own_callbacks.extend(self._custom_callbacks[None])
+ logger.debug(f"Get {len(_own_callbacks)} callback fns through Trainer.on().")
+ self._custom_callbacks[None] = []
+ if self.marker is not None:
+ if len(self._custom_callbacks[self.marker]) == 0:
+ logger.info(f"You have set `trainer.marker = {self.marker}`, but there are no callback function matched "
+ f"`{self.marker}` that is added through function `Trainer.on`")
+ _own_callbacks += self._custom_callbacks[self.marker]
+ for each_callback in _own_callbacks:
+ self.add_callback_fn(*each_callback)
+
+ def _check_callback_called_legality(self, check_mode: bool = True):
+ r"""
+ 这个函数主要的作用在于:
+
+ 如果用户定制了训练流程中的一部分,例如 ``batch_step_fn`` 或者 ``TrainBatchLoop``;并且这些部分流程中可能会包含一些 callback
+ 函数的调用;例如 ``train_batch_loop.batch_step_fn`` 中包含 ``on_before_backward`` 等;
+
+ 用户是十分可能忘记在其自己定制的部分流程中实现对这些 callback 函数的调用的;因此需要我们进行检测和提醒;
+
+ 这种检测也十分简单,即如果我们检测到 callback_manager 的某一 callback 函数在训练一段时间(通常是涉及到允许定制的部分流程的结尾)后,
+ 其被调用的次数是 0,那么我们就会打印 ``warning`` 信息;
+
+ 1. 这个函数的调用时机(这个函数会在以下情况被调用):
+
+ 当检测 'batch_step_fn' 时,这个函数应当在 'train_batch_loop.run' 的 while 循环的最后进行调用;
+ 当检测 'TrainBatchLoop' 时,这个函数应当在每一个 epoch 的最后进行调用;
+
+ 2. 这个函数作用的更细致的解释:
+
+ 这一函数的作用在于检查用户定制的 batch_step_fn / TrainBatchLoop 是否能够正确地调用 callback 函数,更准确地说,当用户实际
+ 定制了 ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad",
+ "on_after_zero_grad") /
+ ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end",
+ "on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad",
+ "on_after_zero_grad")
+ 这些 callabck_fn 后,如果其同样也定制了 batch_step_fn / TrainBatchLoop,那么其有可能忘记了在自己的 batch_step_fn 中
+ 上述的这些 callback 函数,而这个函数的作用就在于检测用户是否产生了这一行为;
+
+ 注意,这一函数只会在 batch_step_fn 不为 None 时或者 TrainBatchLoop 没有被替换时才会被调用;
+
+ :param check_mode: 用来判断该函数是用来检测 'batch_step_fn' 还是用来检测 'TrainBatchLoop' 的参数,为 True 时表示检测
+ 'batch_step_fn',为 False 时表示检测 'TrainBatchLoop';
+ """
+ if check_mode:
+ callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step",
+ "on_before_zero_grad", "on_after_zero_grad")
+ else:
+ callbacks = ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end",
+ "on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step",
+ "on_before_zero_grad", "on_after_zero_grad")
+ _not_called_callback_fns = []
+ for each_callback_fn in callbacks:
+ if each_callback_fn in self.callback_manager.callback_fns:
+ if self.callback_manager.callback_counter[each_callback_fn] == 0:
+ _not_called_callback_fns.append(each_callback_fn)
+
+ if check_mode:
+ if len(_not_called_callback_fns) != 0:
+ logger.rank_zero_warning("You have customized your 'batch_step_fn' in the 'train_batch_loop' and also use these "
+ f"callback_fns: {_not_called_callback_fns}, but it seems that"
+ "you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.",
+ once=True)
+ # 对于 'batch_step_fn' 来讲,其只需要在第一次的 step 后进行检测即可,因此在第一次检测后将 check_batch_step_fn 置为 pass
+ # 函数;
+ self.check_batch_step_fn = lambda *args, **kwargs: ...
+ elif len(_not_called_callback_fns)!=0:
+ logger.rank_zero_warning("You have customized your 'TrainBatchLoop' and also use these callback_fns: "
+ f"{_not_called_callback_fns}, but it seems that"
+ "you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.",
+ once=True)
+
+ def _check_train_batch_loop_legality(self):
+ r"""
+ 该函数用于检测用户定制的 `train_batch_loop` 是否正确地调用了 callback 函数以及是否正确地更新了 `trainer_state` 的状态;
+ 该函数仅当用户通过属性更换用自己的定制的 `train_batch_loop` 替换了默认的 `TrainBatchLoop` 对象后才会被调用;
+ 当被调用时,该函数仅当第一次被调用时被调用;
+ """
+ # 1. 检测用户定制的 `train_batch_loop` 是否正确地调用了 callback 函数;
+ self._check_callback_called_legality(check_mode=False)
+
+ # 2. 检测用户定制的 `train_batch_loop` 是否正确地更新了 `trainer_state` 的状态;
+ # 因为该检测函数只会在第一个 epoch 运行完后调用,因此我们只需要检测这些 `trainer_state` 的值是否正确即可;
+ if self.batch_idx_in_epoch == 0:
+ logger.warning("You have customized your `train_batch_loop`, but it seemed that you forget to update the "
+ "`trainer_state.batch_idx_in_epoch` in your process of training. Look the origin class "
+ "`TrainBatchLoop`.")
+ if self.global_forward_batches == 0:
+ logger.warning("You have customized your `train_batch_loop`, but it seemed that you forget to update the "
+ "`trainer_state.global_forward_batches` in your process of training. Look the origin class "
+ "`TrainBatchLoop`.")
+ self.has_checked_train_batch_loop = True
+
+ """ Trainer 需要的一些 property """
+ @property
+ def driver(self):
+ """
+ :return: ``trainer`` 中的 ``driver`` 实例;
+ """
+ return self._driver
+
+ @driver.setter
+ def driver(self, driver: Driver):
+ self._driver = driver
+
+ @property
+ def train_batch_loop(self):
+ """
+ :return: ``trainer`` 中的 ``train_batch_loop`` 实例;
+ """
+ return self._train_batch_loop
+
+ @train_batch_loop.setter
+ def train_batch_loop(self, loop: Loop):
+ self.has_checked_train_batch_loop = False
+ if self.batch_step_fn is not None:
+ logger.warning("`batch_step_fn` was customized in the Trainer initialization, it will be ignored "
+ "when the `train_batch_loop` is also customized.")
+ # 如果用户定制了 TrainBatchLoop,那么我们不需要再专门去检测 batch_step_fn,因为该函数一定会被忽略;
+ self.check_batch_step_fn = lambda *args, **kwargs: ...
+ self._train_batch_loop = loop
+
+ def save_model(self, folder: Union[str, os.PathLike, BinaryIO, io.BytesIO], only_state_dict: bool = False,
+ model_save_fn: Optional[Callable] = None, **kwargs):
+ r"""
+ 用于帮助您保存模型的辅助函数;
+
+ :param folder: 保存模型的文件夹。如果没有传入 ``model_save_fn`` 参数,则我们会在这个文件夹下保存 ``fastnlp_model.pkl.tar`` 文件;
+ :param only_state_dict: 仅在 ``model_save_fn`` 为空时,有效。是否只保存模型的 ``state_dict``;
+ :param model_save_fn: 您自己定制的用来替换该保存函数本身保存逻辑的函数,当您传入了该参数后,我们会实际调用该函数,而不会去调用 ``driver`` 的 ``save_model`` 函数;
+ :kwargs:
+ * *input_spec* -- 该参数详见 **PaddlePaddle** 框架的保存函数 :meth:`~fastNLP.core.drivers.PaddleDriver.save_model` 中的说明;
+
+ .. note::
+
+ 注意如果您需要在训练的过程中保存模型,如果没有特别复杂的逻辑,强烈您使用我们专门为保存模型以及断点重训功能定制的 ``callback``: ``CheckpointCallback``;
+ ``CheckpointCallback`` 的使用具体见 :class:`~fastNLP.core.callbacks.checkpoint_callback.CheckpointCallback`;
+
+ 这意味着在大多数时刻您并不需要自己主动地调用该函数来保存模型;当然您可以在自己定制的 callback 类中通过直接调用 ``trainer.save_model`` 来保存模型;
+
+ 具体实际的保存模型的操作由具体的 driver 实现,这意味着对于不同的 ``Driver`` 来说,保存模型的操作可能是不尽相同的,
+ 您如果想要了解更多的保存模型的细节,请直接查看各个 ``Driver`` 的 ``save_model`` 函数;
+
+ ``save_model`` 函数和 ``load_model`` 函数是配套使用的;
+ """
+
+ self.on_save_model()
+ self.driver.barrier()
+
+ if not isinstance(folder, (io.BytesIO, BinaryIO)):
+ if model_save_fn is not None:
+ if not callable(model_save_fn):
+ raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.")
+ rank_zero_call(model_save_fn)(folder)
+ else:
+ if isinstance(folder, str):
+ folder = Path(folder)
+ self.driver.save_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs)
+ else:
+ if model_save_fn is not None:
+ raise RuntimeError("It is not allowed to specify a `model_save_fn` parameter with `folder` being "
+ "`io.BytesIO` type.")
+ self.driver.save_model(folder, only_state_dict, **kwargs)
+ self.driver.barrier()
+
+ def load_model(self, folder: Union[str, Path, BinaryIO, io.BytesIO], only_state_dict: bool = True,
+ model_load_fn: Optional[Callable] = None, **kwargs):
+ """
+ 用于帮助您加载模型的辅助函数;
+
+ :param folder: 存放着您需要加载的 model 的文件夹,默认会尝试读取该文件夹下的 ``fastnlp_model.pkl.tar`` 文件。在 ``model_load_fn``
+ 不为空时,直接将该 folder 传递到 ``model_load_fn`` 中;
+ :param only_state_dict: 要读取的文件中是否仅包含模型权重。在 ``model_load_fn`` 不为 ``None`` 时,该参数无意义;
+ :param model_load_fn: :class:`Callable` 的函数,接受一个 folder 作为参数,需要注意该函数不需要返回任何内容;
+ :param kwargs: 理论上您不需要使用到该参数;
+
+ .. note::
+
+ 注意您需要在初始化 ``Trainer`` 后再通过 ``trainer`` 实例来调用该函数;这意味着您需要保证在保存和加载时使用的 ``driver`` 是属于同一个
+ 训练框架的,例如都是 ``pytorch`` 或者 ``paddle``;
+
+ 注意在大多数情况下您不需要使用该函数,如果您需要断点重训功能,您可以直接使用 ``trainer.load_checkpoint`` 函数;
+
+ 该函数在通常情况下和 ``save_model`` 函数配套使用;其参数均与 ``save_model`` 函数成对应关系;
+ """
+ self.on_load_model()
+ self.driver.barrier()
+ if not isinstance(folder, (io.BytesIO, BinaryIO)):
+ try:
+ if model_load_fn is not None:
+ if not callable(model_load_fn):
+ raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.")
+ model_load_fn(folder)
+ else:
+ if isinstance(folder, str):
+ folder = Path(folder)
+ self.driver.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs)
+ except FileNotFoundError as e:
+ if FASTNLP_MODEL_FILENAME not in os.listdir(folder):
+ logger.error(f"fastNLP model checkpoint file:{FASTNLP_MODEL_FILENAME} is not found in {folder}.")
+ raise e
+ else:
+ if model_load_fn is not None:
+ raise RuntimeError("It is not allowed to specify a `model_save_fn` parameter with `folder` being "
+ "`io.BytesIO` type.")
+ self.driver.load_model(folder, only_state_dict, **kwargs)
+ self.driver.barrier()
+
+ def save_checkpoint(self, folder: Union[str, Path], only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, **kwargs):
+ r"""
+ 用于帮助您实现断点重训功能的保存函数;保存内容包括:callback 状态、Trainer 的状态、Sampler 的状态【在恢复的时候才能恢复到特定 batch 】、
+ 模型参数、optimizer的状态、fp16 Scaler的状态【如果有】。
+
+ :param folder: 保存在哪个文件夹下,会在该文件下生成两个文件:``fastnlp_checkpoint.pkl.tar`` 与 ``fastnlp_model.pkl.tar`` 。
+ 如果 ``model_save_fn`` 不为空,则没有 ``fastnlp_model.pkl.tar`` 文件;
+ :param only_state_dict: 当 ``model_save_fn`` 为空时有效,表明是否仅保存模型的权重;
+ :param model_save_fn: 如果模型保存比较特殊,可以传入该函数自定义模型的保存过程,输入应该接受一个文件夹(实际上就是接受上面的 folder
+ 参数),不需要返回值;这意味着您可以通过该函数来自己负责模型的保存过程,而我们则会将 ``trainer`` 的状态保存好;
+ :kwargs:
+ * *input_spec* -- 该参数详见 **PaddlePaddle** 框架的保存函数 :meth:`~fastNLP.core.drivers.PaddleDriver.save_model` 中的说明;
+
+ .. note::
+
+ 注意如果您需要在训练的过程中使用断点重训功能,您可以直接使用 ``CheckpointCallback``;
+ ``CheckpointCallback`` 的使用具体见 :class:`~fastNLP.core.callbacks.checkpoint_callback.CheckpointCallback`;
+
+ 这意味着在大多数时刻您并不需要自己主动地调用该函数来保存 ``Trainer`` 的状态;当然您可以在自己定制的 callback 类中通过直接调用 ``trainer.save_checkpoint`` 来保存 ``Trainer`` 的状态;
+
+ 具体实际的保存状态的操作由具体的 driver 实现,这意味着对于不同的 ``Driver`` 来说,保存的操作可能是不尽相同的,
+ 您如果想要了解保存 ``Trainer`` 状态的更多细节,请直接查看各个 ``Driver`` 的 ``save`` 函数;
+
+ ``save_checkpoint`` 函数和 ``load_checkpoint`` 函数是配套使用的;
+
+ .. note::
+
+ 为了支持断点重训功能,我们会在调用该函数时保存以下内容:
+
+ 1. 各个 ``callback`` 的状态,这主要涉及到一些带有运行状态的 ``callback``;
+ 2. 控制训练流程的变量 ``trainer_state``,具体详见 :class:`~fastNLP.core.controllers.utils.state.TrainerState`;
+ 3. 一个特殊的变量 ``num_consumed_batches``,表示在这次训练过程中总共训练了多少个 batch 的数据;您不需要关心这个变量;
+ 4. sampler 的状态,为了支持断点重训功能,我们会在 trainer 初始化的时候,将您的 ``trainer_dataloader`` 的 ``sampler`` 替换为
+ 我们专门用于断点重训功能的 ``ReproducibleSampler``,详见 :class:`~fastNLP.core.samplers.reproducible_sampler.ReproducibleSampler`;
+ 5. model 的状态,即模型参数;
+ 6. optimizers 的状态,即优化器的状态;
+ 7. fp16 的状态;
+
+ .. warning::
+
+ 一个值得注意的问题是 ``Driver`` 在新版 ``fastNLP`` 中的特殊作用,在断点重训时则体现为您应当尽量保证在前后两次训练中使用的 ``Driver``
+ 是一致的,例如您不能在第一次训练时使用 ``pytorch``,而在第二次训练时使用 ``paddle``;或者尽量不要在第一次训练时使用分布式训练,但是
+ 在第二次训练时使用非分布式训练(尽管这一行为的部分情况是支持的,请见下方的说明);
+
+ 但是如果您一定需要在前后使用不同分布式情况的 ``Driver``,那么在简单的默认情况下,我们也还是支持您使用断点重训的,这意味您可以在第一次训练时
+ 使用单卡,但是在第二次训练时使用多卡进行训练;或者反过来;
+
+ 以 ``pytorch`` 为例,这里的简单的默认情况指的是您的 ``train_dataloader`` 所使用的 ``sampler`` 是 ``RandomSampler`` 或者 ``SequentialSampler``;
+ 如果您的 ``sampler`` 是其它类型的 ``sampler``,那么我们仅支持前后两次训练 ``driver`` 严格不变时的断点重训;
+ """
+
+ self.driver.barrier()
+
+ # 1. callback states 和 每一个callback的具体 callback 函数的 filter 的状态;
+ # 2. trainer_state;
+ states = {
+ "callback_states": self.on_save_checkpoint(),
+ "trainer_state": self.trainer_state.state_dict(),
+ 'num_consumed_batches': self.batch_idx_in_epoch - getattr(self, 'start_batch_idx_in_epoch', 0)
+ }
+
+ if isinstance(folder, str):
+ folder = Path(folder)
+
+ if model_save_fn is not None:
+ if not callable(model_save_fn):
+ raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.")
+ rank_zero_call(model_save_fn)(folder)
+ self.driver.save_checkpoint(folder=folder, dataloader=self.dataloader, states=states, should_save_model=False, **kwargs)
+ else:
+ self.driver.save_checkpoint(folder=folder, dataloader=self.dataloader, states=states,
+ only_state_dict=only_state_dict, should_save_model=True, **kwargs)
+
+ self.driver.barrier()
+
+ def load_checkpoint(self, folder: str, resume_training: bool = True, only_state_dict: bool = True,
+ model_load_fn: Optional[Callable] = None, **kwargs):
+ r"""
+ 用于帮助您实现断点重训功能的加载函数;
+
+ :param folder: 保存断点重训时 ``trainer`` 的状态文件的文件夹;
+ :param resume_training: 是否精确到从上次训练时最终截断的那一个 batch 开始训练;如果 ``resume_training=True``,那么我们
+ 只会加载 ``model`` 和 ``optimizers`` 的状态;而其余对象的值则根据用户的 ``Trainer`` 的初始化直接重置;
+ :param only_state_dict: 保存的 ``model`` 是否只保存了权重;
+ :param model_load_fn: 使用的模型加载函数,参数应为一个文件夹,注意该函数不需要返回任何内容;您可以传入该参数来定制自己的加载模型的操作,
+ 当该参数不为 None 时,我们默认加载模型由该函数完成,``trainer.load_checkpoint`` 函数则会把 ``trainer`` 的其余状态加载好;
+
+ .. note::
+
+ 在 fastNLP 中,断点重训的保存和加载的逻辑是完全分离的,这意味着您在第二次训练时可以将 ``CheckpointCallback`` 从 ``trainer`` 中
+ 去除,而直接使用 ``trainer.load_checkpoint`` 函数加载 ``trainer`` 的状态来进行断点重训;
+
+ 该函数在通常情况下和 ``save_checkpoint`` 函数配套使用;其参数与 ``save_checkpoint`` 函数成对应关系;
+
+ 对于在前后两次训练 ``Driver`` 不同的情况时使用断点重训,请参考 :meth:`~fastNLP.core.controllers.trainer.Trainer.load_checkpoint` 函数的 ``warning``;
+
+ Example::
+
+ trainer = Trainer(...)
+
+ trainer.load_checkpoint(folder='/path-to-your-saved_checkpoint_folder/', ...)
+
+ trainer.run()
+
+ """
+
+ self.driver.barrier()
+ if isinstance(folder, str):
+ folder = Path(folder)
+
+ dataloader = self.dataloader
+ if not resume_training:
+ dataloader = None
+ try:
+ if model_load_fn is not None:
+ if not callable(model_load_fn):
+ raise ValueError("Parameter `model_save_fn` should be `Callable`.")
+ model_load_fn(folder)
+ states = self.driver.load_checkpoint(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs)
+ else:
+ states = self.driver.load_checkpoint(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs)
+ except FileNotFoundError as e:
+ if FASTNLP_CHECKPOINT_FILENAME not in os.listdir(folder) and FASTNLP_MODEL_FILENAME in os.listdir(folder):
+ logger.error("It seems that you are trying to load the trainer checkpoint from a model checkpoint folder.")
+ elif FASTNLP_CHECKPOINT_FILENAME not in os.listdir(folder):
+ logger.error(f"fastNLP Trainer checkpoint file:{FASTNLP_CHECKPOINT_FILENAME} is not found in {folder}.")
+ raise e
+
+ if not resume_training:
+ return
+
+ self.dataloader = states.pop('dataloader')
+
+ # 1. 恢复 trainer_state 的状态;
+ self.trainer_state.load_state_dict(states["trainer_state"])
+
+ # 2. 修改 trainer_state.batch_idx_in_epoch
+ # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler;
+ # 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于
+ # '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立
+ self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch')
+ # 这个是防止用户在 Trainer.load_checkpoint 之后还没结束当前 epoch 又继续 save_checkpoint
+ self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch
+
+ # 5. 恢复所有 callback 的状态;
+ self.on_load_checkpoint(states["callback_states"])
+
+ self.driver.barrier()
+
+ """ 这四个函数是用来方便用户定制自己的 batch_step_fn(用于替换 train_batch_loop 当中的 batch_step_fn 函数) 的 """
+
+ def train_step(self, batch):
+ r"""
+ 实现模型训练过程中的对一个 batch 的数据的前向传播过程;
+
+ .. note::
+
+ 该函数的提供是为了您能够更方便地定制自己的 ``train_batch_step_fn`` 来替换原本的 ``train_batch_loop.batch_step_fn``;更具体的细节
+ 请见 :meth:`~fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop.batch_step_fn`;
+
+ ``trainer.backward / zero_grad / step`` 函数的作用类似;
+
+ :param batch: 一个 batch 的数据;
+ :return: 模型的前向传播函数所返回的结果;
+ """
+ with self.driver.auto_cast():
+ outputs = self.driver.model_call(batch, self._train_step, self._train_step_signature_fn)
+ outputs = match_and_substitute_params(self.output_mapping, outputs)
+ return outputs
+
+ def backward(self, outputs):
+ r"""
+ 实现模型训练过程中神经网络的反向传播过程;
+
+ :param outputs: 模型的输出,应当为一个字典或者 dataclass,里面包含以 ``loss`` 为关键字的值;
+ """
+ self.on_before_backward(outputs)
+ loss = self.extract_loss_from_outputs(outputs)
+ loss = loss / self.accumulation_steps
+ self.driver.backward(loss)
+ self.on_after_backward()
+
+ def zero_grad(self):
+ r"""
+ 实现模型训练过程中对优化器中的梯度的置零操作;
+ """
+ if (self.global_forward_batches + 1) % self.accumulation_steps == 0:
+ self.on_before_zero_grad(self.optimizers)
+ self.driver.zero_grad()
+ self.on_after_zero_grad(self.optimizers)
+
+ def step(self):
+ r"""
+ 实现模型训练过程中的优化器的参数更新操作;
+ """
+
+ if (self.global_forward_batches + 1) % self.accumulation_steps == 0:
+ self.on_before_optimizers_step(self.optimizers)
+ self.driver.step()
+ self.on_after_optimizers_step(self.optimizers)
+
+ def move_data_to_device(self, batch):
+ r"""
+ 将数据迁移到当前进程所使用的设备上;
+
+ :param batch: 一个 batch 的数据;
+ :return: 位置已经被迁移后的数据;
+ """
+ return self.driver.move_data_to_device(batch)
+
+ @staticmethod
+ def extract_loss_from_outputs(outputs):
+ r"""
+ 用来从用户模型的输出对象中抽取 ``loss`` 对象;
+ 目前支持 `outputs` 对象为 ``dict`` 或者 ``dataclass``;
+
+ :return: 被抽取出来的 ``loss`` 对象,例如如果是 ``pytorch``,那么返回的就是一个 tensor;
+ """
+ if isinstance(outputs, Dict):
+ try:
+ loss = outputs["loss"]
+ except:
+ raise KeyError(f"We cannot find `loss` from your model output(with keys:{outputs.keys()}). Please either "
+ f"directly return it from your model or use `output_mapping` to prepare it.")
+ elif is_dataclass(outputs):
+ try:
+ loss = outputs.loss
+ except:
+ raise AttributeError("We cannot find `loss` from your model output. Please either directly return it from"
+ " your model or use `output_mapping` to prepare it.")
+ else:
+ raise ValueError("The `outputs` from your model could only be of `dataclass` or `Dict` type. Or you can use "
+ "the parameter `output_mapping` to prepare loss.")
+
+ return loss
+
+ @contextmanager
+ def get_no_sync_context(self):
+ r"""
+ 用于在使用梯度累积并且进行分布式训练时,由于在前 ``accumulation_steps - 1`` 的时间内不需要进行梯度的同步,因此通过使用该 context 上下文
+ 环境来避免梯度的同步;
+
+ .. note::
+
+ 部分深度学习框架的梯度累积并不需要通过提供上下文环境实现,关于这点需要您深入了解您正在使用的框架的机制;而对于这些框架,fastNLP 会返回一个
+ 空的上下文环境。
+
+ :return: 一个支持 ``no_sync`` 的 ``context``;
+ """
+
+ if (self.global_forward_batches + 1) % self.accumulation_steps != 0:
+ _no_sync_context = self.driver.get_model_no_sync_context()
+ else:
+ _no_sync_context = nullcontext
+
+ with _no_sync_context():
+ yield
+
+ """ trainer state property """
+
+ @property
+ def n_epochs(self) -> int:
+ r"""
+ :return: 当前训练的总体的 epoch 的数量;
+ """
+ return self.trainer_state.n_epochs
+
+ @n_epochs.setter
+ def n_epochs(self, n_epochs: int):
+ self.trainer_state.n_epochs = n_epochs
+
+ @property
+ def cur_epoch_idx(self) -> int:
+ r"""
+ :return: 当前正在第几个 epoch;
+ """
+ return self.trainer_state.cur_epoch_idx
+
+ @cur_epoch_idx.setter
+ def cur_epoch_idx(self, cur_epoch_idx: int):
+ self.trainer_state.cur_epoch_idx = cur_epoch_idx
+
+ @property
+ def global_forward_batches(self) -> int:
+ """
+ :return: 从训练开始到当前总共训练了多少 batch 的数据;
+ """
+ return self.trainer_state.global_forward_batches
+
+ @global_forward_batches.setter
+ def global_forward_batches(self, global_forward_batches: int):
+ self.trainer_state.global_forward_batches = global_forward_batches
+
+ @property
+ def batch_idx_in_epoch(self) -> int:
+ r"""
+ :return: 在从当前的这个 epoch 开始,到现在共训练了多少 batch 的数据;
+ """
+ return self.trainer_state.batch_idx_in_epoch
+
+ @batch_idx_in_epoch.setter
+ def batch_idx_in_epoch(self, batch_idx_in_epoch: int):
+ self.trainer_state.batch_idx_in_epoch = batch_idx_in_epoch
+
+ @property
+ def num_batches_per_epoch(self) -> int:
+ r"""
+ :return: 每一个 epoch 实际会训练多少个 batch 的数据;
+ """
+ return self.trainer_state.num_batches_per_epoch
+
+ @num_batches_per_epoch.setter
+ def num_batches_per_epoch(self, num_batches_per_epoch: int):
+ self.trainer_state.num_batches_per_epoch = num_batches_per_epoch
+
+ @property
+ def n_batches(self) -> int:
+ r"""
+ :return: 整体的训练中实际会训练多少个 batch 的数据;
+ """
+ return self.trainer_state.n_batches
+
+ @n_batches.setter
+ def n_batches(self, n_batches: int):
+ self.trainer_state.n_batches = n_batches
+
+ """ driver property """
+
+ @property
+ def model_device(self):
+ r"""
+ :return: 当前模型所在的设备;注意该值在当且仅当在少数情况下为 ``None``,例如当使用 ``pytorch`` 时,仅当用户自己初始化 ``init_progress_group`` 时
+ ``model_device`` 才为 None;
+ """
+ return self.driver.model_device
+
+ @property
+ def data_device(self):
+ r"""
+ :return: 数据会被迁移到的目的设备;
+ """
+ return self.driver.data_device
+
+ """ dataloader property """
+
+ @property
+ def train_dataloader(self):
+ """
+ :return: 用户传入的 ``train_dataloader``,注意该 ``dataloader`` 与用户传入给 ``Trainer`` 的 ``dataloader`` 对象是同一个对象,而我们在
+ 实际训练过程中使用的 ``dataloader`` 的状态可能有所更改;
+ """
+ return self._train_dataloader
+
+ @train_dataloader.setter
+ def train_dataloader(self, train_dataloader):
+ self._train_dataloader = train_dataloader
+
+ @property
+ def evaluate_dataloaders(self):
+ """
+ :return: 用户传入的 ``evaluate_dataloaders``;
+ """
+ return self._evaluate_dataloaders
+
+ @evaluate_dataloaders.setter
+ def evaluate_dataloaders(self, evaluate_dataloaders):
+ self._evaluate_dataloaders = evaluate_dataloaders
+
+
+def _get_input_output_mapping(input_mapping, output_mapping, train_input_mapping, train_output_mapping,
+ evaluate_input_mapping, evaluate_output_mapping):
+ """
+ 确定在训练过程中到底要使用哪个 input_mapping 和 output_mapping,之所以要设置该函数是因为在有些时候 evaluate 所需要的 input_mapping 和
+ output_mapping 是与 train 的时候是不一样的,因此需要额外的定制;
+ """
+ if train_input_mapping is not None and input_mapping is not None:
+ raise ValueError("Parameter `input_mapping` and `train_input_mapping` cannot be set simultaneously.")
+
+ if evaluate_input_mapping is not None and input_mapping is not None:
+ raise ValueError("Parameter `input_mapping` and `evaluate_input_mapping` cannot be set simultaneously.")
+
+ if train_output_mapping is not None and output_mapping is not None:
+ raise ValueError("Parameter `output_mapping` and `train_output_mapping` cannot be set simultaneously.")
+
+ if evaluate_output_mapping is not None and output_mapping is not None:
+ raise ValueError("Parameter `output_mapping` and `evaluate_output_mapping` cannot be set simultaneously.")
+
+ if train_input_mapping is None:
+ train_input_mapping = input_mapping
+ if evaluate_input_mapping is None:
+ evaluate_input_mapping = input_mapping
+
+ if train_output_mapping is None:
+ train_output_mapping = output_mapping
+ if evaluate_output_mapping is None:
+ evaluate_output_mapping = output_mapping
+
+ return train_input_mapping, train_output_mapping, evaluate_input_mapping, evaluate_output_mapping
+
+
+
+
+
+
diff --git a/fastNLP/core/controllers/utils/__init__.py b/fastNLP/core/controllers/utils/__init__.py
new file mode 100644
index 00000000..22575b12
--- /dev/null
+++ b/fastNLP/core/controllers/utils/__init__.py
@@ -0,0 +1,6 @@
+__all__ = [
+ 'State',
+ 'TrainerState'
+]
+
+from .state import State, TrainerState
diff --git a/fastNLP/core/controllers/utils/state.py b/fastNLP/core/controllers/utils/state.py
new file mode 100644
index 00000000..98d0c601
--- /dev/null
+++ b/fastNLP/core/controllers/utils/state.py
@@ -0,0 +1,83 @@
+from dataclasses import dataclass
+from typing import Optional, Dict
+
+__all__ = [
+ 'State',
+ 'TrainerState'
+]
+
+
+class State(dict):
+ r"""
+ 提供给用户使用的 ``state``,用来记载您的 ``callback`` 实时数据,该 ``state`` 实际上是一个字典,我们通过复用 ``__getattr__`` 方法来实现类似
+ 类属性的字典调用方式;
+
+ 为了实现断点重训,用户应当保证其保存的信息都是可序列化的;
+
+ 推荐的使用方式::
+
+ >>> state = State()
+ >>> state["best_accuracy"] = 0.9
+ >>> print(state["best_accuracy"])
+ or
+ >>> print(state.best_accuracy)
+ """
+
+ __slots__ = () # 用户不应当使用 state.name = "name" 来使用此类,因此我们限制用户不可自己对该类设置属性,但是可以通过属性访问字典;
+
+ def __init__(self, *args, **kwargs):
+ super(State, self).__init__(*args, **kwargs)
+
+ def __getattr__(self, item):
+ if item in self:
+ _value = self[item]
+ if isinstance(_value, dict):
+ return State(_value)
+ else:
+ return _value
+ else:
+ raise ValueError(f"key '{item}' is not existed!")
+
+@dataclass
+class TrainerState:
+ r"""
+ 该类用于我们 fastNLP 自己内部为了训练流程所记录的一些状态,当然是要暴露给用户给用户使用的;
+ 我们保存的 state 大部分上是 trainer 断点重训 需要重新加载的;
+ 专属于 `Trainer` 的状态记载的类;
+
+ :param n_epochs: 训练过程中总共的 epoch 的数量;
+ :param cur_epoch_idx: 当前正在运行第几个 epoch;
+ :param global_forward_batches: 当前模型总共 forward 了多少个 step;
+ :param batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step;
+ :param num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step;
+ :param n_batches: 完整训练过程会 forward 的 step 数量,注意 ``n_batches = num_batches_per_epoch * n_epochs`` ;
+ """
+ n_epochs: Optional[int] = None # 无论如何重新算
+
+ cur_epoch_idx: Optional[int] = 0 # 断点重训; 仅当 resume=False 时为0;
+ global_forward_batches: Optional[int] = 0 # 断点重训
+
+ batch_idx_in_epoch: Optional[int] = 0 # 断点重训
+
+ num_batches_per_epoch: Optional[int] = None # 无论如何重新算
+
+ n_batches: Optional[int] = None # 无论如何重新算
+
+ def state_dict(self) -> Dict:
+ r"""
+ :return: 用于断点重训来保存的状态字典;
+ """
+ return {"cur_epoch_idx": self.cur_epoch_idx, "global_forward_batches": self.global_forward_batches,
+ "batch_idx_in_epoch": self.batch_idx_in_epoch}
+
+ def load_state_dict(self, state_dict: Dict):
+ r"""
+ 用于断点重训来重新加载保存的状态字典;
+
+ :param state_dict: 用于加载的状态字典;
+ """
+ for key in state_dict:
+ assert key in {"cur_epoch_idx", "global_forward_batches", "batch_idx_in_epoch"}, "Wrong state_dict for `TrainerState`."
+ setattr(self, key, state_dict[key])
+
+
diff --git a/fastNLP/core/controllers/utils/utils.py b/fastNLP/core/controllers/utils/utils.py
new file mode 100644
index 00000000..ca1e0b86
--- /dev/null
+++ b/fastNLP/core/controllers/utils/utils.py
@@ -0,0 +1,158 @@
+from typing import Dict
+
+from fastNLP.core.callbacks import CallbackManager
+from .state import TrainerState
+from fastNLP.core.utils.utils import _check_valid_parameters_number
+
+__all__ = []
+
+class TrainerEventTrigger:
+ r"""
+ 为了避免在训练流程中调用 callback 函数中写成类似 `'trainer.callback_manager.on_train_begin'` 的形式,我们选择单独为 ``Trainer``
+ 抽象一层,然后一些特殊的操作可以在这里进行,例如我们通过 :meth:`on_validate_end` 来通知所有的 ``CheckpointCallback`` 实例在当前的 step 后保存
+ 模型。
+ """
+ callback_manager: CallbackManager
+ trainer_state: TrainerState
+
+ def on_after_trainer_initialized(self, driver):
+ self.callback_manager.on_after_trainer_initialized(self, driver)
+
+ def on_sanity_check_begin(self):
+ self.callback_manager.on_sanity_check_begin(self)
+
+ def on_sanity_check_end(self, sanity_check_res):
+ self.callback_manager.on_sanity_check_end(self, sanity_check_res)
+
+ def on_train_begin(self):
+ self.callback_manager.on_train_begin(self)
+
+ def on_train_end(self):
+ self.callback_manager.on_train_end(self)
+
+ def on_train_epoch_begin(self):
+ self.callback_manager.on_train_epoch_begin(self)
+
+ def on_train_epoch_end(self):
+ self.callback_manager.on_train_epoch_end(self)
+
+ def on_fetch_data_begin(self):
+ self.callback_manager.on_fetch_data_begin(self)
+
+ def on_fetch_data_end(self):
+ self.callback_manager.on_fetch_data_end(self)
+
+ def on_train_batch_begin(self, batch, indices=None):
+ self.callback_manager.on_train_batch_begin(self, batch, indices)
+
+ def on_train_batch_end(self):
+ self.callback_manager.on_train_batch_end(self)
+
+ def on_exception(self, exception):
+ self.callback_manager.on_exception(self, exception)
+
+ def on_save_model(self):
+ self.callback_manager.on_save_model(self)
+
+ def on_load_model(self):
+ self.callback_manager.on_load_model(self)
+
+ def on_save_checkpoint(self) -> Dict:
+ return self.callback_manager.on_save_checkpoint(self)
+
+ def on_load_checkpoint(self, states):
+ self.callback_manager.on_load_checkpoint(self, states)
+
+ def on_before_backward(self, outputs):
+ self.callback_manager.on_before_backward(self, outputs)
+
+ def on_after_backward(self):
+ self.callback_manager.on_after_backward(self)
+
+ def on_before_optimizers_step(self, optimizers):
+ self.callback_manager.on_before_optimizers_step(self, optimizers)
+
+ def on_after_optimizers_step(self, optimizers):
+ self.callback_manager.on_after_optimizers_step(self, optimizers)
+
+ def on_before_zero_grad(self, optimizers):
+ self.callback_manager.on_before_zero_grad(self, optimizers)
+
+ def on_after_zero_grad(self, optimizers):
+ self.callback_manager.on_after_zero_grad(self, optimizers)
+
+ def on_evaluate_begin(self):
+ self.callback_manager.on_evaluate_begin(self)
+
+ def on_evaluate_end(self, results):
+ self.trainer_state.save_on_this_step = True
+ self.callback_manager.on_evaluate_end(self, results)
+
+
+class _TruncatedDataLoader:
+ r"""
+ ``_TruncatedDataLoader`` 用于实现 ``Trainer`` 和 ``Evaluator`` 中的 '预跑' 和 '假跑' 功能:
+
+ 1. 预跑 是针对 trainer 的验证而言的,即我们在正式的训练前会先使用 trainer 内置的 evaluator(如果不为 None)评测数量非常少的数据,
+ 来检验用户的 metric 和 evaluate_dataloader 以及模型是否能够合作完成正确的评测过程;
+ 2. 假跑 的意思是改变每一个 epoch 中训练或者评测的实际的 batch 的数量,例如改成 10,来让模型快速地迭代整体的训练或者评测流程,来查看
+ 整体的过程的正确性;
+
+ ``_TruncatedDataLoader`` 的实现也非常简单,我们在该类中内置一个计数器,当迭代器的迭代数量达到指定数值后 ``raise StopIteration``;
+
+ :param dataloader: 可迭代的 dataloader 。
+ :param num_batches: 迭代多少个 batch 就停止。
+ """
+ def __init__(self, dataloader, num_batches: int):
+
+ self.dataloader = dataloader
+ self._num_batches = min(num_batches, len(dataloader))
+ self._count = 0
+
+ def __len__(self):
+ r"""
+ 为了在外部调用 `len` 方法时正确地返回当前会迭代的长度;
+ """
+ return self._num_batches
+
+ def __iter__(self):
+ # 将初试的 `dataloader` 转换成一个 `Iterator` 的逻辑应该放在这里,即只有当外界真正的调用 iter(dataloader) 的时候才需要返回一个 Iterator;
+ # TODO 测试一下
+ self._iterator = iter(self.dataloader)
+ self._count = 0
+ return self
+
+ def __next__(self):
+ if self._count >= self._num_batches:
+ raise StopIteration
+ self._count += 1
+ # 注意 dataloader 数据不足时会自己本身触发 `StopIteration`;
+ return next(self._iterator)
+
+ def __getattr__(self, item):
+ return getattr(self.dataloader, item)
+
+ def __setattr__(self, key, value):
+ # 添加该函数使得在进行实验性训练或者评测时,用户对于 trainer.dataloader 的感觉和正常训练完全一样;
+ # 注意这里不能直接 ``setattr(self.dataloader, key, value)``,会导致死循环;
+ if "dataloader" in self.__dict__:
+ if hasattr(self.dataloader, key):
+ setattr(self.dataloader, key, value)
+ else:
+ self.__dict__[key] = value
+ else:
+ self.__dict__[key] = value
+
+
+def check_evaluate_every(evaluate_every):
+ r"""
+ 检验用户传入的 ``evaluate_every`` 参数是否合法;
+
+ ``evaluate_every`` 的使用详见 ``Trainer`` 的 ``evaluate_every`` 参数;
+
+ 主要在于当参数 ``evaluate_every`` 是一个 Callable 的函数时,需要保证其参数的正确性;
+ """
+ if not callable(evaluate_every) and (not isinstance(evaluate_every, int) or evaluate_every == 0):
+ raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.")
+ if callable(evaluate_every):
+ _check_valid_parameters_number(evaluate_every, expected_params=['trainer'])
diff --git a/fastNLP/core/dataloaders/__init__.py b/fastNLP/core/dataloaders/__init__.py
new file mode 100644
index 00000000..06d3f5a8
--- /dev/null
+++ b/fastNLP/core/dataloaders/__init__.py
@@ -0,0 +1,22 @@
+__all__ = [
+ 'MixDataLoader',
+ 'TorchDataLoader',
+ 'PaddleDataLoader',
+ 'JittorDataLoader',
+ 'OneflowDataLoader',
+ 'prepare_jittor_dataloader',
+ 'prepare_paddle_dataloader',
+ 'prepare_torch_dataloader',
+ 'prepare_oneflow_dataloader',
+
+ "prepare_dataloader",
+
+ "OverfitDataLoader"
+]
+
+from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader
+from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader, MixDataLoader
+from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader
+from .oneflow_dataloader import OneflowDataLoader, prepare_oneflow_dataloader
+from .prepare_dataloader import prepare_dataloader
+from .utils import OverfitDataLoader
\ No newline at end of file
diff --git a/fastNLP/core/dataloaders/jittor_dataloader/__init__.py b/fastNLP/core/dataloaders/jittor_dataloader/__init__.py
new file mode 100644
index 00000000..8aba7614
--- /dev/null
+++ b/fastNLP/core/dataloaders/jittor_dataloader/__init__.py
@@ -0,0 +1,7 @@
+__all__ = [
+ "JittorDataLoader",
+ 'prepare_jittor_dataloader'
+
+]
+
+from .fdl import JittorDataLoader, prepare_jittor_dataloader
diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py
new file mode 100644
index 00000000..0e0cb443
--- /dev/null
+++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py
@@ -0,0 +1,305 @@
+__all__ = [
+ 'JittorDataLoader',
+ 'prepare_jittor_dataloader'
+]
+
+from typing import Callable, Optional, List, Union, Dict, Sequence
+from copy import deepcopy
+
+import numpy as np
+
+from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
+
+if _NEED_IMPORT_JITTOR:
+ from jittor.dataset.utils import collate_batch
+ from jittor.dataset import Dataset
+else:
+ from fastNLP.core.dataset import DataSet as Dataset
+
+from fastNLP.core.collators import Collator
+from fastNLP.core.dataloaders.utils import indice_collate_wrapper
+from fastNLP.core.dataset import DataSet as FDataSet
+from ..utils import HasLenGetitemType
+
+
+class _JittorDataset(Dataset):
+ """
+ 对用户传的 ``dataset`` 进行封装,以便 ``JittorDataLoader`` 能够支持使用自定义的 ``dataset`` 。
+ """
+
+ def __init__(self, dataset) -> None:
+ super(_JittorDataset, self).__init__()
+ self.dataset = dataset
+ self.total_len = len(dataset)
+
+ def __getitem__(self, item):
+ if isinstance(item, np.integer):
+ item = item.tolist()
+ return (item, self.dataset[item])
+
+ def __getstate__(self):
+ return self.__dict__
+
+ def __setstate__(self, state):
+ self.__dict__ = state
+
+
+class JittorDataLoader:
+ """
+ 提供给 ``jittor`` 框架使用的 ``DataLoader`` 函数,``JittorDataLoader`` 提供了 ``Collator`` 来自动检测 dataset 的每个 field 是否可 pad,
+ 若是可 pad 的 field 则自动 pad 到相同长度,否则只会将相同 field 的数据收集组成一个 batch 返回。
+ 具体详见 :class:`~fastNLP.core.collators.Collator`;用户通过 callte_fn 来控制是否使用该功能, collate_fn 只能为 ``['auto', None, Callable]`` 三种取值。
+
+ * callate_fn 为 ``'auto'`` 时,``JittorDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的取值。
+ 此时可以配套使用 ``JittorDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。
+ * callate_fn 为 ``None`` 时, ``JittorDataLoader`` 默认使用 Jittor DataLoader 自带的 collate_fn
+ * collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是
+ dataset 的一条数据;该 Callable 函数还应当返回一个对象。
+
+ :param dataset: 实现了 __getitem__() 和 __len__() 的对象。
+ :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
+ :param shuffle: 是否打乱数据集, 默认为 ``False``。
+ :param drop_last: 当 ``drop_last=True`` 时,``JittorDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据;
+ 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。
+ :param num_workers: 当 ``num_workers > 0`` 时, ``JittorDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快
+ 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。
+ :param buffer_size: 每个进程占用的内存空间,默认为 512M。主要是配合 ``num_workers`` 使用,用户可以自定义每个进程的内存大小。
+ :param stop_grad: 是否不使用梯度, 默认 ``True`` 。
+ :param keep_numpy_array: 返回的数据是 ``np.array`` 类型而不是 ``jittor.Var`` 类型,默认为 ``False``
+ :param endless: 是否让 ``JittorDataLoader`` 无限返回数据,也就是将 dataset 循环使用使得返回数据是没有限制的。默认为 ``False``.
+ :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``.
+
+ * callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时,
+ ``JittorDataLoader`` 调用默认的 Jittor 框架的 ``DataLoader`` 自带的 ``collate_batch`` 作为 callate_fn 的默认值, 其无法处理
+ :class:`~fastNLP.core.dataset.DataSet` 的 dataset 对象。
+ * callate_fn 为 ``'auto'`` 时,``JittorDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。
+ 此时可以配套使用 ``JittorDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。
+ * collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是
+ dataset 的一条数据;该 Callable 函数还应当返回一个对象。
+ """
+ def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False,
+ drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024,
+ stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False,
+ collate_fn: Union[None, str, Callable] = "auto") -> None:
+
+ # TODO 验证支持replacesampler (以后完成) 增加Sampler
+ # 将内部dataset批次设置为1
+ if isinstance(dataset, Dataset):
+ dataset.set_attrs(batch_size=1, shuffle=False, endless=False)
+
+ # FastNLP Datset, collate_fn not None
+ if isinstance(dataset, FDataSet) and collate_fn is None:
+ raise ValueError("When use FastNLP DataSet, collate_fn must be not None")
+
+ # 将所有dataset转为jittor类型的dataset
+ if not isinstance(dataset, _JittorDataset):
+ self.dataset = _JittorDataset(dataset)
+
+ if isinstance(collate_fn, str):
+ if collate_fn == "auto":
+ if isinstance(self.dataset.dataset, FDataSet):
+ self.collate_fn = deepcopy(self.dataset.dataset.collator)
+ # jittor 比较特殊,只需要保证返回 numpy.array, 其Dataloader会转为jt.var
+ self.collate_fn.set_backend(backend="numpy")
+ else:
+ # jittor 比较特殊,只需要保证返回 numpy.array, 其Dataloader会转为jt.var
+ self.collate_fn = Collator(backend="numpy")
+ else:
+ raise ValueError(f"collate_fn: {collate_fn} must be 'auto'")
+ elif isinstance(collate_fn, Callable):
+ self.collate_fn = collate_fn
+ else:
+ self.collate_fn = collate_batch
+
+ self.dataset.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last,
+ num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad,
+ keep_numpy_array=keep_numpy_array, endless=endless)
+
+ self.cur_batch_indices = None
+
+ def __getattr__(self, attr):
+ if attr in ["batch_size", "shuffle", "drop_last", "num_workers", "buffer_size", "stop_grad",
+ "keep_numpy_array", "endless", "sampler"]:
+ return getattr(self.dataset, attr)
+ raise AttributeError(f"{self} has not attribute '{attr}'")
+
+ def __iter__(self):
+ # TODO 第一次迭代后不能设置collate_fn,设置是无效的
+ if self.cur_batch_indices is None:
+ self.dataset.set_attrs(collate_batch=indice_collate_wrapper(self.collate_fn))
+ for indices, data in self.dataset.__iter__():
+ self.cur_batch_indices = indices
+ yield data
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None,
+ pad_fn: Callable = None) -> Collator:
+ """
+ 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。
+
+ :param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的
+ field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``;
+ 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。
+ 如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 ``"_single"`` 。
+ :param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
+ field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``,
+ 该值无意义。
+ :param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。
+ :param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`,
+ :class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。
+ 若 ``pad_val`` 为 ``None`` ,该值无意义 。
+ :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的
+ batch 形式。 collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。
+ :return: 使用的 collator
+ """
+ collator = self._get_collator()
+ if isinstance(collator, Collator):
+ collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
+ return collator
+ else:
+ raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")
+
+ def _get_collator(self):
+ """
+ 如果 collate_fn 是 Collator 对象,得到该对象。如果没有的话,返回 None
+
+ :return:
+ """
+ collator = None
+ if hasattr(self.collate_fn, '__wrapped__') and isinstance(self.collate_fn.__wrapped__, Collator):
+ collator = self.collate_fn.__wrapped__
+ elif isinstance(self.collate_fn, Collator):
+ collator = self.collate_fn
+ return collator
+
+ def set_ignore(self, *field_names) -> Collator:
+ """
+ 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略::
+
+ dataloader.set_ignore('field1', 'field2')
+
+ :param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的
+ field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``;
+ 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。
+ :return: 使用的 collator
+ """
+ collator = self._get_collator()
+ if isinstance(collator, Collator):
+ collator.set_ignore(*field_names)
+ return collator
+ else:
+ raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
+
+ def get_batch_indices(self) -> List[int]:
+ """
+ 获取当前 ``batch`` 中每条数据对应的索引。
+
+ :return: 当前 ``batch`` 数据的索引;
+ """
+ return self.cur_batch_indices
+
+
+def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = None,
+ drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024,
+ stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False,
+ collate_fn: Union[None, str, Callable] = "auto",
+ non_train_batch_size: int = None) \
+ -> Union[Dict[str, JittorDataLoader], JittorDataLoader]:
+ """
+ ``prepare_jittor_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 :class:`JittorDataLoader` 对象, 详见 :class:`~fastNLP.core.dataloaders.JittorDataLoader`。
+ 根据 ``ds_or_db`` 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下:
+
+ * 当 ds_or_db 为 :class:`~fastNLP.io.DataSet` 时,``prepare_jittor_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来
+ 帮你实例化一个 :class:`JittorDataLoader` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.JittorDataLoader`;
+ * 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_jittor_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value
+ 来创建不同的 :class:`JittorDataLoader` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_jittor_dataloader`` 默认该 value 为训练数据集,
+ 会将 ``batch_size`` 和 ``sampler`` 作为参数,其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。
+ 最终根据 ``key: JittorDataLoader`` 组成 ``Dict[key, JittorDataLoader]`` 的字典返回;
+ * 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_jittor_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的
+ :class:`JittorDataLoader` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_jittor_dataloader`` 默认该 value 为训练数据集,会将 ``batch_size`` 和
+ ``sampler`` 作为参数,其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。最终根据 ``key: JittorDataLoader`` 组成
+ ``Dict[key, JittorDataLoader]`` 的字典返回;
+
+ :param ds_or_db: 可以有以下三种取值:
+
+ * ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典;
+ * ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典;
+ * ds_or_db 为实现了 :meth:`__getitem__` 和 :meth:`__len__` 的对象 ,返回值为 :class:`~fastNLP.core.dataloaders.JittorDataLoader`;
+
+ :param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 :class:`Dict` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数
+ 设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。 默认为 ``16``。
+ :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
+ :param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 True ,
+ 其它的为 False 。
+ :param drop_last: 当 ``drop_last=True`` 时,:class:`JittorDataLoader` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据;
+ 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。
+ :param num_workers: 当 ``num_workers > 0`` 时, :class:`JittorDataLoader` 会开启 num_workers 个子进程来处理数据, 可以加快
+ 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。
+ :param buffer_size: 每个进程占用的内存空间,默认为512M。主要是配合 ``num_workers`` 使用,用户可以自定义每个进程的内存大小。
+ :param stop_grad: 是否不使用梯度, 默认 ``True`` 。
+ :param keep_numpy_array: 返回的数据是 :class:`np.array` 类型而不是 :class:`ittor.Var` 类型,默认为 ``False``
+ :param endless: 是否让 :class:`JittorDataLoader` 无限返回数据,也就是将 dataset 循环使用使得返回数据是没有限制的。默认为 ``False``.
+ :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``.
+
+ * callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时,
+ :class:`JittorDataLoader` 调用默认的 Jittor 框架的 ``DataLoader`` 自带的 ``collate_batch`` 作为 callate_fn 的默认值, 其无法处理
+ :class:`~fastNLP.core.dataset.DataSet` 的 dataset 对象。
+ * callate_fn 为 ``'auto'`` 时,:class:`JittorDataLoader` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。
+ 此时可以配套使用 :class:`JittorDataLoader` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。
+ * collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是
+ dataset 的一条数据;该 Callable 函数还应当返回一个对象。
+ """
+ from fastNLP.io.data_bundle import DataBundle
+
+ if isinstance(ds_or_db, DataBundle):
+ dl_bundle = {}
+ for name, ds in ds_or_db.iter_datasets():
+ if 'train' in name:
+ dl_bundle[name] = JittorDataLoader(ds, batch_size=batch_size, shuffle=True if shuffle is None else shuffle,
+ drop_last=drop_last, num_workers=num_workers,
+ buffer_size=buffer_size,
+ stop_grad=stop_grad, keep_numpy_array=keep_numpy_array,
+ endless=endless,
+ collate_fn=collate_fn)
+ else:
+ dl_bundle[name] = JittorDataLoader(ds,
+ batch_size=non_train_batch_size if non_train_batch_size else batch_size,
+ shuffle=False if shuffle is None else shuffle,
+ drop_last=drop_last, num_workers=num_workers,
+ buffer_size=buffer_size,
+ stop_grad=stop_grad, keep_numpy_array=keep_numpy_array,
+ endless=endless,
+ collate_fn=collate_fn)
+ return dl_bundle
+
+ elif isinstance(ds_or_db, Dict):
+ ds_dict = {}
+ for name, ds in ds_or_db.items():
+ if 'train' in name:
+ dl = JittorDataLoader(ds, batch_size=batch_size, shuffle=True if shuffle is None else shuffle,
+ drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size,
+ stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless,
+ collate_fn=collate_fn)
+ else:
+ dl = JittorDataLoader(ds,
+ batch_size=non_train_batch_size if non_train_batch_size else batch_size,
+ shuffle=False if shuffle is None else shuffle,
+ drop_last=drop_last, num_workers=num_workers,
+ buffer_size=buffer_size,
+ stop_grad=stop_grad, keep_numpy_array=keep_numpy_array,
+ endless=endless,
+ collate_fn=collate_fn)
+ ds_dict[name] = dl
+ return ds_dict
+
+ elif isinstance(ds_or_db, HasLenGetitemType):
+ dl = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=False if shuffle is None else shuffle,
+ drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size,
+ stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless,
+ collate_fn=collate_fn)
+ return dl
+
+ else:
+ raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or mapping!")
diff --git a/fastNLP/core/dataloaders/oneflow_dataloader/__init__.py b/fastNLP/core/dataloaders/oneflow_dataloader/__init__.py
new file mode 100644
index 00000000..d17ce91c
--- /dev/null
+++ b/fastNLP/core/dataloaders/oneflow_dataloader/__init__.py
@@ -0,0 +1,6 @@
+__all__ = [
+ "OneflowDataLoader",
+ "prepare_oneflow_dataloader",
+]
+
+from .fdl import OneflowDataLoader, prepare_oneflow_dataloader
diff --git a/fastNLP/core/dataloaders/oneflow_dataloader/fdl.py b/fastNLP/core/dataloaders/oneflow_dataloader/fdl.py
new file mode 100644
index 00000000..e2882d75
--- /dev/null
+++ b/fastNLP/core/dataloaders/oneflow_dataloader/fdl.py
@@ -0,0 +1,363 @@
+__all__ = [
+ 'OneflowDataLoader',
+ 'prepare_oneflow_dataloader'
+]
+
+from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List, Any
+from abc import ABC
+from copy import deepcopy
+
+from fastNLP.core.dataset import DataSet
+from fastNLP.core.collators import Collator
+from fastNLP.core.dataloaders.utils import indice_collate_wrapper
+from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW
+from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler
+from ..utils import _match_param
+from ..utils import HasLenGetitemType
+
+if _NEED_IMPORT_ONEFLOW:
+ from oneflow.utils.data import DataLoader, Sampler, Dataset
+else:
+ from fastNLP.core.utils.dummy_class import DummyClass as DataLoader
+
+
+class _FDataSet:
+ """
+ 提供给 ``OneflowDataLoader`` 使用的 warp 类,其功能是对 dataset 进行封装,wrap 修改 dataset 的 __getitem__ 函数,增加返回
+ 数据的下标 idx 。
+
+ ..note::
+
+ 需要注意的是传入 ``__init__`` 的 dataset 需要实现 __getattribute__ 方法才能在 _FDataset 实例化对象中调用 dataset 的方法
+
+ """
+
+ def __init__(self, dataset) -> None:
+ self.dataset = dataset
+
+ def __getitem__(self, item: Union[int, list]) -> Tuple:
+ return (item, self.dataset[item])
+
+ def __getattr__(self, item):
+ try:
+ return self.dataset.__getattribute__(item)
+ except AttributeError as e:
+ raise e
+
+ def __len__(self) -> int:
+ return len(self.dataset)
+
+ def __getstate__(self):
+ return self.__dict__
+
+ def __setstate__(self, state):
+ self.__dict__ = state
+
+
+class OneflowDataLoader(DataLoader):
+ """
+ 提供给 ``oneflow`` 框架使用的 ``DataLoader`` 函数,``OneflowDataLoader`` 提供了 ``Collator`` 来自动检测 dataset 的每个 field 是否可 pad,
+ 若是可 pad 的 field 则自动 pad 到相同长度,否则只会将相同 field 的数据收集组成一个 batch 返回。
+ 具体详见 :class:`~fastNLP.core.collators.Collator`;用户通过 callte_fn 来控制是否使用该功能, collate_fn 只能为 ``['auto', None, Callable]``
+ 三种取值。
+
+ * callate_fn 为 ``'auto'`` 时,``OneflowDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的取值。
+ 此时可以配套使用 ``OneflowDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。
+ * callate_fn 为 ``None`` 时, ``OneflowDataLoadr`` 默认使用 :class:`oneflow.utils.data.DataLoader` 自带的 collate_fn
+ * collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是
+ dataset 的一条数据;该 Callable 函数还应当返回一个对象。
+
+ :param dataset: 实现了 __getitem__() 和 __len__() 的对象。
+ :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
+ :param non_train_batch_size: 非训练数据集的 ``OneflowDataLoader`` 批次大小,默认为 ``16`` 且当 ``batch_sampler`` 为 ``None`` 有效。
+ :param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` ,
+ 其它的为 False 。
+ :param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index ,
+ 默认为 ``None``, 当其不为 ``None`` 时, shuffle 参数无效。
+ :param non_train_sampler: 非训练数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index ,
+ 默认为None, 当其不为 None 时, shuffle 参数无效。
+ :param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为
+ dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``sampler``, ``shuffle`` 参数均失效。
+ :param num_workers: 当 ``num_workers > 0`` 时, ``OneflowDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快
+ 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。
+ :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``.
+
+ * callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时,
+ ``OneflowDataLoader`` 调用默认的 oneflow 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理
+ :class:`~fastNLP.core.dataset.DataSet` 的dataset对象。
+ * callate_fn 为 ``'auto'`` 时,``OneflowDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。
+ 此时可以配套使用 ``OneflowDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。
+ * collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是
+ dataset 的一条数据;该 Callable 函数还应当返回一个对象。
+
+ :param pin_memory: 如果其为 ``True``, 那么 ``OneflowDataLoader`` 会在返回数据张量之前将其 copy 到 cuda 的 pin memory 中。
+ :param drop_last: 当 ``drop_last=True`` 时,``OneflowDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据;
+ 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。
+ :param timeout: 子进程的输出队列获取数据的超时值
+ :param worker_init_fn: init 函数,如果不设置为 ``None``,则将会在每个子进程初始化时调用该函数。
+ :param multiprocessing_context: 多进程的上下文环境
+ :param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个 ``base_seed``
+ :param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2`` 意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` .
+ :param persistent_workers: 如果其为 ``True``, ``OneflowDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False``
+ """
+
+ def __init__(self, dataset, batch_size: int = 16,
+ shuffle: bool = False, sampler = None, batch_sampler = None,
+ num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto',
+ pin_memory: bool = False, drop_last: bool = False,
+ timeout: float = 0, worker_init_fn: Optional[Callable] = None,
+ multiprocessing_context=None, generator=None, prefetch_factor: int = 2,
+ persistent_workers: bool = False, **kwargs) -> None:
+
+ if isinstance(dataset, DataSet) and collate_fn is None:
+ raise ValueError("When use FastNLP DataSet, collate_fn must be not None")
+
+ if not isinstance(dataset, _FDataSet):
+ dataset = _FDataSet(dataset)
+
+ if num_workers>0 and multiprocessing_context is None:
+ multiprocessing_context = 'fork' # 这里默认使用fork的方式来启动多进程
+
+ if batch_sampler is not None:
+ batch_size = 1
+ shuffle = False
+ sampler = None
+ elif sampler is None:
+ sampler = RandomSampler(dataset, shuffle=shuffle)
+ shuffle = False
+
+ if isinstance(collate_fn, str):
+ if collate_fn == 'auto':
+ if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset
+ collate_fn = deepcopy(dataset.dataset.collator)
+ collate_fn.set_backend(backend="oneflow")
+ else:
+ collate_fn = Collator(backend="oneflow")
+ else:
+ raise ValueError(f"collate_fn: {collate_fn} must be 'auto'")
+
+ dl_kwargs = _match_param(OneflowDataLoader.__init__, DataLoader.__init__, fn_name=DataLoader.__name__)
+ if dl_kwargs is None:
+ super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
+ batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn,
+ pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
+ multiprocessing_context=multiprocessing_context, generator=generator,
+ prefetch_factor=prefetch_factor,
+ persistent_workers=persistent_workers)
+ else:
+ super().__init__(**dl_kwargs)
+
+ self.cur_batch_indices = None
+
+ def __iter__(self):
+ self.collate_fn = indice_collate_wrapper(self.collate_fn)
+ for indices, data in super().__iter__():
+ self.cur_batch_indices = indices
+ yield data
+
+ def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None,
+ pad_fn: Callable = None) -> Collator:
+ """
+ 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。
+
+ :param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的
+ field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``;
+ 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。
+ 如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 "_single" 。
+ :param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
+ field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``,
+ 该值无意义。
+ :param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。
+ :param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`,
+ :class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。
+ 若 ``pad_val`` 为 ``None`` ,该值无意义 。
+ :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的
+ batch 形式。 collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。
+ :return: 使用的 collator
+ """
+ collator = self._get_collator()
+ if isinstance(collator, Collator):
+ collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
+ return collator
+ else:
+ raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")
+
+ def _get_collator(self):
+ """
+ 如果 collate_fn 是 Collator 对象,得到该对象。如果没有的话,返回 None
+
+ :return:
+ """
+ collator = None
+ if hasattr(self.collate_fn, '__wrapped__') and isinstance(self.collate_fn.__wrapped__, Collator):
+ collator = self.collate_fn.__wrapped__
+ elif isinstance(self.collate_fn, Collator):
+ collator = self.collate_fn
+ return collator
+
+ def set_ignore(self, *field_names) -> Collator:
+ """
+ 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略::
+
+ dataloader.set_ignore('field1', 'field2')
+
+ :param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的
+ field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``;
+ 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。
+ :return: 使用的 collator
+ """
+ collator = self._get_collator()
+ if isinstance(collator, Collator):
+ collator.set_ignore(*field_names)
+ return collator
+ else:
+ raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
+
+ def get_batch_indices(self) -> List[int]:
+ """
+ 获取当前 ``batch`` 中每条数据对应的索引。
+
+ :return: 当前 ``batch`` 数据的索引;
+ """
+ return self.cur_batch_indices
+
+
+def prepare_oneflow_dataloader(ds_or_db,
+ batch_size: int = 16,
+ shuffle: bool = None,
+ sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
+ batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
+ num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto',
+ pin_memory: bool = False, drop_last: bool = False,
+ timeout: float = 0, worker_init_fn: Optional[Callable] = None,
+ multiprocessing_context=None, generator=None, prefetch_factor: int = 2,
+ persistent_workers: bool = False,
+ non_train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
+ non_train_batch_size: int = None) \
+ -> Union[OneflowDataLoader, Dict[str, OneflowDataLoader]]:
+ """
+ ``prepare_oneflow_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``OneflowDataloader`` 对象, 详见 :class:`~fastNLP.OneflowDataLoader`。
+ 根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下:
+
+ * 当 ds_or_db 为 ``DataSet`` 时,``prepare_oneflow_dataloader`` 会将使用的除了 ``non_train_batch_size`` 和 ``non_train_sampler`` 以外的参数来
+ 帮你实例化一个 ``OneflowDataLoader`` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.OneflowDataLoader`。
+ * 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_oneflow_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value
+ 来创建不同的 ``OneflowDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_oneflow_dataloader`` 默认该 value 为训练数据集,
+ 会将 ``batch_size`` 和 ``sampler`` 作为参数,其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。
+ 最终根据 ``key: OneflowDataLoader`` 组成 ``Dict[key, OneflowDataLoader]`` 的字典返回。
+ * 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_oneflow_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的
+ ``OneflowDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_oneflow_dataloader`` 默认该 value 为训练数据集,会将 ``batch_size`` 和 ``sampler`` 作为参数,
+ 其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。最终根据 ``key: OneflowDataLoader`` 组成
+ ``Dict[key, OneflowDataLoader]`` 的字典返回。
+
+ :param ds_or_db: 可以有以下三种取值,
+
+ * ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, OneflowDataLoader]`` 的字典;
+ * ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, OneflowDataLoader]`` 的字典;
+ * ds_or_db 为实现了 __getitem__() 和 __len__() 的对象 ,返回值为 :class:`~fastNLP.core.dataloaders.OneflowDataLoader`;
+
+ :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
+ :param non_train_batch_size: 非训练数据集的 ``OneflowDataLoader`` 批次大小,默认为 ``16`` 且当 ``batch_sampler`` 为 ``None`` 有效。
+ :param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` ,
+ 其它的为 False 。
+ :param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index ,
+ 默认为 ``None``, 当其不为 ``None`` 时, shuffle 参数无效。
+ :param non_train_sampler: 非训练数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index ,
+ 默认为None, 当其不为 None 时, shuffle 参数无效。
+ :param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List 中的值为
+ dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``sampler``, ``shuffle`` 参数均失效。
+ :param num_workers: 当 ``num_workers > 0`` 时, ``OneflowDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快
+ 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。
+ :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``.
+
+ * callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时,
+ ``OneflowDataLoader`` 调用默认的 oneflow 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理
+ :class:`~fastNLP.core.dataset.DataSet` 的dataset对象。
+ * callate_fn 为 ``'auto'`` 时,``OneflowDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。
+ 此时可以配套使用 ``OneflowDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。
+ * collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是
+ dataset 的一条数据;该 Callable 函数还应当返回一个对象。
+
+ :param pin_memory: 如果其为 ``True``, 那么 ``OneflowDataLoader`` 会在返回数据张量之前将其 copy 到 cuda 的 pin memory 中。
+ :param drop_last: 当 ``drop_last=True`` 时,``OneflowDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据;
+ 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。
+ :param timeout: 子进程的输出队列获取数据的超时值
+ :param worker_init_fn: init 函数,如果不设置为 ``None``,则将会在每个子进程初始化时调用该函数。
+ :param multiprocessing_context: 多进程的上下文环境
+ :param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个 ``base_seed``
+ :param prefetch_factor: 每个 worker 提前装载的 samples 数量。 ``2`` 意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` 。
+ :param persistent_workers: 如果其为 ``True``, ``OneflowDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False``
+
+ """
+
+ from fastNLP.io import DataBundle
+
+ if isinstance(ds_or_db, DataBundle):
+ assert sampler is None, "sampler can only be None when multiple datasets are presented."
+ assert batch_sampler is None, "batch_sampler can only be None when multiple datasets are presented."
+ dl_bundle = {}
+ for name, ds in ds_or_db.iter_datasets():
+ if 'train' in name:
+ dl_bundle[name] = OneflowDataLoader(dataset=ds, batch_size=batch_size,
+ shuffle=True if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler,
+ num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
+ drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
+ multiprocessing_context=multiprocessing_context, generator=generator,
+ prefetch_factor=prefetch_factor,
+ persistent_workers=persistent_workers,
+ )
+ else:
+ dl_bundle[name] = OneflowDataLoader(dataset=ds,
+ batch_size=non_train_batch_size if non_train_batch_size else batch_size,
+ shuffle=False if shuffle is None else shuffle,
+ sampler=non_train_sampler if non_train_sampler else sampler,
+ batch_sampler=batch_sampler,
+ num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
+ drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
+ multiprocessing_context=multiprocessing_context, generator=generator,
+ prefetch_factor=prefetch_factor,
+ persistent_workers=persistent_workers,
+ )
+ return dl_bundle
+
+ elif isinstance(ds_or_db, Mapping):
+ assert sampler is None, "sampler can only be None when multiple datasets are presented."
+ assert batch_sampler is None, "batch_sampler can only be None when multiple datasets are presented."
+ dl_bundle = {}
+ for name, ds in ds_or_db.items():
+ if 'train' in name:
+ dl_bundle[name] = OneflowDataLoader(dataset=ds, batch_size=batch_size,
+ shuffle=True if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler,
+ num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
+ drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
+ multiprocessing_context=multiprocessing_context, generator=generator,
+ prefetch_factor=prefetch_factor,
+ persistent_workers=persistent_workers,
+ )
+ else:
+ dl_bundle[name] = OneflowDataLoader(dataset=ds,
+ batch_size=non_train_batch_size if non_train_batch_size else batch_size,
+ shuffle=False if shuffle is None else shuffle,
+ sampler=non_train_sampler if non_train_sampler else sampler,
+ batch_sampler=batch_sampler,
+ num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
+ drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
+ multiprocessing_context=multiprocessing_context, generator=generator,
+ prefetch_factor=prefetch_factor,
+ persistent_workers=persistent_workers,
+ )
+
+ return dl_bundle
+
+ elif isinstance(ds_or_db, HasLenGetitemType):
+ dl = OneflowDataLoader(dataset=ds_or_db, batch_size=batch_size,
+ shuffle=False if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler,
+ num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
+ drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
+ multiprocessing_context=multiprocessing_context, generator=generator,
+ prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
+ )
+ return dl
+
+ else:
+ raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or mapping!")
diff --git a/fastNLP/core/dataloaders/paddle_dataloader/__init__.py b/fastNLP/core/dataloaders/paddle_dataloader/__init__.py
new file mode 100644
index 00000000..a5ae3a68
--- /dev/null
+++ b/fastNLP/core/dataloaders/paddle_dataloader/__init__.py
@@ -0,0 +1,6 @@
+__all__ = [
+ 'PaddleDataLoader',
+ 'prepare_paddle_dataloader',
+]
+
+from .fdl import PaddleDataLoader, prepare_paddle_dataloader
\ No newline at end of file
diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py
new file mode 100644
index 00000000..12f00534
--- /dev/null
+++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py
@@ -0,0 +1,391 @@
+__all__ = [
+ 'PaddleDataLoader',
+ 'prepare_paddle_dataloader'
+]
+
+from typing import Callable, List, Optional, Union, Dict, Sequence
+from copy import deepcopy
+
+from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
+
+if _NEED_IMPORT_PADDLE:
+ from paddle.io import DataLoader, Dataset, Sampler
+else:
+ from fastNLP.core.utils.dummy_class import DummyClass as Dataset
+ from fastNLP.core.utils.dummy_class import DummyClass as DataLoader
+ from fastNLP.core.utils.dummy_class import DummyClass as Sampler
+
+from fastNLP.core.collators.collator import Collator
+from fastNLP.core.dataloaders.utils import indice_collate_wrapper
+from fastNLP.core.dataset import DataSet as FDataSet
+from fastNLP.core.samplers import ReproducibleBatchSampler, RandomBatchSampler
+from ..utils import _match_param, HasLenGetitemType
+
+
+class _PaddleDataset(Dataset):
+ """
+ 对用户传的dataset进行封装,以便PaddleDataLoader能够支持使用自定义的dataset
+ """
+
+ def __init__(self, dataset) -> None:
+ super(_PaddleDataset, self).__init__()
+ self.dataset = dataset
+
+ def __getitem__(self, item):
+ return (item, self.dataset[item])
+
+ def __len__(self) -> int:
+ return len(self.dataset)
+
+ def __getattr__(self, item):
+ try:
+ return self.dataset.__getattribute__(item)
+ except Exception as e:
+ raise e
+
+ def __len__(self) -> int:
+ return len(self.dataset)
+
+ def __getstate__(self):
+ return self.__dict__
+
+ def __setstate__(self, state):
+ self.__dict__ = state
+
+
+class PaddleDataLoader(DataLoader):
+ """
+ ``PaddleDataLoader`` 是专门提供给 ``paddle`` 框架的 ``DataLoader`` ,其集成了 ``fastNLP`` 的 ``Collator`` ,
+ 具体详见 :class:`~fastNLP.core.collators.Collator`, 并对 ``paddle`` 的 ``DataLoader`` 进行了
+ 封装,使得其具备以下功能:
+
+ 1. ``PaddleDataLoader`` 支持输入的 dataset 是无框架的,只要实现了 __getitem__() 和 __len__() 的对象即可,
+ 当不使用 :class:`~fastNLP.core.dataset.DataSet` 时也不需要传入 collate_fn, 只要只需要将 ``collate_fn='auto'`` 就能够自动
+ 探测数据的类型并判断能否 pad 。此时可以调用 ``set_pad`` 和 ``set_ignore`` 方法来设置 field 的 pad_val 或者忽略某个 field 的 pad 操作。
+
+ Example::
+
+ from fastNLP import PaddleDataLoader
+ class MyDataset:
+ def __init(self, data_lens=100):
+ self.data_lens = 100
+ def __getitem__(self, item):
+ if item % 2 == 0:
+ return {'x':[101, 256, 453], 'y': 0}
+ else:
+ return {'x': [101, 200], 'y': 1}
+ def __len__(self):
+ return self.data_lens
+ dataset = MyDataset()
+ paddle_dl = PaddleDataLoader(dataset, collate_fn='auto')
+ for batch in paddle_dl:
+ ...
+
+ 2.当 collate_fn 为 ``None`` 时,``PaddleDataLoader`` 默认使用 ``paddle`` 自带的 ``default_collate_fn`` 作为 collate_fn 的值
+
+ .. note::
+ 当传入的dataset为fastNLP的DataSet时,collate_fn不能为None。默认可以是"auto"或者自定义callable函数。
+
+ 3. 当 collate_fn 为 :class:`Callable` 时,该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是
+ dataset 的一条数据;该 Callable 函数还应当返回一个对象。
+
+ :param dataset: 实现了 __getitem__() 和 __len__() 的对象。
+ :param feed_list: feed Tensor list.
+ 这个张量能被 ``paddle.static.data`` 创建。 如果 :attr:`return_list` 是 ``False``, 那么 :attr:`feed_list`
+ 应该被设置。 默认为 ``None `` 。
+ :param places: 将数据放进的一个 list 的 place。 :attr:`places` 能为 None.
+ 如果 :attr:`places` 为 None, 默认放在 CPUPlace 或者 CUDAPlace(0) 设备上。 如果 ``places`` 是一个 list 类型的 字符串, 那么字符串
+ 可以是 ``cpu`` , ``gpu:x`` 或者 ``gpu_pinned`` , 其中 ``x`` 是 gpu 的下标。
+ :param return_list: 每个设备上的返回值是否为以列表形式显示。 如果 :attr:`return_list=False`,
+ 每个设备上的返回值值为 str -> Tensor 的 dict, 其中 dict 的 key 为每个 fed Tensors 的名字。
+ 如果 :attr:`return_list=True`, 每个设备上的返回值值为 list(Tensor)。 :attr:`return_list` 只能在动态图情况下设置为 ``True`` .
+ 默认值为 ``True`` 。
+ :param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为
+ dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``shuffle`` 参数均失效。
+ :param batch_size: 批次大小,默认为 ``16`` 且当 ``batch_sampler`` 为 None 有效。
+ :param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` ,
+ 其它的为 False 。
+ :param drop_last: 当 ``drop_last=True`` 时,``PaddleDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据;
+ 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。
+ :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``.
+
+ * callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时,
+ ``PaddleDataLoader`` 调用默认的 Paddle 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理
+ :class:`~fastNLP.core.dataset.DataSet` 的dataset对象。
+ * callate_fn 为 ``'auto'`` 时,``PaddleDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。
+ 此时可以配套使用 ``PaddleDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。
+ * collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是
+ dataset 的一条数据;该 Callable 函数还应当返回一个对象。
+
+ :param num_workers: 当 ``num_workers > 0`` 时, ``PaddleDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快
+ 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。
+ :param use_buffer_reader: 是否开启 buffer_reader 。如果 ``use_buffer_reader=True`` ,那么 ``PaddleDataLoader`` 会异步地预取下一个 batch 的
+ 数据,因此它将会加快数据传输的速度,但是将会占用更多的内存或者显存。默认值是 ``True``。
+ :param use_shared_memory: 是否使用共享内存。当 ``use_shared_memory=True`` 时,将采用共享内存来加快将数据放进进程队列。建议仅当计算机上的
+ 共享空间足够大时。(例如 Linux 上的 /dev/shm/ 空间足够大)共享内存仅在多进程模式( ``num_workers>0`` )下生效。
+ :param timeout: 从子进程的输出队列获取数据的超时值
+ :param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。
+ :param persistent_workers: 如果其为 ``True``, ``PaddleDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False``
+ """
+
+ def __init__(self, dataset, feed_list=None, places=None,
+ return_list: bool = True, batch_sampler=None,
+ batch_size: int = 16, shuffle: bool = False,
+ drop_last: bool = False, collate_fn: Union[str, Callable, None] = 'auto',
+ num_workers: int = 0, use_buffer_reader: bool = True,
+ use_shared_memory: bool = True, timeout: int = 0,
+ worker_init_fn: Callable = None, persistent_workers=False) -> None:
+
+ # FastNLP Datset, collate_fn not None
+ if isinstance(dataset, FDataSet) and collate_fn is None:
+ raise ValueError("When use FastNLP DataSet, collate_fn must be not None")
+
+ if not isinstance(dataset, _PaddleDataset):
+ dataset = _PaddleDataset(dataset)
+
+ if batch_sampler is None:
+ batch_sampler = RandomBatchSampler(dataset, batch_size=batch_size, shuffle=shuffle,
+ drop_last=drop_last)
+ # 因为无论如何传给 DataLoader 的 batch_sampler 都不是 None
+ # 所以要恢复默认值防止报错
+ batch_size = 1
+ shuffle = False
+ drop_last = False
+
+ if isinstance(collate_fn, str):
+ if collate_fn == 'auto':
+ if isinstance(dataset.dataset, FDataSet):
+ collate_fn = deepcopy(dataset.dataset.collator)
+ collate_fn.set_backend(backend="paddle")
+ else:
+ collate_fn = Collator(backend="paddle")
+
+ else:
+ raise ValueError(f"collate_fn: {collate_fn} must be 'auto'")
+
+ dl_kwargs = _match_param(PaddleDataLoader.__init__, DataLoader.__init__, DataLoader.__name__)
+ if dl_kwargs is None:
+ super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places,
+ return_list=return_list, batch_sampler=batch_sampler,
+ batch_size=batch_size, shuffle=shuffle, drop_last=drop_last,
+ collate_fn=collate_fn, num_workers=num_workers,
+ use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory,
+ timeout=timeout, worker_init_fn=worker_init_fn,
+ persistent_workers=persistent_workers)
+ else:
+ super().__init__(**dl_kwargs)
+ # _collate_fn = _MultiCollator(AutoCollator(as_numpy=True))
+ # if collate_fn is not None:
+ # _collate_fn.add_collator(collate_fn)
+ # self._collate_fn = _collate_fn
+ self.cur_batch_indices = None
+
+ def __iter__(self):
+ # 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。
+ # if len(self._collate_fn.get_collators()) == 0:
+ # self._collate_fn.add_collator(default_collate_fn)
+ # self._collate_fn = default_collate_fn
+ self.collate_fn = indice_collate_wrapper(self.collate_fn)
+ for indices, data in super().__iter__():
+ self.cur_batch_indices = indices
+ yield data
+
+ def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None,
+ pad_fn: Callable = None) -> Collator:
+ """
+ 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。
+
+ :param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的
+ field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``;
+ 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。
+ 如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 "_single" 。
+ :param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
+ field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``,
+ 该值无意义。
+ :param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。
+ :param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`,
+ :class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。
+ 若 ``pad_val`` 为 ``None`` ,该值无意义 。
+ :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的
+ batch 形式。 collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。
+ :return: 使用的 collator
+ """
+ collator = self._get_collator()
+ if isinstance(collator, Collator):
+ collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
+ return collator
+ else:
+ raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")
+
+ def _get_collator(self):
+ """
+ 如果 collate_fn 是 Collator 对象,得到该对象。如果没有的话,返回 None
+
+ :return:
+ """
+ collator = None
+ if hasattr(self.collate_fn, '__wrapped__') and isinstance(self.collate_fn.__wrapped__, Collator):
+ collator = self.collate_fn.__wrapped__
+ elif isinstance(self.collate_fn, Collator):
+ collator = self.collate_fn
+ return collator
+
+ def set_ignore(self, *field_names) -> Collator:
+ """
+ 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略::
+
+ dataloader.set_ignore('field1', 'field2')
+
+ :param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的
+ field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``;
+ 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。
+ :return: 使用的 collator
+ """
+ collator = self._get_collator()
+ if isinstance(collator, Collator):
+ collator.set_ignore(*field_names)
+ return collator
+ else:
+ raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
+
+ def get_batch_indices(self) -> List[int]:
+ """
+ 获取当前 ``batch`` 中每条数据对应的索引。
+
+ :return: 当前 ``batch`` 数据的索引;
+ """
+ return self.cur_batch_indices
+
+
+def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
+ return_list: bool = True,
+ batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
+ batch_size: int = 16, shuffle: bool = False,
+ drop_last: bool = False, collate_fn: Union[Callable, str, None] = 'auto',
+ num_workers: int = 0, use_buffer_reader: bool = True,
+ use_shared_memory: bool = True, timeout: int = 0,
+ worker_init_fn: Callable = None, persistent_workers=False,
+ non_train_batch_size: int = None) \
+ -> Union[Dict[str, PaddleDataLoader], PaddleDataLoader]:
+ """
+ ``prepare_paddle_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``PaddleDataloader`` 对象, 详见 :class:`~fastNLP.PaddleDataLoader`。
+ 根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下:
+
+ * 当 ds_or_db 为 ``DataSet`` 时,``prepare_paddle_dataloader`` 会将除了 ``non_train_batch_size`` 和 ``non_train_sampler`` 以外的参数来
+ 帮你实例化一个 ``PaddleDataLoader`` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.PaddleDataLoader`。
+ * 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_paddle_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value
+ 来创建不同的 ``PaddleDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_Paddle_dataloader`` 默认该 value 为训练数据集,
+ 会将 ``batch_size`` 和 ``sampler`` 作为参数,其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。
+ 最终根据 ``key: PaddleDataLoader`` 组成 ``Dict[key, PaddleDataLoader]`` 的字典返回。
+ * 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_paddle_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的
+ ``PaddleDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_paddle_dataloader`` 默认该 value 为训练数据集,会将 ``batch_size`` 和 ``sampler`` 作为参数,
+ 其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。最终根据 ``key: PaddleDataLoader`` 组成
+ ``Dict[key, PaddleDataLoader]`` 的字典返回。
+
+ :param ds_or_db: 可以有以下三种取值,
+
+ * ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典;
+ * ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典;
+ * ds_or_db 为实现了 __getitem__() 和 __len__() 的对象 ,返回值为 :class:`~fastNLP.TorchDataLoader`;
+
+ :param feed_list: feed Tensor list.
+ 这个张量能被 ``paddle.static.data`` 创建。 如果 :attr:`return_list` 是 ``False``, 那么 :attr:`feed_list`
+ 应该被设置。 默认为 ``None `` 。
+ :param places: 将数据放进的一个 list 的 place。 :attr:`places` 能为 None.
+ 如果 :attr:`places` 为 None, 默认放在 CPUPlace 或者 CUDAPlace(0) 设备上。 如果 ``places`` 是一个 list 类型的 字符串, 那么字符串
+ 可以是 ``cpu`` , ``gpu:x`` 或者 ``gpu_pinned`` , 其中 ``x`` 是 gpu 的下标。
+ :param return_list: 每个设备上的返回值是否为以列表形式显示。 如果 :attr:`return_list=False`,
+ 每个设备上的返回值值为 str -> Tensor 的 dict, 其中 dict 的 key 为每个 fed Tensors 的名字。
+ 如果 :attr:`return_list=True`, 每个设备上的返回值值为 list(Tensor)。 :attr:`return_list` 只能在动态图情况下设置为 ``True`` .
+ 默认值为 ``True`` 。
+ :param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为
+ dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``shuffle`` 参数均失效。
+ :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
+ :param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` ,
+ 其它的为 False 。
+ :param drop_last: 当 ``drop_last=True`` 时,``PaddleDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据;
+ 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。
+ :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``.
+
+ * callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时,
+ ``PaddleDataLoader`` 调用默认的 Paddle 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理
+ :class:`~fastNLP.core.dataset.DataSet` 的dataset对象。
+ * callate_fn 为 ``'auto'`` 时,``PaddleDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。
+ 此时可以配套使用 ``PaddleDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。
+ * collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是
+ dataset 的一条数据;该 Callable 函数还应当返回一个对象。
+
+ :param num_workers: 当 ``num_workers > 0`` 时, ``PaddleDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快
+ 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。
+ :param use_buffer_reader: 是否开启 buffer_reader 。如果 ``use_buffer_reader=True`` ,那么 ``PaddleDataLoader`` 会异步地预取下一个 batch 的
+ 数据,因此它将会加快数据传输的速度,但是将会占用更多的内存或者显存。默认值是 ``True``。
+ :param use_shared_memory: 是否使用共享内存。当 ``use_shared_memory=True`` 时,将采用共享内存来加快将数据放进进程队列。建议仅当计算机上的
+ 共享空间足够大时。(例如 Linux 上的 /dev/shm/ 空间足够大)共享内存仅在多进程模式( ``num_workers>0`` )下生效。
+ :param timeout: 从子进程的输出队列获取数据的超时值
+ :param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。
+ :param persistent_workers: 如果其为 ``True``, ``PaddleDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False``
+
+ """
+ from fastNLP.io.data_bundle import DataBundle
+
+ if isinstance(ds_or_db, DataBundle):
+ assert batch_sampler is None, "batch_sampler can only be None when multiple datasets are presented."
+ dl_bundle = {}
+ for name, ds in ds_or_db.iter_datasets():
+ if 'train' in name:
+ dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places,
+ return_list=return_list,
+ batch_sampler=batch_sampler, batch_size=batch_size,
+ shuffle=True if shuffle is None else shuffle,
+ drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
+ use_shared_memory=use_shared_memory,
+ use_buffer_reader=use_buffer_reader,
+ timeout=timeout, worker_init_fn=worker_init_fn,
+ persistent_workers=persistent_workers)
+ else:
+ dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places,
+ return_list=return_list,
+ batch_sampler=batch_sampler,
+ batch_size=non_train_batch_size if non_train_batch_size else batch_size,
+ shuffle=False if shuffle is None else shuffle,
+ drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
+ use_shared_memory=use_shared_memory,
+ use_buffer_reader=use_buffer_reader,
+ timeout=timeout, worker_init_fn=worker_init_fn,
+ persistent_workers=persistent_workers)
+ return dl_bundle
+
+ elif isinstance(ds_or_db, Dict):
+ assert batch_sampler is None, "batch_sampler can only be None when multiple datasets are presented."
+ ds_dict = {}
+ for name, ds in ds_or_db.items():
+ if 'train' in name:
+ dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list,
+ batch_sampler=batch_sampler, batch_size=batch_size,
+ shuffle=False if shuffle is None else shuffle,
+ drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
+ use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader,
+ timeout=timeout, worker_init_fn=worker_init_fn,
+ persistent_workers=persistent_workers)
+ else:
+ dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list,
+ batch_sampler=batch_sampler,
+ batch_size=non_train_batch_size if non_train_batch_size else batch_size,
+ shuffle=False if shuffle is None else shuffle,
+ drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
+ use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader,
+ timeout=timeout, worker_init_fn=worker_init_fn,
+ persistent_workers=persistent_workers)
+ ds_dict[name] = dl
+ return ds_dict
+
+ elif isinstance(ds_or_db, HasLenGetitemType):
+ dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list,
+ batch_sampler=batch_sampler, batch_size=batch_size,
+ shuffle=False if shuffle is None else shuffle,
+ drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
+ use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader,
+ timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers)
+ return dl
+ else:
+ raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or mapping!")
diff --git a/fastNLP/core/dataloaders/prepare_dataloader.py b/fastNLP/core/dataloaders/prepare_dataloader.py
new file mode 100644
index 00000000..65b739aa
--- /dev/null
+++ b/fastNLP/core/dataloaders/prepare_dataloader.py
@@ -0,0 +1,111 @@
+__all__ = [
+ 'prepare_dataloader'
+]
+
+from typing import Union, Callable
+import os
+import sys
+
+from .torch_dataloader import prepare_torch_dataloader
+from .paddle_dataloader import prepare_paddle_dataloader
+from .jittor_dataloader import prepare_jittor_dataloader
+from .oneflow_dataloader import prepare_oneflow_dataloader
+from ...envs import FASTNLP_BACKEND, SUPPORT_BACKENDS
+from ..log import logger
+
+
+def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = None, drop_last: bool = False,
+ collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0,
+ backend: str = 'auto'):
+ """
+ 自动创建合适的 ``DataLoader`` 对象。例如,检测当当前环境是 ``torch`` 的,则返回 ``TorchDataLoader`` , 是 ``paddle`` 的则
+ 返回 ``PaddleDataLoader`` 。如果有更多需要定制的参数,请直接使用对应的 ``prepare`` 函数,例如
+ :func:`~fastNLP.core.dataloaders.prepare_torch_dataloader` 或 :func:`~fastNLP.core.dataloaders.prepare_paddle_dataloader` 等。
+
+ :param dataset: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。
+
+ * 为单个数据集对象时,返回一个 DataLoader 。
+ * 为数据集对象序列时,返回一个序列的 DataLoader 。
+ * 为字典型 或 :class:`~fastNLP.io.DataBundle` 数据时,返回 :class:`Dict` 类型的数据。
+
+ :param batch_size: 批次大小。
+ :param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` ,
+ 其它的为 False 。
+ :param drop_last: 当最后一个 batch 不足 ``batch_size`` 数量的是否,是否丢弃。
+ :param collate_fn: 用于处理一个 batch 的函数,一般包括 padding 和转为 tensor。有以下三种取值:
+
+ * 为 ``auto`` 时,使用 :class:`~fastNLP.Collator` 进行 padding 和 转tensor 。
+ * 为 :class:`Callable` 时,应当接受一个 ``batch`` 的数据作为参数,同时输出一个对象 。
+ * 为 ``None`` 时,使用各个框架的 DataLoader 的默认 ``collate_fn`` 。
+ :param num_workers: 使用多少进程进行数据的 fetch 。
+ :param backend: 当前支持 ``["auto", "torch", "paddle", "jittor", "oneflow"]`` 四种类型。
+
+ * 为 ``auto`` 时,首先根据环境变量 ``"FASTNLP_BACKEND"`` 进行判断;如果没有设置则通过当前
+ ``sys.modules`` 中已经 import 的 ``backend`` 进行判定。如果以上均无法判定,则报错。如果找到了
+ ``backend`` ,则按照下述的方式处理。
+ * 为 ``torch`` 时,使用 :func:`~fastNLP.core.dataloaders.prepare_torch_dataloader` 。
+ * 为 ``paddle`` 时,使用 :func:`~fastNLP.core.dataloaders.prepare_paddle_dataloader` 。
+ * 为 ``jittor`` 时,使用 :func:`~fastNLP.core.dataloaders.prepare_jittor_dataloader` 。
+ * 为 ``oneflow`` 时,使用 :func:`~fastNLP.core.dataloaders.prepare_oneflow_dataloader` 。
+
+ :return
+ """
+ if backend == 'auto':
+ backend = _get_backend()
+ if backend == 'torch':
+ return prepare_torch_dataloader(ds_or_db=dataset, batch_sampler=None, collate_fn=collate_fn,
+ num_workers=num_workers, shuffle=shuffle, sampler=None,
+ batch_size=batch_size)
+ elif backend == 'paddle':
+ return prepare_paddle_dataloader(ds_or_db=dataset, batch_sampler=None, collate_fn=collate_fn,
+ num_workers=num_workers, batch_size=batch_size, shuffle=shuffle)
+ elif backend == 'jittor':
+ prepare_jittor_dataloader(ds_or_db=dataset, sampler=None, collate_fn=collate_fn,
+ num_workers=num_workers, batch_size=batch_size, shuffle=shuffle,
+ drop_last=drop_last)
+ elif backend == 'oneflow':
+ return prepare_oneflow_dataloader(ds_or_db=dataset, batch_sampler=None, collate_fn=collate_fn,
+ num_workers=num_workers, shuffle=shuffle, sampler=None,
+ batch_size=batch_size)
+ else:
+ raise ValueError(f"Currently we do not support backend:{backend}.")
+
+
+def _check_module(module):
+ """
+ 检查该 module 是否含有 某个 backend 的特征
+
+ :param module: module 对象
+ :return:
+ """
+ try:
+ file = module.__file__
+ for backend in SUPPORT_BACKENDS:
+ if f'{os.sep}site-packages{os.sep}{backend}' in file:
+ return backend
+ if f'{os.sep}dist-packages{os.sep}{backend}' in file:
+ return backend
+ except:
+ pass
+ return None
+
+
+def _get_backend():
+ if os.environ.get(FASTNLP_BACKEND, None) != None:
+ backend = os.environ.get(FASTNLP_BACKEND)
+ logger.debug(f"Get Dataloader backend:{backend} from os.environ")
+ else:
+ available_backends = set()
+ for module in sys.modules.values():
+ _backend = _check_module(module)
+ if _backend:
+ available_backends.add(_backend)
+ if len(available_backends) == 1:
+ backend = available_backends.pop()
+ logger.debug(f"Get Dataloader backend:{backend} from sys.modules.")
+ elif len(available_backends) > 1:
+ raise RuntimeError("Fail to detect dataloader backend automatically, because multiple backends:"
+ f"{available_backends} has been imported.")
+ else:
+ raise RuntimeError("Fail to detect dataloader backend automatically, please set it manually.")
+ return backend
\ No newline at end of file
diff --git a/fastNLP/core/dataloaders/torch_dataloader/__init__.py b/fastNLP/core/dataloaders/torch_dataloader/__init__.py
new file mode 100644
index 00000000..a55d3d0d
--- /dev/null
+++ b/fastNLP/core/dataloaders/torch_dataloader/__init__.py
@@ -0,0 +1,8 @@
+__all__ = [
+ "TorchDataLoader",
+ "prepare_torch_dataloader",
+ "MixDataLoader"
+]
+
+from .fdl import TorchDataLoader, prepare_torch_dataloader
+from .mix_dataloader import MixDataLoader
diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py
new file mode 100644
index 00000000..8aa48382
--- /dev/null
+++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py
@@ -0,0 +1,367 @@
+__all__ = [
+ 'TorchDataLoader',
+ 'prepare_torch_dataloader'
+]
+
+from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List, Any
+from abc import ABC
+from copy import deepcopy
+
+from fastNLP.core.dataset import DataSet
+from fastNLP.core.collators import Collator
+from fastNLP.core.dataloaders.utils import indice_collate_wrapper
+from fastNLP.envs.imports import _NEED_IMPORT_TORCH
+from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler
+from ..utils import _match_param
+from ..utils import HasLenGetitemType
+
+if _NEED_IMPORT_TORCH:
+ from torch.utils.data import DataLoader, Sampler, Dataset
+else:
+ from fastNLP.core.utils.dummy_class import DummyClass as DataLoader
+
+
+class _FDataSet:
+ """
+ 提供给 ``TorchDataLoader`` 使用的 warp 类,其功能是对 dataset 进行封装,wrap 修改 dataset 的 __getitem__ 函数,增加返回
+ 数据的下标 idx 。
+
+ ..note::
+
+ 需要注意的是传入 ``__init__`` 的 dataset 需要实现 __getattribute__ 方法才能在 _FDataset 实例化对象中调用 dataset 的方法
+
+ """
+
+ def __init__(self, dataset) -> None:
+ self.dataset = dataset
+
+ def __getitem__(self, item: Union[int, list]) -> Tuple:
+ return (item, self.dataset[item])
+
+ def __getattr__(self, item):
+ try:
+ return self.dataset.__getattribute__(item)
+ except AttributeError as e:
+ raise e
+
+ def __len__(self) -> int:
+ return len(self.dataset)
+
+ # 这里需要显示地带上这两个方法,因为可能会涉及到 pickle 的 dumps 和 loads;否则会导致 pickle 在 loads 时调用 __setstate__ 方法
+ # 进入到 __getattr__ 内部,引发死循环;
+ # https://docs.python.org/3/library/pickle.html#pickling-class-instances
+ # https://stackoverflow.com/questions/73662315/when-using-multiprocessing-and-spawn-in-python-use-self-a-in-getattr-cause?noredirect=1
+ def __getstate__(self):
+ return self.__dict__
+
+ def __setstate__(self, state):
+ self.__dict__ = state
+
+
+class TorchDataLoader(DataLoader):
+ """
+ 提供给 ``torch`` 框架使用的 ``DataLoader`` 函数,``TorchDataLoader`` 提供了 ``Collator`` 来自动检测 dataset 的每个 field 是否可 pad,
+ 若是可 pad 的 field 则自动 pad 到相同长度,否则只会将相同 field 的数据收集组成一个 batch 返回。
+ 具体详见 :class:`~fastNLP.core.collators.Collator`;用户通过 callte_fn 来控制是否使用该功能, collate_fn 只能为 ``['auto', None, Callable]``
+ 三种取值。
+
+ * callate_fn 为 ``'auto'`` 时,``TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的取值。
+ 此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。
+ * callate_fn 为 ``None`` 时, ``TorchDataLoadr`` 默认使用 torch DataLoader 自带的 collate_fn
+ * collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是
+ dataset 的一条数据;该 Callable 函数还应当返回一个对象。
+
+ :param dataset: 实现了 __getitem__() 和 __len__() 的对象。
+ :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
+ :param non_train_batch_size: 非训练数据集的 ``TorchDataLoader`` 批次大小,默认为 ``16`` 且当 ``batch_sampler`` 为 ``None`` 有效。
+ :param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` ,
+ 其它的为 False 。
+ :param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index ,
+ 默认为 ``None``, 当其不为 ``None`` 时, shuffle 参数无效。
+ :param non_train_sampler: 非训练数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index ,
+ 默认为None, 当其不为 None 时, shuffle 参数无效。
+ :param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为
+ dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``sampler``, ``shuffle`` 参数均失效。
+ :param num_workers: 当 ``num_workers > 0`` 时, ``TorchDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快
+ 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。
+ :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``.
+
+ * callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时,
+ ``TorchDataLoader`` 调用默认的 torch 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理
+ :class:`~fastNLP.core.dataset.DataSet` 的dataset对象。
+ * callate_fn 为 ``'auto'`` 时,``TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。
+ 此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。
+ * collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是
+ dataset 的一条数据;该 Callable 函数还应当返回一个对象。
+
+ :param pin_memory: 如果其为 ``True``, 那么 ``TorchDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。
+ :param drop_last: 当 ``drop_last=True`` 时,``TorchDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据;
+ 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。
+ :param timeout: 子进程的输出队列获取数据的超时值
+ :param worker_init_fn: init 函数,如果不设置为 ``None``,则将会在每个子进程初始化时调用该函数。
+ :param multiprocessing_context: 多进程的上下文环境
+ :param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个 ``base_seed``
+ :param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2`` 意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` .
+ :param persistent_workers: 如果其为 ``True``, ``TorchDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False``
+ """
+
+ def __init__(self, dataset, batch_size: int = 16,
+ shuffle: bool = False, sampler=None, batch_sampler=None,
+ num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto',
+ pin_memory: bool = False, drop_last: bool = False,
+ timeout: float = 0, worker_init_fn: Optional[Callable] = None,
+ multiprocessing_context=None, generator=None, prefetch_factor: int = 2,
+ persistent_workers: bool = False, **kwargs) -> None:
+
+ if isinstance(dataset, DataSet) and collate_fn is None:
+ raise ValueError("When use FastNLP DataSet, collate_fn must be not None")
+
+ if not isinstance(dataset, _FDataSet):
+ dataset = _FDataSet(dataset)
+
+ if num_workers>0 and multiprocessing_context is None:
+ multiprocessing_context = 'fork' # 这里默认使用fork的方式来启动多进程
+
+ if batch_sampler is not None:
+ batch_size = 1
+ shuffle = False
+ sampler = None
+ elif sampler is None:
+ sampler = RandomSampler(dataset, shuffle=shuffle)
+ shuffle = False
+
+ if isinstance(collate_fn, str):
+ if collate_fn == 'auto':
+ if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset
+ collate_fn = deepcopy(dataset.dataset.collator)
+ collate_fn.set_backend(backend="torch")
+ else:
+ collate_fn = Collator(backend="torch")
+ else:
+ raise ValueError(f"collate_fn: {collate_fn} must be 'auto'")
+
+ dl_kwargs = _match_param(TorchDataLoader.__init__, DataLoader.__init__, fn_name=DataLoader.__name__)
+ if dl_kwargs is None:
+ super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
+ batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn,
+ pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
+ multiprocessing_context=multiprocessing_context, generator=generator,
+ prefetch_factor=prefetch_factor,
+ persistent_workers=persistent_workers)
+ else:
+ super().__init__(**dl_kwargs)
+
+ self.cur_batch_indices = None
+
+ def __iter__(self):
+ self.collate_fn = indice_collate_wrapper(self.collate_fn)
+ for indices, data in super().__iter__():
+ self.cur_batch_indices = indices
+ yield data
+
+ def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None,
+ pad_fn: Callable = None) -> Collator:
+ """
+ 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。
+
+ :param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的
+ field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``;
+ 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。
+ 如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 "_single" 。
+ :param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
+ field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``,
+ 该值无意义。
+ :param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。
+ :param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`,
+ :class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。
+ 若 ``pad_val`` 为 ``None`` ,该值无意义 。
+ :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的
+ batch 形式。 collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。
+ :return: 使用的 collator
+ """
+ collator = self._get_collator()
+ if isinstance(collator, Collator):
+ collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
+ return collator
+ else:
+ raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")
+
+ def _get_collator(self):
+ """
+ 如果 collate_fn 是 Collator 对象,得到该对象。如果没有的话,返回 None
+
+ :return:
+ """
+ collator = None
+ if hasattr(self.collate_fn, '__wrapped__') and isinstance(self.collate_fn.__wrapped__, Collator):
+ collator = self.collate_fn.__wrapped__
+ elif isinstance(self.collate_fn, Collator):
+ collator = self.collate_fn
+ return collator
+
+ def set_ignore(self, *field_names) -> Collator:
+ """
+ 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略::
+
+ dataloader.set_ignore('field1', 'field2')
+
+ :param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的
+ field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``;
+ 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。
+ :return: 使用的 collator
+ """
+ collator = self._get_collator()
+ if isinstance(collator, Collator):
+ collator.set_ignore(*field_names)
+ return collator
+ else:
+ raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
+
+ def get_batch_indices(self) -> List[int]:
+ """
+ 获取当前 ``batch`` 中每条数据对应的索引。
+
+ :return: 当前 ``batch`` 数据的索引;
+ """
+ return self.cur_batch_indices
+
+
+def prepare_torch_dataloader(ds_or_db,
+ batch_size: int = 16,
+ shuffle: bool = None,
+ sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
+ batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
+ num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto',
+ pin_memory: bool = False, drop_last: bool = False,
+ timeout: float = 0, worker_init_fn: Optional[Callable] = None,
+ multiprocessing_context=None, generator=None, prefetch_factor: int = 2,
+ persistent_workers: bool = False,
+ non_train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
+ non_train_batch_size: int = None) \
+ -> Union[TorchDataLoader, Dict[str, TorchDataLoader]]:
+ """
+ ``prepare_torch_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``TorchDataloader`` 对象, 详见 :class:`~fastNLP.TorchDataLoader`。
+ 根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下:
+
+ * 当 ds_or_db 为 ``DataSet`` 时,``prepare_torch_dataloader`` 会将使用的除了 ``non_train_batch_size`` 和 ``non_train_sampler`` 以外的参数来
+ 帮你实例化一个 ``TorchDataLoader`` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.TorchDataLoader`。
+ * 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_torch_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value
+ 来创建不同的 ``TorchDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_torch_dataloader`` 默认该 value 为训练数据集,
+ 会将 ``batch_size`` 和 ``sampler`` 作为参数,其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。
+ 最终根据 ``key: TorchDataLoader`` 组成 ``Dict[key, TorchDataLoader]`` 的字典返回。
+ * 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_torch_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的
+ ``TorchDataLoader`` 对象;当 key 中包含 ``'train'`` 字符串时,``prepare_torch_dataloader`` 默认该 value 为训练数据集,会将 ``batch_size`` 和 ``sampler`` 作为参数,
+ 其他 key 不包含 ``'train'`` 字符串的数据集则使用 ``non_train_size`` 和 ``non_train_sampler`` 作为参数。最终根据 ``key: TorchDataLoader`` 组成
+ ``Dict[key, TorchDataLoader]`` 的字典返回。
+
+ :param ds_or_db: 可以有以下三种取值,
+
+ * ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典;
+ * ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典;
+ * ds_or_db 为实现了 __getitem__() 和 __len__() 的对象 ,返回值为 :class:`~fastNLP.TorchDataLoader`;
+
+ :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
+ :param non_train_batch_size: 非训练数据集的 ``TorchDataLoader`` 批次大小,默认为 ``16`` 且当 ``batch_sampler`` 为 ``None`` 有效。
+ :param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 ``'train'`` 则设置其 shuffle 为 ``True`` ,
+ 其它的为 False 。
+ :param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index ,
+ 默认为 ``None``, 当其不为 ``None`` 时, shuffle 参数无效。
+ :param non_train_sampler: 非训练数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index ,
+ 默认为None, 当其不为 None 时, shuffle 参数无效。
+ :param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为
+ dataset 的下标 index ;默认为 ``None``,当其不为 ``None`` 时,``bacth_size``, ``sampler``, ``shuffle`` 参数均失效。
+ :param num_workers: 当 ``num_workers > 0`` 时, ``TorchDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快
+ 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。
+ :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``.
+
+ * callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时,
+ ``TorchDataLoader`` 调用默认的 torch 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理
+ :class:`~fastNLP.core.dataset.DataSet` 的dataset对象。
+ * callate_fn 为 ``'auto'`` 时,``TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。
+ 此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。
+ * collate_fn 为 :class:`Callable` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是
+ dataset 的一条数据;该 Callable 函数还应当返回一个对象。
+
+ :param pin_memory: 如果其为 ``True``, 那么 ``TorchDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。
+ :param drop_last: 当 ``drop_last=True`` 时,``TorchDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据;
+ 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。
+ :param timeout: 子进程的输出队列获取数据的超时值
+ :param worker_init_fn: init 函数,如果不设置为 ``None``,则将会在每个子进程初始化时调用该函数。
+ :param multiprocessing_context: 多进程的上下文环境
+ :param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个 ``base_seed``
+ :param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2`` 意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` .
+ :param persistent_workers: 如果其为 ``True``, ``TorchDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False``
+
+ """
+
+ from fastNLP.io import DataBundle
+
+ if isinstance(ds_or_db, DataBundle):
+ assert sampler is None, "sampler can only be None when multiple datasets are presented."
+ assert batch_sampler is None, "batch_sampler can only be None when multiple datasets are presented."
+ dl_bundle = {}
+ for name, ds in ds_or_db.iter_datasets():
+ if 'train' in name:
+ dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size,
+ shuffle=True if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler,
+ num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
+ drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
+ multiprocessing_context=multiprocessing_context, generator=generator,
+ prefetch_factor=prefetch_factor,
+ persistent_workers=persistent_workers,
+ )
+ else:
+ dl_bundle[name] = TorchDataLoader(dataset=ds,
+ batch_size=non_train_batch_size if non_train_batch_size else batch_size,
+ shuffle=False if shuffle is None else shuffle,
+ sampler=non_train_sampler if non_train_sampler else sampler,
+ batch_sampler=batch_sampler,
+ num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
+ drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
+ multiprocessing_context=multiprocessing_context, generator=generator,
+ prefetch_factor=prefetch_factor,
+ persistent_workers=persistent_workers,
+ )
+ return dl_bundle
+
+ elif isinstance(ds_or_db, Mapping):
+ assert sampler is None, "sampler can only be None when multiple datasets are presented."
+ assert batch_sampler is None, "batch_sampler can only be None when multiple datasets are presented."
+ dl_bundle = {}
+ for name, ds in ds_or_db.items():
+ if 'train' in name:
+ dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size,
+ shuffle=True if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler,
+ num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
+ drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
+ multiprocessing_context=multiprocessing_context, generator=generator,
+ prefetch_factor=prefetch_factor,
+ persistent_workers=persistent_workers,
+ )
+ else:
+ dl_bundle[name] = TorchDataLoader(dataset=ds,
+ batch_size=non_train_batch_size if non_train_batch_size else batch_size,
+ shuffle=False if shuffle is None else shuffle,
+ sampler=non_train_sampler if non_train_sampler else sampler,
+ batch_sampler=batch_sampler,
+ num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
+ drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
+ multiprocessing_context=multiprocessing_context, generator=generator,
+ prefetch_factor=prefetch_factor,
+ persistent_workers=persistent_workers,
+ )
+
+ return dl_bundle
+
+ elif isinstance(ds_or_db, HasLenGetitemType):
+ dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size,
+ shuffle=False if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler,
+ num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
+ drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
+ multiprocessing_context=multiprocessing_context, generator=generator,
+ prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
+ )
+ return dl
+
+ else:
+ raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or mapping!")
diff --git a/fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py b/fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py
new file mode 100644
index 00000000..7dd4bba5
--- /dev/null
+++ b/fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py
@@ -0,0 +1,222 @@
+__all__ = [
+ 'MixDataLoader'
+]
+
+from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence, Mapping
+
+import numpy as np
+from pkg_resources import parse_version
+
+from fastNLP.core.dataset import DataSet, Instance
+from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler
+from fastNLP.envs.imports import _NEED_IMPORT_TORCH
+from fastNLP.core.collators import Collator
+
+if _NEED_IMPORT_TORCH:
+ from torch import __version__ as torchversion
+ from torch.utils.data import DataLoader, Sampler
+else:
+ from fastNLP.core.utils.dummy_class import DummyClass as DataLoader
+
+
+class _MixDataset:
+ """
+ 将所有数据集当成一个混合大数据集来对待, 在 __getitem__() 能根据输入的 idx 来判断属于哪个小数据并返回其 ds_index
+
+ """
+
+ def __init__(self, datasets: list = None) -> None:
+ """
+ :param datasets: 实现了 __getitem__() 和 __len__() 的对象的序列
+ """
+ self.datasets = datasets
+ # 记录每个数据集的长度索引, 以便根据idx定位数据集的位置
+ self.lens = []
+ index = 0
+ for item in self.datasets:
+ index += len(item)
+ self.lens.append(index)
+
+ def __getitem__(self, idx: Union[int, List[int]]) -> Union[Tuple[Instance, int], Tuple[DataSet, int]]:
+ """
+ 根据index索引获取数据, 能够跟 idx 的范围定位属于哪个小数据并返回
+
+ :param idx: 整数类型的index或者列表
+ :return:
+ """
+ if isinstance(idx, int):
+ if idx >= self.lens[-1]:
+ raise ValueError(f"idx: {idx} out of range")
+ # 找到其属于哪个数据集,返回下标
+ ds_index = np.searchsorted(self.lens, idx, side='right')
+ if ds_index > 0:
+ idx -= self.lens[ds_index - 1]
+ return self.datasets[ds_index][idx], ds_index
+ elif isinstance(idx, list):
+ # 一般一个list列表只能是属于一种数据的,否则会报错
+ dataset = DataSet()
+ ds_index = 0
+ for i in idx:
+ assert isinstance(i, int), "Only int index allowed."
+ instance, ds_index = self[i]
+ dataset.append(instance)
+ return dataset, ds_index
+ else:
+ raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))
+
+ def __len__(self) -> int:
+ return self.lens[-1]
+
+
+class _MixCollateFn:
+ """
+ 存在多个auto_collate和多个collate_fn时候,对一个批次数据集应用哪个auto_collate和collate_fn的问题
+
+ """
+
+ def __init__(self, collate_fns: Union[List[Callable], Callable]) -> None:
+
+ if isinstance(collate_fns, Sequence):
+ self.collate_fns = lambda idx, lst: collate_fns[idx](lst)
+ elif callable(collate_fns):
+ self.collate_fns = lambda idx, lst: collate_fns(lst)
+ else:
+ self.collate_fns = lambda idx, lst: lst
+
+ def __call__(self, ins_list: List) -> Dict:
+ """
+ 调用一次该方法,我们将ins_list视为同一个数据集采样出来的,故ds_index只能为一种
+
+ :param ins_list:
+ :return:
+ """
+ _ins_list, _ds_index = [], 0
+ for ins, _ds_index in ins_list:
+ _ins_list.append(ins)
+ _ins_list = self.collate_fns(_ds_index, _ins_list)
+ return _ins_list
+
+
+class MixDataLoader(DataLoader):
+ """
+ 针对以下四种情况提供的 ``MixDataLoader``, 目前只支持 **pytorch** 框架的版本, 其中 mode 的取值范围为 ``['sequential', 'mix', 'polling', 'Sampler']``:
+
+ * 当 mode 为 ``'sequential'`` 时,``MixDataLoader`` 将 ``datasets`` 的序列或者字典视为一个混合大数据集, 按照 datasets 数据集序列或者字典的顺序一个
+ 接一个的 sample 完所有数据。
+ * 当 mode 为 ``'mix'`` 时, ``MixDataLoader`` 将 ``datasets`` 的序列或者字典视为一个混合大数据集, 然后根据用户输入的 idx 序列随机 sample
+ 混合数据集 datasets 的数据组成一个 batch 序列返回。
+ * 当 mode 为 ``'polling'`` 时, ``MixDataLoader`` 按照 ``datasets`` 数据集的顺序, 先从第一个数据集采样一个 batch 的数据返回,
+ 再从第二数据集采样一个 batch 数据返回, 直至最后一个数据集采样一个 batch 数据返回后再从第一个数据采样第二个 batch 数据返回,直至所有的数据集都被轮询地的采样完。
+ * 当 mode 为 ``'Sampler'`` 时, 该 Sampler 是实现 __iter__() 的实例化对象, 其功能是每次 iter 时返回一个 batch 序列, 其类型为 List[int];
+ 且 Sampler 必须将输入的 datasets 视为一个混合大数据集, 其 index 范围为 ``0 0`` 时, ``MixDataLoader`` 会开启 ``num_workers`` 个子进程来处理数据, 可以加快数据处理速度,但同时
+ 也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。
+ :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 ``None`` 有效。 且 datasets 上所有 dataset 的 batch_size 一致。
+ :param drop_last: 当 ``drop_last=True`` 时,``MixDataLoader`` 会扔掉 datasets 中 每个 dataset 最后一个长度小于 ``batch_size`` 的 batch 数据;
+ 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。
+ :param ds_ratio: ``ds_ratio`` 是控制 datasets 怎么组成一个混合大数据集的重要参数, 其取值为 ``[None, 'truncate_to_least', 'pad_to_most', List[float], Dict[str, float]]``:
+
+ * ds_ratio 为 ``None``, datasets 数据集序列或字典不进行数据扩充处理。
+ * ds_ratio 为 ``'truncate_to_least'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最断长度 ``mix_len``, 其他数据集会被切断
+ 到最短长度 ``mix_len``。这种切断不是物理上切断,``MixDataLoader`` 会根据 sampler 不同来采样数据集到指定的最短长度 ``mix_len``。
+ * ds_ratio 为 ``'pad_to_most'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最大长度 ``max_len``, 其他其他数据集会扩充
+ 到最大长度 ``mix_len``。这种扩充不是物理上扩充, ``MixDataLoader`` 会根据 sampler 不同来重采样 dataset 到指定的最大长度 ``max_len``。
+ * ds_ratio 为 ``Dict[str, float]`` 时, datasets 类型也必须为 ``Dict[str, DataSet]``, 其 key 一一对应。 ds_ratio 的 value 是任意大于 0 的浮点数,
+ 代表着 datasets 的 value 数据进行扩充或者缩减的倍数。
+
+ """
+
+ def __init__(self, datasets: Dict = None, mode: str = 'sequential',
+ collate_fn: Union[str, Callable, Dict[str, Callable]] = 'auto',
+ sampler: Union[str, None] = None,
+ num_workers: int = 0, batch_size: int = 16, drop_last=False,
+ ds_ratio: Union[None, str, Dict[str, float]] = None,
+ pin_memory: bool = False) -> None:
+ # sampler 为 dict,则判断是否与 datasets 的 key 相同
+ if isinstance(sampler, Dict):
+ for key in datasets.keys():
+ if not sampler[key]:
+ raise ValueError(f"the key:{key} of datasets is not in sampler, where sampler is a dict!")
+ # collate_fn 为 dict,则判断是否与 datasets 的 key 相同
+ if isinstance(collate_fn, Dict):
+ if mode == 'mix':
+ raise ValueError(f"mode: {mode} do not support collate_fn is Dict, please use callate_fn=Callable or 'auto'")
+ for key in datasets.keys():
+ if not collate_fn[key]:
+ raise ValueError(f"the key:{key} of datasets is not in collate_fn, where collate_fn is a dict!")
+
+ if isinstance(collate_fn, str) and collate_fn == 'auto':
+ date_type = None
+ for idx, ds in enumerate(datasets.values()):
+ if idx == 0:
+ date_type = type(ds[0])
+ if type(ds[0]) != date_type or not (isinstance(ds[0], List) or isinstance(ds[0], Mapping)):
+ raise ValueError(f"when you use callate_fn={collate_fn}, all dataset must be list or dict。"
+ f"But dataset {idx - 1} data type is {date_type}, dataset {idx} data type is {type(ds[0])}")
+
+ collate_fn = Collator(backend='torch')
+
+ # Dict 类型的 collate_fn 转化为 List,以便于 _MixCollateFn 里面根据 idx 定位 dataset
+ if isinstance(collate_fn, Dict):
+ collate_fn = [fn for _, fn in collate_fn.items()]
+
+ dataset = [ds for _, ds in datasets.items()]
+
+ # 对 collate_fn 进行包裹, 统一处理 collate_fn 不同情况下使用的问题
+ collate_fn = _MixCollateFn(collate_fn)
+
+ if mode == 'sequential':
+ batch_sampler = MixSequentialSampler(datasets, batch_size=batch_size, sampler=sampler,
+ drop_last=drop_last, ds_ratio=ds_ratio)
+ elif mode == 'polling':
+ batch_sampler = PollingSampler(datasets, batch_size=batch_size, sampler=sampler,
+ drop_last=drop_last, ds_ratio=ds_ratio)
+ elif mode == 'mix':
+ batch_sampler = DopedSampler(datasets, batch_size=batch_size, sampler=sampler,
+ drop_last=drop_last, ds_ratio=ds_ratio)
+ elif isinstance(mode, Sampler):
+ batch_sampler = mode
+ else:
+ raise ValueError(f"{mode} must be sequential, polling, mix or batch_sampler")
+
+ if parse_version(torchversion) >= parse_version('1.7'):
+ super(MixDataLoader, self).__init__(
+ _MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None,
+ batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn,
+ pin_memory=pin_memory, drop_last=False, timeout=0,
+ worker_init_fn=None, multiprocessing_context=None, generator=None,
+ prefetch_factor=2, persistent_workers=False
+ )
+ else:
+ super(MixDataLoader, self).__init__(
+ _MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None,
+ batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn,
+ pin_memory=pin_memory, drop_last=False, timeout=0,
+ worker_init_fn=None, multiprocessing_context=None, generator=None,
+ )
+
+ def __iter__(self):
+ return super().__iter__()
diff --git a/fastNLP/core/dataloaders/utils.py b/fastNLP/core/dataloaders/utils.py
new file mode 100644
index 00000000..4a648b99
--- /dev/null
+++ b/fastNLP/core/dataloaders/utils.py
@@ -0,0 +1,157 @@
+import os
+from typing import Callable, Any, Union, Sequence
+from abc import ABC
+import inspect
+import ast
+
+from ..log import logger
+from ..utils.cache_results import get_func_calls, truncate_start_blanks
+__all__ = [
+ "indice_collate_wrapper",
+ "OverfitDataLoader"
+]
+
+
+def indice_collate_wrapper(func:Callable):
+ """
+ 其功能是封装一层 collate_fn,将 dataset 取到的 tuple 数据分离开,将 idx 打包为 indices。
+
+ :param func: 需要修饰的函数
+ :return:
+ """
+ if hasattr(func, '__name__') and func.__name__ == '_indice_collate_wrapper': # 如果已经被包裹过了
+ return func
+
+ def _indice_collate_wrapper(tuple_data): # 这里不能使用 functools.wraps ,否则会检测不到
+ indice, ins_list = [], []
+ for idx, ins in tuple_data:
+ indice.append(idx)
+ ins_list.append(ins)
+ return indice, func(ins_list)
+ _indice_collate_wrapper.__wrapped__ = func # 记录对应的
+
+ return _indice_collate_wrapper
+
+
+def _match_param(fun, call_fn:Callable, fn_name:str=None):
+ """
+ 在调用 _match_param 的函数(就是 fun )中会调用 call_fn 这个函数。由于 fun 中支持的函数比 call_fn 更多,例如低版本的
+ :class:`~.fastNLP.TorchDataLoader` 中支持的参数,在torch 1.6 版本的 DataLoader 就不支持,但在高版本的 torch 中是支持的
+ 因此,这里需要根据当前版本的 DataLoader 判定出适合传入 DataLoader 进行初始化的参数,并且在不支持但又被设置的参数上进行
+ warning 。
+
+ :param fun: 调用函数本身
+ :param call_fn:
+ :param fn_name: 方便报错的用的函数
+ :return:
+ """
+ try:
+ if fn_name is None:
+ try:
+ fn_name = call_fn.__name__
+ except:
+ fn_name = str(call_fn)
+
+ last_frame = inspect.currentframe().f_back
+
+ # 调用 _match_param 的函数名称,获取默认的参数值
+ fun_default_params = {}
+ fun_parameters = inspect.signature(fun)
+ for name, fun_param in fun_parameters.parameters.items():
+ if fun_param.default is not fun_param.empty:
+ fun_default_params[name] = fun_param.default
+
+ # 获取实际传入的参数值
+ param_names, args_name, kwargs_name, values = inspect.getargvalues(last_frame)
+ if args_name is not None:
+ raise RuntimeError("Function does not support positional arguments, such as: fun(*args).")
+ kwargs = values.get(kwargs_name, {})
+ for param in param_names:
+ if param not in values:
+ value = fun_default_params.get(param)
+ else:
+ value = values[param]
+ kwargs[param] = value
+
+ # 根据需要实际需要调用的 call_fn 的参数进行匹配
+ call_fn_parameters = inspect.signature(call_fn)
+ call_fn_kwargs = {}
+ has_kwargs = False
+ for name, param in call_fn_parameters.parameters.items():
+ if name == 'self':
+ continue
+ if param.kind in (param.POSITIONAL_OR_KEYWORD, param.KEYWORD_ONLY): # 最前面的 args
+ call_fn_kwargs[name] = param.default
+ if param.kind == param.VAR_KEYWORD:
+ has_kwargs = True
+
+ # 组装得到最终的参数
+ call_kwargs = {}
+ for name, value in kwargs.items():
+ if name in call_fn_kwargs or has_kwargs: # 如果存在在里面,或者包含了 kwargs 就直接运行
+ call_kwargs[name] = value
+ # 如果不在需要调用的函数里面,同时又是非默认值
+ elif name not in call_fn_kwargs and name in fun_default_params and fun_default_params[name]!=value:
+ logger.rank_zero_warning(f"Parameter:{name} is not supported for {fn_name}.")
+
+ return call_kwargs
+ except BaseException as e:
+ logger.debug(f"Exception happens when match parameters for {fn_name}: {e}")
+ return None
+
+
+class HasLenGetitemType(ABC):
+ """
+ 判断是否实现了 __len__ 和 __getitem__ 方法的类
+
+ """
+ @classmethod
+ def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
+ if cls is HasLenGetitemType:
+ flag = callable(getattr(subclass, '__getitem__', None)) and callable(getattr(subclass, '__len__', None))
+ return flag
+ return NotImplemented
+
+
+class OverfitDataLoader:
+ """
+ 实现一个简单的迭代器来模拟实际的 dataloader,从给定的 ``dataloader`` 中取出部分数据,来让 Trainer 实现 overfit 的功能;
+ """
+
+ def __init__(self, dataloader, overfit_batches: int, batches=None):
+ # batches 参数是给重新初始化dataloader使用的
+ self.dataloader = dataloader # 需要将实际的 dataloader 挂载到该对象上,从而应付一些对于实际的 dataloader 的操作;
+ if batches is None:
+ self.batches = []
+ self.overfit_batches = int(overfit_batches)
+
+ if self.overfit_batches > len(dataloader):
+ logger.warning("Parameter 'overfit_batches' is bigger than the length of 'train_dataloader'.")
+
+ for idx, batch in enumerate(dataloader):
+ if idx < self.overfit_batches or self.overfit_batches <= -1:
+ self.batches.append(batch)
+ else:
+ assert isinstance(batches, list)
+ self.batches = batches
+
+ def __len__(self):
+ return len(self.batches)
+
+ def __iter__(self):
+ for batch in self.batches:
+ yield batch
+
+ def __getattr__(self, item):
+ return getattr(self.dataloader, item)
+
+
+
+if __name__ == '__main__':
+ def demo(*args, **kwargs):
+ pass
+
+ d = indice_collate_wrapper(demo)
+
+ print(d.__name__)
+ print(d.__wrapped__)
\ No newline at end of file
diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py
deleted file mode 100644
index 45a488d9..00000000
--- a/fastNLP/core/dataset.py
+++ /dev/null
@@ -1,1213 +0,0 @@
-r"""
-:class:`~fastNLP.core.dataset.DataSet` 是fastNLP中用于承载数据的容器。可以将DataSet看做是一个表格,
-每一行是一个sample (在fastNLP中被称为 :mod:`~fastNLP.core.instance` ),
-每一列是一个feature (在fastNLP中称为 :mod:`~fastNLP.core.field` )。
-
-.. csv-table:: Following is a demo layout of DataSet
- :header: "sentence", "words", "seq_len"
-
- "This is the first instance .", "[This, is, the, first, instance, .]", 6
- "Second instance .", "[Second, instance, .]", 3
- "Third instance .", "[Third, instance, .]", 3
- "...", "[...]", "..."
-
-在fastNLP内部每一行是一个 :class:`~fastNLP.Instance` 对象; 每一列是一个 :class:`~fastNLP.FieldArray` 对象。
-
-----------------------------
-1.DataSet的创建
-----------------------------
-
-创建DataSet主要有以下的3种方式
-
-1.1 传入dict
-----------------------------
-
- .. code-block::
-
- from fastNLP import DataSet
- data = {'sentence':["This is the first instance .", "Second instance .", "Third instance ."],
- 'words': [['this', 'is', 'the', 'first', 'instance', '.'], ['Second', 'instance', '.'], ['Third', 'instance', '.'],
- 'seq_len': [6, 3, 3]}
- dataset = DataSet(data)
- # 传入的dict的每个key的value应该为具有相同长度的list
-
-1.2 通过 Instance 构建
-----------------------------
-
- .. code-block::
-
- from fastNLP import DataSet
- from fastNLP import Instance
- dataset = DataSet()
- instance = Instance(sentence="This is the first instance",
- words=['this', 'is', 'the', 'first', 'instance', '.'],
- seq_len=6)
- dataset.append(instance)
- # 可以继续append更多内容,但是append的instance应该和第一个instance拥有完全相同的field
-
-1.3 通过 List[Instance] 构建
---------------------------------------
-
- .. code-block::
-
- from fastNLP import DataSet
- from fastNLP import Instance
- instances = []
- instances.append(Instance(sentence="This is the first instance",
- ords=['this', 'is', 'the', 'first', 'instance', '.'],
- seq_len=6))
- instances.append(Instance(sentence="Second instance .",
- words=['Second', 'instance', '.'],
- seq_len=3))
- dataset = DataSet(instances)
-
---------------------------------------
-2.DataSet与预处理
---------------------------------------
-
-常见的预处理有如下几种
-
-2.1 从某个文本文件读取内容
---------------------------------------
-
- .. code-block::
-
- from fastNLP import DataSet
- from fastNLP import Instance
- dataset = DataSet()
- filepath='some/text/file'
- # 假设文件中每行内容如下(sentence label):
- # This is a fantastic day positive
- # The bad weather negative
- # .....
- with open(filepath, 'r') as f:
- for line in f:
- sent, label = line.strip().split('\t')
- dataset.append(Instance(sentence=sent, label=label))
-
- .. note::
- 直接读取特定数据集的数据请参考 :doc:`/tutorials/tutorial_4_load_dataset`
-
-2.2 对DataSet中的内容处理
---------------------------------------
-
- .. code-block::
-
- from fastNLP import DataSet
- data = {'sentence':["This is the first instance .", "Second instance .", "Third instance ."]}
- dataset = DataSet(data)
- # 将句子分成单词形式, 详见DataSet.apply()方法
- dataset.apply(lambda ins: ins['sentence'].split(), new_field_name='words')
- # 或使用DataSet.apply_field()
- dataset.apply_field(lambda sent:sent.split(), field_name='sentence', new_field_name='words')
- # 除了匿名函数,也可以定义函数传递进去
- def get_words(instance):
- sentence = instance['sentence']
- words = sentence.split()
- return words
- dataset.apply(get_words, new_field_name='words')
-
-2.3 删除DataSet的内容
---------------------------------------
-
- .. code-block::
-
- from fastNLP import DataSet
- dataset = DataSet({'a': list(range(-5, 5))})
- # 返回满足条件的instance,并放入DataSet中
- dropped_dataset = dataset.drop(lambda ins:ins['a']<0, inplace=False)
- # 在dataset中删除满足条件的instance
- dataset.drop(lambda ins:ins['a']<0) # dataset的instance数量减少
- # 删除第3个instance
- dataset.delete_instance(2)
- # 删除名为'a'的field
- dataset.delete_field('a')
-
-
-2.4 遍历DataSet的内容
---------------------------------------
-
- .. code-block::
-
- for instance in dataset:
- # do something
-
-2.5 一些其它操作
---------------------------------------
-
- .. code-block::
-
- # 检查是否存在名为'a'的field
- dataset.has_field('a') # 或 ('a' in dataset)
- # 将名为'a'的field改名为'b'
- dataset.rename_field('a', 'b')
- # DataSet的长度
- len(dataset)
-
---------------------------------------
-3.DataSet与自然语言处理(NLP)
---------------------------------------
-
-在目前深度学习的模型中,大都依赖于随机梯度下降法(SGD)进行模型的优化。随机梯度下降需要将数据切分成一个个的 batch,
-一个batch进行一次前向计算(forward)与梯度后向传播(backward)。在自然语言处理的场景下,往往还需要对数据进行pad。这是
-由于句子的长度一般是不同的,但是一次batch中的每个field都必须是一个tensor,所以需要将所有句子都补齐到相同的长度。
-
-3.1 DataSet与DataSetIter
---------------------------------------
-
- 我们先看fastNLP中如何将数据分成一个一个的batch的例子, 这里我们使用随机生成的数据来模拟一个二分类文本分类任务,
- words和characters是输入,labels是文本类别
-
- .. code-block::
-
- from fastNLP import DataSet
- from fastNLP import DataSetIter
- from fastNLP import SequentialSampler
- from fastNLP import EngChar2DPadder
-
- num_instances = 100
- # 假设每句话最少2个词,最多5个词; 词表的大小是100个; 一共26个字母,每个单词最短1个字母,最长5个字母
- lengths = [random.randint(2, 5) for _ in range(num_instances)]
- data = {'words': [[random.randint(1, 100) for _ in range(lengths[idx]) ] for idx in range(num_instances)],
- 'chars': [
- [[random.randint(1, 27) for _ in range(random.randint(1, 5))]
- for _ in range(lengths[idx])]
- for idx in range(num_instances)],
- 'label': [random.randint(0, 1) for _ in range(num_instances)]}
-
- d = DataSet(data)
- d.set_padder('chars', EngChar2DPadder()) # 因为英文character的pad方式与word的pad方式不一样
-
- d.set_target('label')
- d.set_input('words', 'chars')
-
- for batch_x, batch_y in DataSetIter(d, sampler=SequentialSampler(), batch_size=2):
- print("batch_x:", batch_x)
- print("batch_y:", batch_y)
- break
- # 输出为
- # {'words': tensor([[49, 27, 20, 36, 63],
- # [53, 82, 23, 11, 0]]), 'chars': tensor([[[13, 3, 14, 25, 1],
- # [ 8, 20, 12, 0, 0],
- # [27, 8, 0, 0, 0],
- # [ 1, 15, 26, 0, 0],
- # [11, 24, 17, 0, 0]],
- #
- # [[ 6, 14, 11, 27, 22],
- # [18, 6, 4, 19, 0],
- # [19, 22, 9, 0, 0],
- # [10, 25, 0, 0, 0],
- # [ 0, 0, 0, 0, 0]]])}
- # {'label': tensor([0, 0])}
-
- 其中 :class:`~fastNLP.DataSetIter` 是用于从DataSet中按照batch_size为大小取出batch的迭代器,
- :class:`~fastNLP.SequentialSampler` 用于指示 :class:`~fastNLP.DataSetIter` 以怎样的
- 顺序从DataSet中取出instance以组成一个batch,
- 更详细的说明请参照 :class:`~fastNLP.DataSetIter` 和 :class:`~fastNLP.SequentialSampler` 文档。
-
- 通过 ``DataSet.set_input('words', 'chars')`` , fastNLP将认为 `words` 和 `chars` 这两个field都是input,并将它们都放入迭代器
- 生成的第一个dict中; ``DataSet.set_target('labels')`` , fastNLP将认为 `labels` 这个field是target,并将其放入到迭代器的第
- 二个dict中。如上例中所打印结果。分为input和target的原因是由于它们在被 :class:`~fastNLP.Trainer` 所使用时会有所差异,
- 详见 :class:`~fastNLP.Trainer`
-
- 当把某个field设置为 `target` 或者 `input` 的时候(两者不是互斥的,可以同时设为两种),fastNLP不仅仅只是将其放
- 置到不同的dict中,而还会对被设置为 `input` 或 `target` 的 field 进行类型检查。类型检查的目的是为了看能否把该 field 转为
- pytorch的 :class:`torch.LongTensor` 或 :class:`torch.FloatTensor` 类型
- (也可以在 :class:`~fastNLP.DataSetIter` 中设置输出numpy类型,参考 :class:`~fastNLP.DataSetIter` )。
-
- 如上例所示,fastNLP已将 `words` ,`chars` 和 `label` 转为了 :class:`Tensor` 类型。
- 如果 field 在每个 `instance` 都拥有相同的维度(不能超过两维),且最内层的元素都为相同的 type(int, float, np.int*, np.float*),
- 则fastNLP默认将对该 field 进行pad。也支持全为str的field作为target和input,这种情况下,fastNLP默认不进行pad。
- 另外,当某个 field 已经被设置为了 target 或者 input 后,之后 `append` 的
- `instance` 对应的 field 必须要和前面已有的内容一致,否则会报错。
-
- 可以查看field的dtype::
-
- from fastNLP import DataSet
-
- d = DataSet({'a': [0, 1, 3], 'b':[[1.0, 2.0], [0.1, 0.2], [3]]})
- d.set_input('a', 'b')
- d.a.dtype
- >> numpy.int64
- d.b.dtype
- >> numpy.float64
- # 默认情况下'a'这个field将被转换为torch.LongTensor,但如果需要其为torch.FloatTensor可以手动修改dtype
- d.a.dtype = float # 请确保该field的确可以全部转换为float。
-
- 如果某个field中出现了多种类型混合(比如一部分为str,一部分为int)的情况,fastNLP无法判断该field的类型,会报如下的
- 错误::
-
- from fastNLP import DataSet
-
- d = DataSet({'data': [1, 'a']})
- d.set_input('data')
- >> RuntimeError: Mixed data types in Field data: [, ]
-
- 可以通过设置以忽略对该field进行类型检查::
-
- from fastNLP import DataSet
- d = DataSet({'data': [1, 'a']})
- d.set_ignore_type('data')
- d.set_input('data')
-
- 当某个field被设置为忽略type之后,fastNLP将不对其进行pad。
-
-3.2 DataSet与pad
---------------------------------------
-
- 在fastNLP里,pad是与一个field绑定的。即不同的field可以使用不同的pad方式,比如在英文任务中word需要的pad和
- character的pad方式往往是不同的。fastNLP是通过一个叫做 :class:`~fastNLP.Padder` 的子类来完成的。
- 默认情况下,所有field使用 :class:`~fastNLP.AutoPadder`
- 。可以通过使用以下方式设置Padder(如果将padder设置为None,则该field不会进行pad操作)。
- 大多数情况下直接使用 :class:`~fastNLP.AutoPadder` 就可以了。
- 如果 :class:`~fastNLP.AutoPadder` 或 :class:`~fastNLP.EngChar2DPadder` 无法满足需求,
- 也可以自己写一个 :class:`~fastNLP.Padder` 。
-
- .. code-block::
-
- from fastNLP import DataSet
- from fastNLP import EngChar2DPadder
- import random
- dataset = DataSet()
- max_chars, max_words, sent_num = 5, 10, 20
- contents = [[
- [random.randint(1, 27) for _ in range(random.randint(1, max_chars))]
- for _ in range(random.randint(1, max_words))
- ] for _ in range(sent_num)]
- # 初始化时传入
- dataset.add_field('chars', contents, padder=EngChar2DPadder())
- # 直接设置
- dataset.set_padder('chars', EngChar2DPadder())
- # 也可以设置pad的value
- dataset.set_pad_val('chars', -1)
-
-3.3 根据DataSet中多个field合成新的field
-------------------------------------------------------------
-
- DataSet支持在进行batch时,默认只能看到当前的field的值,但在某些训练中可能存在以下的情况: (1)需要两个field拼接成为一个field;
- (2)需要在batch中进行负采样。这时候就需要能够同时利用多个field进行batch的操作,DataSet中的add_collate_fn()函数支持添加
- 自定义涉及多个field的collate_fn函数。例如下例中将两个field拼接成一个field的场景
-
- .. code-block::
-
- from fastNLP import DataSet, DataSetIter
- import torch
-
- data = DataSet({
- 'x1': [[0, 1],
- [2]],
- 'x2': [[3],
- [2, 4, 5]],
- 'y': [0, 1]
- })
- data.set_target('y')
-
- # 所有的collate_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中
- # 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field.
- def concat_collate_fn(ins_list):
- x1 = [ins['x1'] for ind,ins in ins_list]
- x2 = [ins['x2'] for ind,ins in ins_list]
- xs = []
- for i in range(len(ins_list)):
- xs.append(torch.LongTensor(x1[i] + x2[i]))
- # 需要自行pad并转换为tensor,但不需要移动到gpu
- arr = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=0)
- b_x = {'x': arr}
- b_y = {}
- # 返回值一定是两个dict,第一个dict的值会认为是input,第二个dict的值会认为是target. 若名称与已有input或target重复,则
- # 采用返回值。
- return b_x, b_y
-
- data.add_collate_fn(concat_collate_fn)
-
- for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2):
- print("batch_x:", batch_x)
- print("batch_y:", batch_y)
- # batch_x: {'x': tensor([[0, 1, 3, 0],
- # [2, 2, 4, 5]])}
- # batch_y: {'y': array([0, 1])}
-
- # 如果取batch过程含有一些参数,可以通过类来实现
- class ConCollateFn:
- def __init__(self, max_len=3):
- self.max_len = max_len
-
- def __call__(self, ins_list): # 实现该类的__call__函数
- x1 = [ins['x1'] for ind, ins in ins_list]
- x2 = [ins['x2'] for ind, ins in ins_list]
- xs = []
- for i in range(len(ins_list)):
- xs.append(torch.LongTensor(x1[i] + x2[i])[:self.max_len])
- arr = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=0)
- b_x = {'x': arr}
- b_y = {}
- return b_x, b_y
- data.delete_collate_fn() # 删除之前的collate_fn
- data.add_collate_fn(ConCollateFn(max_len=3))
- for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2):
- print("batch_x:", batch_x)
- print("batch_y:", batch_y)
- # batch_x: {'x': tensor([[0, 1, 3],
- # [2, 2, 4]])}
- # batch_y: {'y': array([0, 1])}
-
-"""
-__all__ = [
- "DataSet",
-]
-
-import _pickle as pickle
-from copy import deepcopy
-
-import numpy as np
-from prettytable import PrettyTable
-
-from ._logger import logger
-from .const import Const
-from .field import AppendToTargetOrInputException
-from .field import AutoPadder
-from .field import FieldArray
-from .field import SetInputOrTargetException
-from .instance import Instance
-from .utils import pretty_table_printer
-from .collate_fn import Collater
-try:
- from tqdm.auto import tqdm
-except:
- from .utils import _pseudo_tqdm as tqdm
-
-
-class ApplyResultException(Exception):
- def __init__(self, msg, index=None):
- super().__init__(msg)
- self.msg = msg
- self.index = index # 标示在哪个数据遭遇到问题了
-
-class DataSet(object):
- r"""
- fastNLP的数据容器,详细的使用方法见文档 :mod:`fastNLP.core.dataset`
- """
-
- def __init__(self, data=None):
- r"""
-
- :param data: 如果为dict类型,则每个key的value应该为等长的list; 如果为list,
- 每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。
- """
- self.field_arrays = {}
- if data is not None:
- if isinstance(data, dict):
- length_set = set()
- for key, value in data.items():
- length_set.add(len(value))
- assert len(length_set) == 1, "Arrays must all be same length."
- for key, value in data.items():
- self.add_field(field_name=key, fields=value)
- elif isinstance(data, list):
- for ins in data:
- assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins))
- self.append(ins)
-
- else:
- raise ValueError("data only be dict or list type.")
- self._collater = Collater()
-
- @property
- def collater(self):
- if self._collater is None:
- self._collater = Collater()
- return self._collater
-
- @collater.setter
- def collater(self, value):
- assert isinstance(value, Collater)
- self._collater = value
-
- def __contains__(self, item):
- return item in self.field_arrays
-
- def __iter__(self):
- def iter_func():
- for idx in range(len(self)):
- yield self[idx]
-
- return iter_func()
-
- def _inner_iter(self):
- class Iter_ptr:
- def __init__(self, dataset, idx):
- self.dataset = dataset
- self.idx = idx
-
- def __getitem__(self, item):
- assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[
- self.idx])
- assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx)
- return self.dataset.field_arrays[item][self.idx]
-
- def __setitem__(self, key, value):
- raise TypeError("You cannot modify value directly.")
-
- def items(self):
- ins = self.dataset[self.idx]
- return ins.items()
-
- def __repr__(self):
- return self.dataset[self.idx].__repr__()
-
- def inner_iter_func():
- for idx in range(len(self)):
- yield Iter_ptr(self, idx)
-
- return inner_iter_func()
-
- def __getitem__(self, idx):
- r"""给定int的index,返回一个Instance; 给定slice,返回包含这个slice内容的新的DataSet。
-
- :param idx: can be int or slice.
- :return: If `idx` is int, return an Instance object.
- If `idx` is slice, return a DataSet object.
- """
- if isinstance(idx, int):
- return Instance(**{name: self.field_arrays[name][idx] for name in self.field_arrays})
- elif isinstance(idx, slice):
- if idx.start is not None and (idx.start >= len(self) or idx.start <= -len(self)):
- raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self) - 1}")
- data_set = DataSet()
- for field_name, field in self.field_arrays.items():
- data_set.add_field(field_name=field_name, fields=field.content[idx], padder=field.padder,
- is_input=field.is_input, is_target=field.is_target, ignore_type=field.ignore_type)
- data_set.collater = self.collater.copy_from(self.collater)
- return data_set
- elif isinstance(idx, str):
- if idx not in self:
- raise KeyError("No such field called {} in DataSet.".format(idx))
- return self.field_arrays[idx]
- elif isinstance(idx, list):
- dataset = DataSet()
- for i in idx:
- assert isinstance(i, int), "Only int index allowed."
- instance = self[i]
- dataset.append(instance)
- for field_name, field in self.field_arrays.items():
- dataset.field_arrays[field_name].to(field)
- dataset.collater = self.collater.copy_from(self.collater)
- return dataset
- else:
- raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))
-
- def __getattr__(self, item):
- # Not tested. Don't use !!
- if item == "field_arrays":
- raise AttributeError
- if isinstance(item, str) and item in self.field_arrays:
- return self.field_arrays[item]
-
- def __setstate__(self, state):
- self.__dict__ = state
-
- def __getstate__(self):
- return self.__dict__
-
- def __len__(self):
- r"""Fetch the length of the dataset.
-
- :return length:
- """
- if len(self.field_arrays) == 0:
- return 0
- field = iter(self.field_arrays.values()).__next__()
- return len(field)
-
- def __repr__(self):
- return str(pretty_table_printer(self))
-
- def print_field_meta(self):
- r"""
- 输出当前field的meta信息, 形似下列的输出::
-
- +-------------+-------+-------+
- | field_names | x | y |
- +=============+=======+=======+
- | is_input | True | False |
- | is_target | False | False |
- | ignore_type | False | |
- | pad_value | 0 | |
- +-------------+-------+-------+
-
- str field_names: DataSet中field的名称
- bool is_input: field是否为input
- bool is_target: field是否为target
- bool ignore_type: 是否忽略该field的type, 一般仅在该field至少为input或target时才有意义
- int pad_value: 该field的pad的值,仅在该field为input或target时有意义
- :return:
- """
- if len(self.field_arrays)>0:
- field_names = ['field_names']
- is_inputs = ['is_input']
- is_targets = ['is_target']
- pad_values = ['pad_value']
- ignore_types = ['ignore_type']
-
- for name, field_array in self.field_arrays.items():
- field_names.append(name)
- if field_array.is_input:
- is_inputs.append(True)
- else:
- is_inputs.append(False)
- if field_array.is_target:
- is_targets.append(True)
- else:
- is_targets.append(False)
-
- if (field_array.is_input or field_array.is_target) and field_array.padder is not None:
- pad_values.append(field_array.padder.get_pad_val())
- else:
- pad_values.append(' ')
-
- if field_array._ignore_type:
- ignore_types.append(True)
- elif field_array.is_input or field_array.is_target:
- ignore_types.append(False)
- else:
- ignore_types.append(' ')
- table = PrettyTable(field_names=field_names)
- fields = [is_inputs, is_targets, ignore_types, pad_values]
- for field in fields:
- table.add_row(field)
- logger.info(table)
- return table
-
- def append(self, instance):
- r"""
- 将一个instance对象append到DataSet后面。
-
- :param ~fastNLP.Instance instance: 若DataSet不为空,则instance应该拥有和DataSet完全一样的field。
-
- """
- if len(self.field_arrays) == 0:
- # DataSet has no field yet
- for name, field in instance.fields.items():
- # field = field.tolist() if isinstance(field, np.ndarray) else field
- self.field_arrays[name] = FieldArray(name, [field]) # 第一个样本,必须用list包装起来
- else:
- if len(self.field_arrays) != len(instance.fields):
- raise ValueError(
- "DataSet object has {} fields, but attempt to append an Instance object with {} fields."
- .format(len(self.field_arrays), len(instance.fields)))
- for name, field in instance.fields.items():
- assert name in self.field_arrays
- try:
- self.field_arrays[name].append(field)
- except AppendToTargetOrInputException as e:
- logger.error(f"Cannot append to field:{name}.")
- raise e
-
- def add_fieldarray(self, field_name, fieldarray):
- r"""
- 将fieldarray添加到DataSet中.
-
- :param str field_name: 新加入的field的名称
- :param ~fastNLP.core.FieldArray fieldarray: 需要加入DataSet的field的内容
- :return:
- """
- if not isinstance(fieldarray, FieldArray):
- raise TypeError("Only fastNLP.FieldArray supported.")
- if len(self) != len(fieldarray):
- raise RuntimeError(f"The field to add must have the same size as dataset. "
- f"Dataset size {len(self)} != field size {len(fieldarray)}")
- fieldarray.name = field_name
- self.field_arrays[field_name] = fieldarray
-
- def add_field(self, field_name, fields, padder=AutoPadder(), is_input=False, is_target=False, ignore_type=False):
- r"""
- 新增一个field
-
- :param str field_name: 新增的field的名称
- :param list fields: 需要新增的field的内容
- :param None,~fastNLP.Padder padder: 如果为None,则不进行pad,默认使用 :class:`~fastNLP.AutoPadder` 自动判断是否需要做pad。
- :param bool is_input: 新加入的field是否是input
- :param bool is_target: 新加入的field是否是target
- :param bool ignore_type: 是否忽略对新加入的field的类型检查
- """
-
- if len(self.field_arrays) != 0:
- if len(self) != len(fields):
- raise RuntimeError(f"The field to add must have the same size as dataset. "
- f"Dataset size {len(self)} != field size {len(fields)}")
- self.field_arrays[field_name] = FieldArray(field_name, fields, is_target=is_target, is_input=is_input,
- padder=padder, ignore_type=ignore_type)
-
- def delete_instance(self, index):
- r"""
- 删除第index个instance
-
- :param int index: 需要删除的instance的index,序号从0开始。
- """
- assert isinstance(index, int), "Only integer supported."
- if len(self) <= index:
- raise IndexError("{} is too large for as DataSet with {} instances.".format(index, len(self)))
- if len(self) == 1:
- self.field_arrays.clear()
- else:
- for field in self.field_arrays.values():
- field.pop(index)
- return self
-
- def delete_field(self, field_name):
- r"""
- 删除名为field_name的field
-
- :param str field_name: 需要删除的field的名称.
- """
- self.field_arrays.pop(field_name)
- return self
-
- def copy_field(self, field_name, new_field_name):
- r"""
- 深度copy名为field_name的field到new_field_name
-
- :param str field_name: 需要copy的field。
- :param str new_field_name: copy生成的field名称
- :return: self
- """
- if not self.has_field(field_name):
- raise KeyError(f"Field:{field_name} not found in DataSet.")
- fieldarray = deepcopy(self.get_field(field_name))
- fieldarray.name = new_field_name
- self.add_fieldarray(field_name=new_field_name, fieldarray=fieldarray)
- return self
-
- def has_field(self, field_name):
- r"""
- 判断DataSet中是否有名为field_name这个field
-
- :param str field_name: field的名称
- :return bool: 表示是否有名为field_name这个field
- """
- if isinstance(field_name, str):
- return field_name in self.field_arrays
- return False
-
- def get_field(self, field_name):
- r"""
- 获取field_name这个field
-
- :param str field_name: field的名称
- :return: :class:`~fastNLP.FieldArray`
- """
- if field_name not in self.field_arrays:
- raise KeyError("Field name {} not found in DataSet".format(field_name))
- return self.field_arrays[field_name]
-
- def get_all_fields(self):
- r"""
- 返回一个dict,key为field_name, value为对应的 :class:`~fastNLP.FieldArray`
-
- :return dict: 返回如上所述的字典
- """
- return self.field_arrays
-
- def get_field_names(self) -> list:
- r"""
- 返回一个list,包含所有 field 的名字
-
- :return list: 返回如上所述的列表
- """
- return sorted(self.field_arrays.keys())
-
- def get_length(self):
- r"""
- 获取DataSet的元素数量
-
- :return: int: DataSet中Instance的个数。
- """
- return len(self)
-
- def rename_field(self, field_name, new_field_name):
- r"""
- 将某个field重新命名.
-
- :param str field_name: 原来的field名称。
- :param str new_field_name: 修改为new_name。
- """
- if field_name in self.field_arrays:
- self.field_arrays[new_field_name] = self.field_arrays.pop(field_name)
- self.field_arrays[new_field_name].name = new_field_name
- else:
- raise KeyError("DataSet has no field named {}.".format(field_name))
- return self
-
- def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True):
- r"""
- 将field_names的field设置为target
-
- Example::
-
- dataset.set_target('labels', 'seq_len') # 将labels和seq_len这两个field的target属性设置为True
- dataset.set_target('labels', 'seq_lens', flag=False) # 将labels和seq_len的target属性设置为False
-
- :param str field_names: field的名称
- :param bool flag: 将field_name的target状态设置为flag
- :param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一
- 行的数据进行类型和维度推断本列的数据的类型和维度。
- """
- assert isinstance(flag, bool), "Only bool type supported."
- for name in field_names:
- if name in self.field_arrays:
- try:
- self.field_arrays[name]._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type)
- self.field_arrays[name].is_target = flag
- except SetInputOrTargetException as e:
- logger.error(f"Cannot set field:{name} as target.")
- raise e
- else:
- raise KeyError("{} is not a valid field name.".format(name))
- return self
-
- def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True):
- r"""
- 将field_names的field设置为input::
-
- dataset.set_input('words', 'seq_len') # 将words和seq_len这两个field的input属性设置为True
- dataset.set_input('words', flag=False) # 将words这个field的input属性设置为False
-
- :param str field_names: field的名称
- :param bool flag: 将field_name的input状态设置为flag
- :param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一
- 行的数据进行类型和维度推断本列的数据的类型和维度。
- """
- for name in field_names:
- if name in self.field_arrays:
- try:
- self.field_arrays[name]._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type)
- self.field_arrays[name].is_input = flag
- except SetInputOrTargetException as e:
- logger.error(f"Cannot set field:{name} as input, exception happens at the {e.index} value.")
- raise e
- else:
- raise KeyError("{} is not a valid field name.".format(name))
- return self
-
- def set_ignore_type(self, *field_names, flag=True):
- r"""
- 将field设置为忽略类型状态。当某个field被设置了ignore_type, 则在被设置为target或者input时将不进行类型检查,
- 默认情况下也不进行pad。如果仍需要pad该field,可通过自定义Padder实现,若该field需要转换为tensor,需要在padder
- 中转换,但不需要在padder中移动到gpu。
-
- :param str field_names: field的名称
- :param bool flag: 将field_name的ignore_type状态设置为flag
- :return:
- """
- assert isinstance(flag, bool), "Only bool type supported."
- for name in field_names:
- if name in self.field_arrays:
- self.field_arrays[name].ignore_type = flag
- else:
- raise KeyError("{} is not a valid field name.".format(name))
- return self
-
- def set_padder(self, field_name, padder):
- r"""
- 为field_name设置padder::
-
- from fastNLP import EngChar2DPadder
- padder = EngChar2DPadder()
- dataset.set_padder('chars', padder) # 则chars这个field会使用EngChar2DPadder进行pad操作
-
- :param str field_name: 设置field的padding方式为padder
- :param None,~fastNLP.Padder padder: 设置为None即删除padder, 即对该field不进行pad操作。
- """
- if field_name not in self.field_arrays:
- raise KeyError("There is no field named {}.".format(field_name))
- self.field_arrays[field_name].set_padder(padder)
- return self
-
- def set_pad_val(self, field_name, pad_val):
- r"""
- 为某个field设置对应的pad_val.
-
- :param str field_name: 修改该field的pad_val
- :param int pad_val: 该field的padder会以pad_val作为padding index
- """
- if field_name not in self.field_arrays:
- raise KeyError("There is no field named {}.".format(field_name))
- self.field_arrays[field_name].set_pad_val(pad_val)
- return self
-
- def get_input_name(self):
- r"""
- 返回所有is_input被设置为True的field名称
-
- :return list: 里面的元素为被设置为input的field名称
- """
- return [name for name, field in self.field_arrays.items() if field.is_input]
-
- def get_target_name(self):
- r"""
- 返回所有is_target被设置为True的field名称
-
- :return list: 里面的元素为被设置为target的field名称
- """
- return [name for name, field in self.field_arrays.items() if field.is_target]
-
- def apply_field(self, func, field_name, new_field_name=None, **kwargs):
- r"""
- 将DataSet中的每个instance中的名为 `field_name` 的field传给func,并获取它的返回值。
-
- :param callable func: input是instance中名为 `field_name` 的field的内容。
- :param str field_name: 传入func的是哪个field。
- :param None,str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆
- 盖之前的field。如果为None则不创建新的field。
- :param optional kwargs: 支持输入is_input,is_target,ignore_type
-
- 1. is_input: bool, 如果为True则将名为 `new_field_name` 的field设置为input
-
- 2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target
-
- 3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型
-
- 4. use_tqdm: bool, 是否使用tqdm显示预处理进度
-
- 5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
-
- :return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度
- """
- assert len(self) != 0, "Null DataSet cannot use apply_field()."
- if not self.has_field(field_name=field_name):
- raise KeyError("DataSet has no field named `{}`.".format(field_name))
- return self.apply(func, new_field_name, _apply_field=field_name, **kwargs)
-
- def apply_field_more(self, func, field_name, modify_fields=True, **kwargs):
- r"""
- 将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的field 传给 func,并获取它的返回值。
- func 可以返回一个或多个 field 上的结果。
-
- .. note::
- ``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
- ``apply`` 区别的介绍。
-
- :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
- :param str field_name: 传入func的是哪个field。
- :param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True
- :param optional kwargs: 支持输入is_input,is_target,ignore_type
-
- 1. is_input: bool, 如果为True则将被修改的field设置为input
-
- 2. is_target: bool, 如果为True则将被修改的field设置为target
-
- 3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型
-
- 4. use_tqdm: bool, 是否使用tqdm显示预处理进度
-
- 5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
-
- :return Dict[str:Field]: 返回一个字典
- """
- assert len(self) != 0, "Null DataSet cannot use apply_field()."
- if not self.has_field(field_name=field_name):
- raise KeyError("DataSet has no field named `{}`.".format(field_name))
- return self.apply_more(func, modify_fields, _apply_field=field_name, **kwargs)
-
- def _add_apply_field(self, results, new_field_name, kwargs):
- r"""
- 将results作为加入到新的field中,field名称为new_field_name
-
- :param List[str] results: 一般是apply*()之后的结果
- :param str new_field_name: 新加入的field的名称
- :param dict kwargs: 用户apply*()时传入的自定义参数
- :return:
- """
- extra_param = {}
- if 'is_input' in kwargs:
- extra_param['is_input'] = kwargs['is_input']
- if 'is_target' in kwargs:
- extra_param['is_target'] = kwargs['is_target']
- if 'ignore_type' in kwargs:
- extra_param['ignore_type'] = kwargs['ignore_type']
- if new_field_name in self.field_arrays:
- # overwrite the field, keep same attributes
- old_field = self.field_arrays[new_field_name]
- if 'is_input' not in extra_param:
- extra_param['is_input'] = old_field.is_input
- if 'is_target' not in extra_param:
- extra_param['is_target'] = old_field.is_target
- if 'ignore_type' not in extra_param:
- extra_param['ignore_type'] = old_field.ignore_type
- self.add_field(field_name=new_field_name, fields=results, is_input=extra_param["is_input"],
- is_target=extra_param["is_target"], ignore_type=extra_param['ignore_type'],
- padder=self.get_field(new_field_name).padder)
- else:
- self.add_field(field_name=new_field_name, fields=results, is_input=extra_param.get("is_input", None),
- is_target=extra_param.get("is_target", None),
- ignore_type=extra_param.get("ignore_type", False))
-
- def apply_more(self, func, modify_fields=True, **kwargs):
- r"""
- 将 ``DataSet`` 中每个 ``Instance`` 传入到func中,并获取它的返回值。func可以返回一个或多个 field 上的结果。
-
- .. note::
- ``apply_more`` 与 ``apply`` 的区别:
-
- 1. ``apply_more`` 可以返回多个 field 的结果, ``apply`` 只可以返回一个field 的结果;
-
- 2. ``apply_more`` 的返回值是一个字典,每个 key-value 对中的 key 表示 field 的名字,value 表示计算结果;
-
- 3. ``apply_more`` 默认修改 ``DataSet`` 中的 field ,``apply`` 默认不修改。
-
- :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
- :param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True
- :param optional kwargs: 支持输入is_input,is_target,ignore_type
-
- 1. is_input: bool, 如果为True则将被修改的的field设置为input
-
- 2. is_target: bool, 如果为True则将被修改的的field设置为target
-
- 3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型
-
- 4. use_tqdm: bool, 是否使用tqdm显示预处理进度
-
- 5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
-
- :return Dict[str:Field]: 返回一个字典
- """
- # 返回 dict , 检查是否一直相同
- assert callable(func), "The func you provide is not callable."
- assert len(self) != 0, "Null DataSet cannot use apply()."
- idx = -1
- try:
- results = {}
- for idx, ins in tqdm(enumerate(self._inner_iter()), total=len(self), dynamic_ncols=True,
- desc=kwargs.get('tqdm_desc', ''),
- leave=False, disable=not kwargs.get('use_tqdm', False)):
- if "_apply_field" in kwargs:
- res = func(ins[kwargs["_apply_field"]])
- else:
- res = func(ins)
- if not isinstance(res, dict):
- raise ApplyResultException("The result of func is not a dict", idx)
- if idx == 0:
- for key, value in res.items():
- results[key] = [value]
- else:
- for key, value in res.items():
- if key not in results:
- raise ApplyResultException("apply results have different fields", idx)
- results[key].append(value)
- if len(res) != len(results):
- raise ApplyResultException("apply results have different fields", idx)
- except Exception as e:
- if idx != -1:
- if isinstance(e, ApplyResultException):
- logger.error(e.msg)
- logger.error("Exception happens at the `{}`th instance.".format(idx))
- raise e
-
- if modify_fields is True:
- for field, result in results.items():
- self._add_apply_field(result, field, kwargs)
-
- return results
-
- def apply(self, func, new_field_name=None, **kwargs):
- r"""
- 将DataSet中每个instance传入到func中,并获取它的返回值.
-
- :param callable func: 参数是 ``DataSet`` 中的 ``Instance``
- :param None,str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆
- 盖之前的field。如果为None则不创建新的field。
- :param optional kwargs: 支持输入is_input,is_target,ignore_type
-
- 1. is_input: bool, 如果为True则将 `new_field_name` 的field设置为input
-
- 2. is_target: bool, 如果为True则将 `new_field_name` 的field设置为target
-
- 3. ignore_type: bool, 如果为True则将 `new_field_name` 的field的ignore_type设置为true, 忽略其类型
-
- 4. use_tqdm: bool, 是否使用tqdm显示预处理进度
-
- 5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
-
- :return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度
- """
- assert callable(func), "The func you provide is not callable."
- assert len(self) != 0, "Null DataSet cannot use apply()."
- idx = -1
- try:
- results = []
- for idx, ins in tqdm(enumerate(self._inner_iter()), total=len(self), dynamic_ncols=True, leave=False,
- desc=kwargs.get('tqdm_desc', ''),
- disable=not kwargs.get('use_tqdm', False)):
- if "_apply_field" in kwargs:
- results.append(func(ins[kwargs["_apply_field"]]))
- else:
- results.append(func(ins))
- except BaseException as e:
- if idx != -1:
- logger.error("Exception happens at the `{}`th instance.".format(idx))
- raise e
-
- if new_field_name is not None:
- self._add_apply_field(results, new_field_name, kwargs)
-
- return results
-
- def add_seq_len(self, field_name: str, new_field_name=Const.INPUT_LEN):
- r"""
- 将使用len()直接对field_name中每个元素作用,将其结果作为sequence length, 并放入seq_len这个field。
-
- :param field_name: str.
- :param new_field_name: str. 新的field_name
- :return:
- """
- if self.has_field(field_name=field_name):
- self.apply_field(len, field_name, new_field_name=new_field_name)
- else:
- raise KeyError(f"Field:{field_name} not found.")
- return self
-
- def drop(self, func, inplace=True):
- r"""
- func接受一个Instance,返回bool值。返回值为True时,该Instance会被移除或者不会包含在返回的DataSet中。
-
- :param callable func: 接受一个Instance作为参数,返回bool值。为True时删除该instance
- :param bool inplace: 是否在当前DataSet中直接删除instance;如果为False,将返回一个新的DataSet。
-
- :return: DataSet
- """
- if inplace:
- results = [ins for ins in self._inner_iter() if not func(ins)]
- for name, old_field in self.field_arrays.items():
- self.field_arrays[name].content = [ins[name] for ins in results]
- return self
- else:
- results = [ins for ins in self if not func(ins)]
- if len(results) != 0:
- dataset = DataSet(results)
- for field_name, field in self.field_arrays.items():
- dataset.field_arrays[field_name].to(field)
- return dataset
- else:
- return DataSet()
-
- def split(self, ratio, shuffle=True):
- r"""
- 将DataSet按照ratio的比例拆分,返回两个DataSet
-
- :param float ratio: 0 1, f'DataSet with {len(self)} instance cannot be split.'
- assert isinstance(ratio, float)
- assert 0 < ratio < 1
- all_indices = [_ for _ in range(len(self))]
- if shuffle:
- np.random.shuffle(all_indices)
- split = int(ratio * len(self))
- if split == 0:
- error_msg = f'Dev DataSet has {split} instance after split.'
- logger.error(error_msg)
- raise IndexError(error_msg)
- dev_indices = all_indices[:split]
- train_indices = all_indices[split:]
- dev_set = DataSet()
- train_set = DataSet()
- for idx in dev_indices:
- dev_set.append(self[idx])
- for idx in train_indices:
- train_set.append(self[idx])
- for field_name in self.field_arrays:
- train_set.field_arrays[field_name].to(self.field_arrays[field_name])
- dev_set.field_arrays[field_name].to(self.field_arrays[field_name])
-
- train_set.collater.copy_from(self.collater)
- dev_set.collater.copy_from(self.collater)
- return train_set, dev_set
-
- def save(self, path):
- r"""
- 保存DataSet.
-
- :param str path: 将DataSet存在哪个路径
- """
- with open(path, 'wb') as f:
- pickle.dump(self, f)
-
- @staticmethod
- def load(path):
- r"""
- 从保存的DataSet pickle文件的路径中读取DataSet
-
- :param str path: 从哪里读取DataSet
- :return: 读取后的 :class:`~fastNLP.读取后的DataSet`。
- """
- with open(path, 'rb') as f:
- d = pickle.load(f)
- assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d))
- return d
-
- def add_collate_fn(self, fn, name=None):
- r"""
- 添加 CollateFn,collate_fn允许在生成的batch的过程中动态生成一些数据(在DataSetIter作为迭代器的情况下有效,默认情况下就是用的
- 这个)。支持依次添加多个collate_fn, 如果相同的key,后面的collate_fn的结果覆盖前面的collate_fn的结果。
-
- :param callable fn: 传入一个可调用的function, 该function可接受的参数为List[(ind1, instance1), (ind2, instance2)]
- (某个batch被选中的所有的indice以及instance),其中ind1/ind2是该instance在dataset中的index,instance1/instance2是
- 这次batch取出来的数据,包含了所有的field。返回值需要为两个dict,第一个dict的值将被认为是input,第二个dict的值被认为是
- target,返回的值至多允许一个空dict。若返回的dict中包含了被设置为input或target的field的名称,将覆盖dataset中的field。
- fastNLP不会将collate_fn的返回结果pad和转换为tensor,需要在collate_fn中完成pad和转换为tensor(不需要将tensor移动到
- gpu中,fastNLP会自动将其移动到特定gpu)。不要修改传入collate_fn中的数据,否则可能导致未知问题。
- :param str,int name: collate_fn的名称,如果不传入,默认使用自增长的数字作为key。相同的name会覆盖之前的collate_fn。
- """
- assert callable(fn), "You must pass in a callable object."
- self.collater.add_fn(fn, name=name)
-
- def delete_collate_fn(self, name=None):
- r"""
- 删除某个collate_fn
-
- :param str,int name: 如果为None,则删除最近加入的collate_fn
- :return:
- """
- self.collater.delete_fn(name)
-
- def _collate_batch(self, ins_list):
- return self.collater.collate_batch(ins_list)
-
- def concat(self, dataset, inplace=True, field_mapping=None):
- """
- 将当前dataset与输入的dataset结合成一个更大的dataset,需要保证两个dataset都包含了相同的field。结合后的dataset的input,target
- 以及collate_fn以当前dataset为准。当dataset中包含的field多于当前的dataset,则多余的field会被忽略;若dataset中未包含所有
- 当前dataset含有field,则会报错。
-
- :param DataSet, dataset: 需要和当前dataset concat的dataset
- :param bool, inplace: 是否直接将dataset组合到当前dataset中
- :param dict, field_mapping: 当dataset中的field名称和当前dataset不一致时,需要通过field_mapping把输入的dataset中的field
- 名称映射到当前field. field_mapping为dict类型,key为dataset中的field名称,value是需要映射成的名称
-
- :return: DataSet
- """
- assert isinstance(dataset, DataSet), "Can only concat two datasets."
-
- fns_in_this_dataset = set(self.get_field_names())
- fns_in_other_dataset = dataset.get_field_names()
- reverse_field_mapping = {}
- if field_mapping is not None:
- fns_in_other_dataset = [field_mapping.get(fn, fn) for fn in fns_in_other_dataset]
- reverse_field_mapping = {v:k for k, v in field_mapping.items()}
- fns_in_other_dataset = set(fns_in_other_dataset)
- fn_not_seen = list(fns_in_this_dataset - fns_in_other_dataset)
-
- if fn_not_seen:
- raise RuntimeError(f"The following fields are not provided in the dataset:{fn_not_seen}")
-
- if inplace:
- ds = self
- else:
- ds = deepcopy(self)
-
- for fn in fns_in_this_dataset:
- ds.get_field(fn).content.extend(deepcopy(dataset.get_field(reverse_field_mapping.get(fn, fn)).content))
-
- return ds
diff --git a/fastNLP/core/dataset/__init__.py b/fastNLP/core/dataset/__init__.py
new file mode 100644
index 00000000..7a4dd8ed
--- /dev/null
+++ b/fastNLP/core/dataset/__init__.py
@@ -0,0 +1,10 @@
+__all__ = [
+ 'DataSet',
+ 'FieldArray',
+ 'Instance',
+ 'ApplyResultException'
+]
+
+from .dataset import DataSet, ApplyResultException
+from .field import FieldArray
+from .instance import Instance
diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py
new file mode 100644
index 00000000..7b0621af
--- /dev/null
+++ b/fastNLP/core/dataset/dataset.py
@@ -0,0 +1,1059 @@
+r"""
+:class:`~fastNLP.core.dataset.DataSet` 是 fastNLP 中用于承载数据的容器。可以将 DataSet 看做是一个表格,
+每一行是一个 sample (在 fastNLP 中被称为 :mod:`~fastNLP.core.dataset.instance` ),
+每一列是一个 feature (在 fastNLP 中称为 :mod:`~fastNLP.core.dataset.field` )。
+
+.. csv-table:: Following is a demo layout of DataSet
+ :header: "sentence", "words", "seq_len"
+
+ "This is the first instance .", "[This, is, the, first, instance, .]", 6
+ "Second instance .", "[Second, instance, .]", 3
+ "Third instance .", "[Third, instance, .]", 3
+ "...", "[...]", "..."
+
+在 fastNLP 内部每一行是一个 :class:`~fastNLP.core.dataset.Instance` 对象; 每一列是一个 :class:`~fastNLP.core.dataset.FieldArray` 对象。
+
+----------------------------
+1.DataSet的创建
+----------------------------
+
+创建DataSet主要有以下的3种方式
+
+1.1 传入dict
+----------------------------
+
+ .. code-block::
+
+ from fastNLP import DataSet
+ data = {'sentence':["This is the first instance .", "Second instance .", "Third instance ."],
+ 'words': [['this', 'is', 'the', 'first', 'instance', '.'], ['Second', 'instance', '.'], ['Third', 'instance', '.'],
+ 'seq_len': [6, 3, 3]}
+ dataset = DataSet(data)
+ # 传入的 dict 的每个 key 的 value 应该为具有相同长度的l ist
+
+1.2 通过 Instance 构建
+----------------------------
+
+ .. code-block::
+
+ from fastNLP import DataSet
+ from fastNLP import Instance
+ dataset = DataSet()
+ instance = Instance(sentence="This is the first instance",
+ words=['this', 'is', 'the', 'first', 'instance', '.'],
+ seq_len=6)
+ dataset.append(instance)
+ # 可以继续 append 更多内容,但是 append 的 instance 应该和第一个 instance 拥有完全相同的 field
+
+1.3 通过 List[Instance] 构建
+--------------------------------------
+
+ .. code-block::
+
+ from fastNLP import DataSet
+ from fastNLP import Instance
+ instances = []
+ winstances.append(Instance(sentence="This is the first instance",
+ ords=['this', 'is', 'the', 'first', 'instance', '.'],
+ seq_len=6))
+ instances.append(Instance(sentence="Second instance .",
+ words=['Second', 'instance', '.'],
+ seq_len=3))
+ dataset = DataSet(instances)
+
+--------------------------------------
+2.DataSet 与预处理
+--------------------------------------
+
+常见的预处理有如下几种:
+
+2.1 从某个文本文件读取内容
+--------------------------------------
+
+ .. code-block::
+
+ from fastNLP import DataSet
+ from fastNLP import Instance
+ dataset = DataSet()
+ filepath = 'some/text/file'
+ # 假设文件中每行内容如下(sentence label):
+ # This is a fantastic day positive
+ # The bad weather negative
+ # .....
+ with open(filepath, 'r') as f:
+ for line in f:
+ sent, label = line.strip().split('\t')
+ dataset.append(Instance(sentence=sent, label=label))
+
+
+2.2 对 DataSet 中的内容处理
+--------------------------------------
+
+ .. code-block::
+
+ from fastNLP import DataSet
+ data = {'sentence':["This is the first instance .", "Second instance .", "Third instance ."]}
+ dataset = DataSet(data)
+ # 将句子分成单词形式, 详见DataSet.apply()方法, 可以开启多进程来加快处理, 也可以更改展示的bar,目前支持 ``['rich', 'tqdm', None]``,
+ # 详细内容可以见 :class:`~fastNLP.core.dataset.DataSet`, 需要注意的时匿名函数不支持多进程
+ dataset.apply(lambda ins: ins['sentence'].split(), new_field_name='words',
+ progress_des='Main',progress_bar='rich')
+ # 或使用DataSet.apply_field()
+ dataset.apply_field(lambda sent:sent.split(), field_name='sentence', new_field_name='words',
+ progress_des='Main',progress_bar='rich')
+ # 除了匿名函数,也可以定义函数传递进去
+ def get_words(instance):
+ sentence = instance['sentence']
+ words = sentence.split()
+ return words
+ dataset.apply(get_words, new_field_name='words', num_proc=2, progress_des='Main',progress_bar='rich')
+
+2.3 删除DataSet的内容
+--------------------------------------
+
+ .. code-block::
+
+ from fastNLP import DataSet
+ dataset = DataSet({'a': list(range(-5, 5))})
+ # 返回满足条件的 instance,并放入 DataSet 中
+ dropped_dataset = dataset.drop(lambda ins:ins['a']<0, inplace=False)
+ # 在 dataset 中删除满足条件的i nstance
+ dataset.drop(lambda ins:ins['a']<0) # dataset 的 instance数量减少
+ # 删除第 3 个 instance
+ dataset.delete_instance(2)
+ # 删除名为 'a' 的 field
+ dataset.delete_field('a')
+
+
+2.4 遍历DataSet的内容
+--------------------------------------
+
+ .. code-block::
+
+ for instance in dataset:
+ # do something
+
+2.5 一些其它操作
+--------------------------------------
+
+ .. code-block::
+
+ # 检查是否存在名为 'a' 的 field
+ dataset.has_field('a') # 或 ('a' in dataset)
+ # 将名为 'a' 的 field 改名为 'b'
+ dataset.rename_field('a', 'b')
+ # DataSet 的长度
+ len(dataset)
+
+"""
+
+__all__ = [
+ "DataSet",
+ "ApplyResultException"
+]
+
+import _pickle as pickle
+from copy import deepcopy
+from typing import Optional, List, Callable, Union, Dict, Any, Mapping
+from types import LambdaType
+import sys
+import time
+
+import numpy as np
+
+from .field import FieldArray
+from .instance import Instance
+from fastNLP.core.utils.utils import pretty_table_printer, deprecated
+from fastNLP.core.collators import Collator
+from fastNLP.core.utils.rich_progress import f_rich_progress, DummyFRichProgress
+from fastNLP.core.utils.tqdm_progress import f_tqdm_progress
+from ..log import logger
+from fastNLP.core.utils.dummy_class import DummyClass
+from ..utils.utils import _get_fun_msg
+
+
+progress_bars = {
+ 'rich': f_rich_progress,
+ 'tqdm': f_tqdm_progress
+}
+
+
+class ApplyResultException(Exception):
+ def __init__(self, msg, index=None):
+ super().__init__(msg)
+ self.msg = msg
+ self.index = index # 标示在哪个数据遭遇到问题了
+
+
+def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, progress_bar: str = 'rich',
+ desc: str = None) -> list:
+ """
+ 对数据集进行处理封装函数,以便多进程使用
+
+ :param ds: 实现了 __getitem__() 和 __len__() 的对象
+ :param _apply_field: 需要处理数据集的 field_name
+ :param func: 用户自定义的 func
+ :param desc: 进度条的描述字符
+ :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
+ :return:
+ """
+ progress_bar = progress_bars.get(progress_bar, DummyFRichProgress())
+ desc = desc if desc else "Processing"
+ task_id = progress_bar.add_task(description=desc, total=len(ds))
+ results = []
+ idx = -1
+
+ try:
+ for idx, ins in enumerate(ds):
+ if _apply_field is not None:
+ results.append(func(ins[_apply_field]))
+ else:
+ results.append(func(ins))
+ progress_bar.update(task_id, advance=1)
+
+ except BaseException as e:
+ if idx != -1:
+ logger.error("Exception happens at the `{}`th instance.".format(idx))
+ raise e
+ finally:
+ progress_bar.destroy_task(task_id)
+ return results
+
+
+def _multi_proc(ds, _apply_field, func, counter, queue):
+ """
+ 对数据集进行处理封装函数,以便多进程使用
+
+ :param ds: 实现了 __getitem__() 和 __len__() 的对象
+ :param _apply_field: 需要处理数据集的 field_name
+ :param func: 用户自定义的 func
+ :param counter: 计数器
+ :param queue: 多进程时,将结果输入到这个 queue 中
+ :return:
+ """
+ idx = -1
+ import contextlib
+ null = DummyClass()
+ with contextlib.redirect_stdout(null): # 避免打印触发 rich 的锁
+ logger.set_stdout(stdout='raw')
+ results = []
+ try:
+ for idx, ins in enumerate(ds):
+ if _apply_field is not None:
+ res = func(ins[_apply_field])
+ else:
+ res = func(ins)
+ results.append(res)
+ with counter.get_lock():
+ counter.value += 1
+ except BaseException as e:
+ if idx != -1:
+ logger.error("Exception happens at the `{}`th instance.".format(idx))
+ raise e
+ queue.put(pickle.dumps(results))
+
+
+class DataSet:
+ r"""
+ fastNLP的数据容器。
+
+ Example::
+
+ from fastNLP.core.dataset import DataSet, Instance
+ data = {'x': [[1, 0, 1], [0, 1, 1], 'y': [0, 1]}
+ data1 = [Instance(x=[1,0,1],y=0), Instance(x=[0,1,1],y=1)]
+ ds = DataSet(data)
+ ds = DataSet(data1)
+
+ fastNLP的 DataSet 是 key-value 存储形式, 目前支持两种初始化方式,输入 data 分别为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 和
+ ``Dict[str, List[Any]]``。
+
+ * 当 data 为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 时, 每个 ``Instance`` 的 field_name 需要保持一致。
+ Instance 详见 :class:`~fastNLP.core.dataset.Instance` 。
+ * 当 data 为 ``Dict[str, List[Any]]`` 时, 则每个 key 的 value 应该为等长的 list, 否则不同 field 的长度不一致。
+
+ :param data: 初始化的内容,其只能为两种类型,分别为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 和
+ ``Dict[str, List[Any]]``。
+
+ * 当 data 为 ``List[:class:`~fastNLP.core.dataset.Instance`]`` 时, 每个 ``Instance`` 的 field_name 需要保持一致。
+ Instance 详见 :class:`~fastNLP.core.dataset.Instance` 。
+ * 当 data 为 ``Dict[str, List[Any]] 时, 则每个 key 的 value 应该为等长的 list, 否则不同 field 的长度不一致。
+ """
+ def __init__(self, data: Union[List[Instance], Dict[str, List[Any]], None] = None):
+ self.field_arrays = {}
+ self._collator = Collator()
+ if data is not None:
+ if isinstance(data, Dict):
+ length_set = {}
+ for key, value in data.items():
+ length_set[key] = len(value)
+ assert len(set(length_set.values())) == 1, f"Fields must all be of same length, instead of {length_set}."
+ for key, value in data.items():
+ self.add_field(field_name=key, fields=value)
+ elif isinstance(data, List):
+ for ins in data:
+ assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins))
+ self.append(ins)
+ else:
+ raise ValueError("data only be dict or list type.")
+
+ def __contains__(self, item):
+ return item in self.field_arrays
+
+ def __iter__(self):
+ for idx in range(len(self)):
+ yield self[idx]
+
+ def _inner_iter(self):
+ class Iter_ptr:
+ def __init__(self, dataset, idx):
+ self.dataset = dataset
+ self.idx = idx
+
+ def __getitem__(self, item):
+ assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[
+ self.idx])
+ assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx)
+ return self.dataset.field_arrays[item][self.idx]
+
+ def __setitem__(self, key, value):
+ raise TypeError("You cannot modify value directly.")
+
+ def items(self):
+ ins = self.dataset[self.idx]
+ return ins.items()
+
+ def __repr__(self):
+ return self.dataset[self.idx].__repr__()
+
+ def inner_iter_func():
+ for idx in range(len(self)):
+ yield Iter_ptr(self, idx)
+
+ return inner_iter_func()
+
+ def __getitem__(self, idx: Union[int, slice, str, list]):
+ r"""
+ 去 DataSet 的内容, 根据 idx 类型不同有不同的返回值。 包括四种类型 ``[int, slice, str, list]``
+
+ * 当 idx 为 ``int`` 时, idx 的值不能超过 ``DataSet`` 的长度, 会返回一个 ``Instance``, 详见
+ :class:`~fastNLP.core.dataset.Instance`
+ * 当 idx 为 ``slice`` 时, 会根据 slice 的内容创建一个新的 DataSet,其包含 slice 所有内容并返回。
+ * 当 idx 为 ``str`` 时, 该 idx 为 DataSet 的 field_name, 其会返回该 field_name 的所有内容, 为 list 类型。
+ * 当 idx 为 ``list`` 时, 该 idx 的 list 内全为 int 数字, 其会取出所有内容组成一个新的 DataSet 返回。
+
+ Example::
+
+ from fastNLP.core.dataset import DataSet
+
+ ds = DataSet({'x': [[1, 0, 1], [0, 1, 1] * 100, 'y': [0, 1] * 100})
+ ins = ds[0]
+ sub_ds = ds[0:100]
+ sub_ds= ds[[1, 0, 3, 2, 1, 4]]
+ field = ds['x']
+
+ :param idx: 用户传入参数
+ :return:
+ """
+ if isinstance(idx, int):
+ return Instance(**{name: self.field_arrays[name][idx] for name in self.field_arrays})
+ elif isinstance(idx, slice):
+ if idx.start is not None and (idx.start >= len(self) or idx.start <= -len(self)):
+ raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self) - 1}")
+ dataset = DataSet()
+ for field_name, field in self.field_arrays.items():
+ dataset.add_field(field_name=field_name, fields=field.content[idx])
+ dataset._collator = deepcopy(self.collator)
+ return dataset
+ elif isinstance(idx, str):
+ if idx not in self:
+ raise KeyError("No such field called {} in DataSet.".format(idx))
+ return self.field_arrays[idx]
+ elif isinstance(idx, list):
+ dataset = DataSet()
+ for i in idx:
+ assert isinstance(i, int), "Only int index allowed."
+ instance = self[i]
+ dataset.append(instance)
+ dataset._collator = deepcopy(self.collator)
+ return dataset
+ else:
+ raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))
+
+ def __setitem__(self, key, value):
+ assert isinstance(key, int) and key None:
+ r"""
+ 将一个 ``instance`` 对象 append 到 DataSet 后面。详见 :class:`~fastNLP.core.dataset.Instance`
+
+ :param instance: 若 DataSet 不为空,则 instance 应该拥有和 DataSet 完全一样的 field;
+ """
+ if len(self.field_arrays) == 0:
+ # DataSet has no field yet
+ for name, field in instance.items():
+ # field = field.tolist() if isinstance(field, np.ndarray) else field
+ self.field_arrays[name] = FieldArray(name, [field]) # 第一个样本,必须用list包装起来
+ else:
+ if len(self.field_arrays) != len(instance.fields):
+ raise ValueError(
+ "DataSet object has {} fields, but attempt to append an Instance object with {} fields."
+ .format(len(self.field_arrays), len(instance.fields)))
+ for name, field in instance.items():
+ assert name in self.field_arrays, f'Field:`{name}` is not found in {self.field_arrays.keys()}'
+ try:
+ self.field_arrays[name].append(field)
+ except Exception as e:
+ logger.error(f"Cannot append to field:{name}.")
+ raise e
+
+ def add_fieldarray(self, field_name: str, fieldarray: FieldArray) -> None:
+ r"""
+ 将 ``fieldarray`` 添加到 DataSet 中.
+
+ :param field_name: 新加入的 field 的名称;
+ :param fieldarray: 需要加入 DataSet 的 field 的内容, 详见 :class:`~fastNLP.core.dataset.FieldArray` ;
+ :return:
+ """
+ if not isinstance(fieldarray, FieldArray):
+ raise TypeError("Only fastNLP.FieldArray supported.")
+ if len(self) != len(fieldarray):
+ raise RuntimeError(f"The field to add must have the same size as dataset. "
+ f"Dataset size {len(self)} != field size {len(fieldarray)}")
+ fieldarray.name = field_name
+ self.field_arrays[field_name] = fieldarray
+
+ def add_field(self, field_name: str, fields: list) -> None:
+ r"""
+ 新增一个 field, 需要注意的是 fields 的长度跟 DataSet 长度一致
+
+ :param field_name: 新增的 field 的名称;
+ :param fields: 需要新增的 field 的内容;
+ """
+
+ if len(self.field_arrays) != 0:
+ if len(self) != len(fields):
+ raise RuntimeError(f"The field to add must have the same size as dataset. "
+ f"Dataset size {len(self)} != field size {len(fields)}")
+ self.field_arrays[field_name] = FieldArray(field_name, fields)
+
+ def delete_instance(self, index: int):
+ r"""
+ 删除第 ``index`` 个 Instance
+
+ :param index: 需要删除的 instance 的 index,序号从 `0` 开始。
+ """
+ assert isinstance(index, int), "Only integer supported."
+ if len(self) <= index:
+ raise IndexError("{} is too large for as DataSet with {} instances.".format(index, len(self)))
+ if len(self) == 1:
+ self.field_arrays.clear()
+ else:
+ for field in self.field_arrays.values():
+ field.pop(index)
+ return self
+
+ def delete_field(self, field_name: str):
+ r"""
+ 删除名为 ``field_name`` 的 field
+
+ :param field_name: 需要删除的 field 的名称;
+ """
+ if self.has_field(field_name):
+ self.field_arrays.pop(field_name)
+ else:
+ raise KeyError(f"Field:{field_name} not found in DataSet.")
+ return self
+
+ def copy_field(self, field_name: str, new_field_name: str):
+ r"""
+ 深度 copy 名为 ``field_name`` 的 field 到 ``new_field_name``
+
+ :param field_name: 需要 copy 的 field;
+ :param new_field_name: copy 生成的 field 名称;
+ :return: 数据集自身;
+ """
+ if not self.has_field(field_name):
+ raise KeyError(f"Field:{field_name} not found in DataSet.")
+ fieldarray = deepcopy(self.get_field(field_name))
+ fieldarray.name = new_field_name
+ self.add_fieldarray(field_name=new_field_name, fieldarray=fieldarray)
+ return self
+
+ def has_field(self, field_name: str) -> bool:
+ r"""
+ 判断 DataSet 中是否有名为 ``field_name`` 这个 field
+
+ :param field_name: field 的名称;
+ :return: 表示是否有名为 ``field_name`` 这个 field;
+ """
+ if isinstance(field_name, str):
+ return field_name in self.field_arrays
+ return False
+
+ def get_field(self, field_name: str) -> FieldArray:
+ r"""
+ 获取名为 ``field_name`` 的 field
+
+ :param field_name: field 的名称;
+ :return: 一个 :class:`~fastNLP.core.dataset.FieldArray` 对象;
+ """
+ if field_name not in self.field_arrays:
+ raise KeyError("Field name {} not found in DataSet".format(field_name))
+ return self.field_arrays[field_name]
+
+ def get_all_fields(self) -> dict:
+ r"""
+ :return: 一个 dict,key 为 field_name, value为对应的 :class:`~fastNLP.core.dataset.FieldArray` 对象。
+ """
+ return self.field_arrays
+
+ def get_field_names(self) -> list:
+ r"""
+ :return: 一个 list,包含所有 field 的名字
+ """
+ return sorted(self.field_arrays.keys())
+
+ def get_length(self) -> int:
+ r"""
+ 获取 DataSet 的元素数量
+
+ :return: DataSet 中 Instance 的个数。
+ """
+ return len(self)
+
+ def rename_field(self, field_name: str, new_field_name: str):
+ r"""
+ 将某个 field 重新命名.
+
+ :param field_name: 原来的 field 名称;
+ :param new_field_name: 修改为 new_name;
+ """
+ if field_name in self.field_arrays:
+ self.field_arrays[new_field_name] = self.field_arrays.pop(field_name)
+ self.field_arrays[new_field_name].name = new_field_name
+ else:
+ raise KeyError("DataSet has no field named {}.".format(field_name))
+ return self
+
+ def apply_field(self, func: Callable, field_name: str = None,
+ new_field_name: str = None, num_proc: int = 0,
+ progress_desc: str = None, progress_bar: str = 'rich'):
+ r"""
+ 将 :class:`DataSet` 每个 ``instance`` 中为 ``field_name`` 的 field 传给函数 ``func``,并写入到 ``new_field_name``
+ 中。
+
+ :param func: 对指定 field 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 field 的内容,返回值将被
+ 写入至 ``new_field_name`` 中。
+ :param field_name: 传入 ``func`` 的 field 名称;
+ :param new_field_name: 函数执行结果写入的 ``field`` 名称。该函数会将 ``func`` 返回的内容放入到 ``new_field_name`` 对
+ 应的 ``field`` 中,注意如果名称与已有的 field 相同则会进行覆盖。如果为 ``None`` 则不会覆盖和创建 field ;
+ :param num_proc: 使用进程的数量。
+
+ .. note::
+
+ 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时,
+ ``func`` 函数中的打印将不会输出。
+
+ :param progress_desc: 如果不为 ``None``,则会显示当前正在处理的进度条的名称;
+ :param progress_bar: 显示进度条的方式,支持 ``["rich", "tqdm", None]``。
+ :return: 从函数 ``func`` 中得到的返回值;
+ """
+ assert len(self) != 0, "Null DataSet cannot use apply_field()."
+ if not self.has_field(field_name=field_name):
+ raise KeyError("DataSet has no field named `{}`.".format(field_name))
+
+ try:
+ results = self._apply_process(num_proc=num_proc, func=func, progress_bar=progress_bar,
+ progress_desc=progress_desc, _apply_field=field_name)
+ except BaseException as e:
+ raise e
+
+ if new_field_name is not None:
+ self.add_field(field_name=new_field_name, fields=results)
+ return results
+
+ def apply_field_more(self, func: Callable = None, field_name: str = None,
+ modify_fields: bool = True, num_proc: int = 0,
+ progress_desc: str = None, progress_bar: str = 'rich'):
+ r"""
+ 将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的 field 传给 ``func``,并获取它的返回值。
+ ``func`` 可以返回一个或多个 field 上的结果。
+
+ .. note::
+ ``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`~fastNLP.core.dataset.DataSet.apply_more` 中关于 ``apply_more`` 与
+ ``apply`` 区别的介绍。
+
+ :param func: 对指定 field 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 field 的内容;返回值为一个字典,
+ key 是field 的名字,value 是对应的结果
+ :param field_name: 传入 ``func`` 的 field 名称;
+ :param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 ``True``
+ :param num_proc: 使用进程的数量。
+
+ .. note::
+
+ 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时,
+ ``func`` 函数中的打印将不会输出。
+
+ :param progress_desc: 如果不为 ``None``,则会显示当前正在处理的进度条的名称;
+ :param progress_bar: 显示进度条的方式,支持 ``["rich", "tqdm", None]``。
+ :return: 一个字典
+ """
+ assert len(self) != 0, "Null DataSet cannot use apply_field()."
+ if not self.has_field(field_name=field_name):
+ raise KeyError("DataSet has no field named `{}`.".format(field_name))
+ idx = -1
+ results = {}
+ apply_out = self._apply_process(num_proc, func, progress_desc=progress_desc,
+ progress_bar=progress_bar, _apply_field=field_name)
+ # 只检测第一个数据是否为dict类型,若是则默认所有返回值为dict;否则报错。
+ if not isinstance(apply_out[0], Mapping):
+ raise Exception(f"The result of func is not a Mapping, but a {type(apply_out[0])}")
+
+ for key, value in apply_out[0].items():
+ results[key] = [value]
+ # 尝试合并所有dict数据, idx+1 的原因是第一条数据不可能出现错误,默认第一条数据为准
+ try:
+ for idx, per_out in enumerate(apply_out[1:]):
+ if len(set(results.keys()) - set(per_out.keys())):
+ raise ApplyResultException("apply results have different fields", idx + 1)
+ for key, value in per_out.items():
+ results[key].append(value)
+
+ except Exception as e:
+ if idx != -1:
+ logger.error("Exception happens at the `{}`th instance.".format(idx + 1))
+ raise e
+
+ if modify_fields is True:
+ for field, result in results.items():
+ self.add_field(field_name=field, fields=result)
+
+ return results
+
+ def _apply_process(self, num_proc: int = 0, func: Callable = None,
+ progress_bar: str = 'rich', _apply_field: str = None,
+ progress_desc: str = 'Main') -> list:
+ """
+ :param num_proc: 使用进程的数量。
+
+ .. note::
+
+ 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时,
+ ``func`` 函数中的打印将不会输出。
+
+ :param func: 用户自定义处理函数,参数是 ``DataSet`` 中的 ``Instance``
+ :param _apply_field: 需要传进去func的数据集的field_name
+ :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
+ :param progress_desc: 进度条的描述字符,默认为'Main
+ """
+ if isinstance(func, LambdaType) and num_proc>1 and func.__name__ == "":
+ raise TypeError("Lambda function does not support multiple processes, please set `num_proc=0`.")
+ if num_proc>1 and sys.platform in ('win32', 'msys', 'cygwin'):
+ raise RuntimeError("Your platform does not support multiprocessing with fork, please set `num_proc=0`")
+
+ if num_proc < 2:
+ results = _apply_single(ds=self, _apply_field=_apply_field, func=func,
+ desc=progress_desc, progress_bar=progress_bar)
+ else:
+ # TODO 1. desc这个需要修改一下,应该把 subprocess 的 desc 修改一下。修改成Process 1 / Process 2
+ import multiprocessing as mp
+ ctx = mp.get_context('fork')
+ num_proc = min(num_proc, len(self))
+ # 划分数据集
+ shard_len = len(self) // num_proc
+ num_left_sample = len(self) % num_proc
+ start = 0
+ shard_data = []
+ for _i in range(num_proc):
+ end = shard_len + int(_i= 0, "num_proc must >= 0"
+ idx = -1
+
+ results = {}
+ apply_out = self._apply_process(num_proc, func, progress_desc=progress_desc,
+ progress_bar=progress_bar)
+ # 只检测第一个数据是否为dict类型,若是则默认所有返回值为dict;否则报错。
+ if not isinstance(apply_out[0], Mapping):
+ raise Exception(f"The result of func:{_get_fun_msg(func)} is not a dict, but of type {type(apply_out[0])}")
+
+ for key, value in apply_out[0].items():
+ results[key] = [value]
+ # 尝试合并所有dict数据, idx+1 的原因是第一条数据不可能出现错误,已经将第一条数据取出来
+ try:
+ for idx, per_out in enumerate(apply_out[1:]):
+ if len(set(results.keys()) - set(per_out.keys())):
+ raise ApplyResultException(f"Apply results have different fields:{set(results.keys())} and "
+ f"{set(per_out.keys())}", idx + 1)
+ for key, value in per_out.items():
+ results[key].append(value)
+
+ except Exception as e:
+ if idx != -1:
+ logger.error("Exception happens at the `{}`th instance.".format(idx + 1))
+ raise e
+
+ if modify_fields is True:
+ for field, result in results.items():
+ self.add_field(field_name=field, fields=result)
+
+ return results
+
+ def apply(self, func: Callable = None, new_field_name: str = None,
+ num_proc: int = 0, progress_bar: str = 'rich', progress_desc: str = ''):
+ """
+ 将 ``DataSet`` 中每个 ``Instance`` 传入到 ``func`` 中,并获取它的返回值。``func`` 仅能返回一个结果。
+
+ :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值将被写入 ``new_field_name`` 中。
+ :param new_field_name: 将 ``func`` 返回的内容放入到 ``new_field_name`` 这个 field中 ,如果名称与已有的 field 相同,则覆
+ 盖之前的 field。如果为 ``None`` 则不创建新的 field。
+ :param num_proc: 使用进程的数量。
+
+ .. note::
+
+ 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时,
+ ``func`` 函数中的打印将不会输出。
+
+ :param progress_bar: 显示进度条的方式,支持 ``["rich", "tqdm", None]``。
+ :param progress_desc: 如果不为 ``None``,则会显示当前正在处理的进度条的名称。
+ """
+ assert callable(func), "The func you provide is not callable."
+ assert len(self) != 0, "Null DataSet cannot use apply()."
+ assert num_proc >= 0, "num_proc must be an integer >= 0."
+ try:
+ results = self._apply_process(num_proc=num_proc, func=func, progress_bar=progress_bar,
+ progress_desc=progress_desc)
+ except BaseException as e:
+ raise e
+
+ if new_field_name is not None:
+ self.add_field(field_name=new_field_name, fields=results)
+
+ return results
+
+ def add_seq_len(self, field_name: str, new_field_name='seq_len'):
+ r"""
+ 将使用 :func:`len` 直接对 ``field_name`` 中每个元素作用,将其结果作为 sequence length, 并放入 ``new_field_name`` 这个 field。
+
+ :param field_name: 需要处理的 field_name
+ :param new_field_name: 新的 field_name
+ :return:
+ """
+ if self.has_field(field_name=field_name):
+ self.apply_field(len, field_name, new_field_name=new_field_name)
+ else:
+ raise KeyError(f"Field:{field_name} not found.")
+ return self
+
+ def drop(self, func: Callable, inplace=True):
+ r"""
+ 删除某些 Instance。 需要注意的是 ``func`` 接受一个 Instance ,返回 bool 值。返回值为 ``True`` 时,
+ 该 Instance 会被移除或者不会包含在返回的 DataSet 中。
+
+ :param func: 接受一个 Instance 作为参数,返回 bool 值。为 ``True`` 时删除该 instance
+ :param inplace: 是否在当前 DataSet 中直接删除 instance;如果为 False,将返回一个新的 DataSet。
+
+ :return: DataSet
+ """
+ if inplace:
+ results = [ins for ins in self if not func(ins)]
+ for name, old_field in self.field_arrays.items():
+ self.field_arrays[name].content = [ins[name] for ins in results]
+ return self
+ else:
+ results = [ins for ins in self if not func(ins)]
+ if len(results) != 0:
+ dataset = DataSet(results)
+ else:
+ dataset = DataSet()
+ for name in self.field_arrays.keys():
+ empty_field = FieldArray(name, [None])
+ empty_field.content = []
+ dataset.field_arrays[name] = empty_field
+ return dataset
+
+ def split(self, ratio: float, shuffle=True):
+ r"""
+ 将 DataSet 按照 ``ratio`` 的比例拆分,返回两个 DataSet
+
+ :param ratio: 0 1, f'DataSet with {len(self)} instance cannot be split.'
+ assert isinstance(ratio, float)
+ assert 0 < ratio < 1
+ all_indices = [_ for _ in range(len(self))]
+ if shuffle:
+ np.random.shuffle(all_indices)
+ split = int(ratio * len(self))
+ if split == 0:
+ error_msg = f'Dev DataSet has `{split}` instance after split.'
+ raise IndexError(error_msg)
+ dev_indices = all_indices[:split]
+ train_indices = all_indices[split:]
+ dev_set = DataSet()
+ train_set = DataSet()
+ for idx in dev_indices:
+ dev_set.append(self[idx])
+ for idx in train_indices:
+ train_set.append(self[idx])
+ dev_set._collator = deepcopy(self.collator)
+ train_set._collator = deepcopy(self.collator)
+
+ return dev_set, train_set
+
+ def save(self, path: str) -> None:
+ r"""
+ 保存 DataSet。
+
+ :param path: 保存路径;
+ """
+ with open(path, 'wb') as f:
+ pickle.dump(self, f)
+
+ @staticmethod
+ def load(path: str):
+ r"""
+ 从保存的 DataSet pickle 文件的路径中读取 DataSet
+
+ :param path: 读取路径;
+ :return: 读取出的 DataSet
+ """
+ with open(path, 'rb') as f:
+ d = pickle.load(f)
+ assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d))
+ return d
+
+ def concat(self, dataset: 'DataSet', inplace:bool=True, field_mapping:Dict=None) -> 'DataSet':
+ """
+ 将当前 DataSet 与输入的 ``dataset`` 结合成一个更大的 dataset,需要保证两个 dataset 都包含了相同的 field。结合后的 dataset
+ 的 field_name 和 _collator 以当前 dataset 为准。若 ``dataset`` 中包含的 field 多于当前的 DataSet,则多余的 field 会被忽略;
+ 若 ``dataset`` 中未包含所有当前 DataSet 含有 field,则会报错。
+
+ :param dataset: 需要和当前 DataSet 拼接的 ``dataset``;
+ :param inplace: 是否直接将 ``dataset`` 组合到当前 DataSet 中;
+ :param field_mapping: 当传入的 ``dataset`` 中的 field 名称和当前 dataset 不一致时,需要通过 ``field_mapping`` 把输入的 ``dataset``
+ 中的 field 名称映射到当前 field。``field_mapping`` 为 dict 类型,key 为 11dataset`` 中的 field 名称,value 是需要映射成的名称
+
+ :return: :class:`~fastNLP.core.dataset.DataSet`
+ """
+ assert isinstance(dataset, DataSet), "Can only concat two datasets."
+
+ fns_in_this_dataset = set(self.get_field_names())
+ fns_in_other_dataset = dataset.get_field_names()
+ reverse_field_mapping = {}
+ if field_mapping is not None:
+ fns_in_other_dataset = [field_mapping.get(fn, fn) for fn in fns_in_other_dataset]
+ reverse_field_mapping = {v: k for k, v in field_mapping.items()}
+ fns_in_other_dataset = set(fns_in_other_dataset)
+ fn_not_seen = list(fns_in_this_dataset - fns_in_other_dataset)
+
+ if fn_not_seen:
+ raise RuntimeError(f"The following fields are not provided in the dataset:{fn_not_seen}")
+
+ if inplace:
+ ds = self
+ else:
+ ds = deepcopy(self)
+
+ for fn in fns_in_this_dataset:
+ ds.get_field(fn).content.extend(deepcopy(dataset.get_field(reverse_field_mapping.get(fn, fn)).content))
+
+ return ds
+
+ @classmethod
+ def from_pandas(cls, df):
+ """
+ 从 :class:`pandas.DataFrame` 中读取并数据转化为 DataSet
+
+ :param df: 使用 pandas 读取的数据
+ :return:
+ """
+ df_dict = df.to_dict(orient='list')
+ return cls(df_dict)
+
+ def to_pandas(self):
+ """
+ 将 DataSet 数据转为 :class:`pandas.DataFrame` 类型的数据
+
+ :return:
+ """
+ import pandas as pd
+ dict_ = {key: value.content for key, value in self.field_arrays.items()}
+ return pd.DataFrame.from_dict(dict_)
+
+ def to_csv(self, path: str):
+ """
+ 将 DataSet 保存为 csv 文件
+
+ :param path: 保存到路径
+ :return:
+ """
+
+ df = self.to_pandas()
+ return df.to_csv(path, encoding="utf-8")
+
+ @property
+ def collator(self) -> Collator:
+ if self._collator is None:
+ self._collator = Collator()
+ return self._collator
+
+ def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None,
+ pad_fn: Callable = None) -> Collator:
+ """
+ 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。
+
+ :param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的
+ field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``;
+ 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。
+ 如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 "_single" 。
+ :param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
+ field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``,
+ 该值无意义。
+ :param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。
+ :param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`,
+ :class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。
+ 若 ``pad_val`` 为 ``None`` ,该值无意义 。
+ :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的
+ batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。
+ :return: 自身的 collator;
+ """
+ if isinstance(self.collator, Collator):
+ self.collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
+ return self.collator
+ else:
+ raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")
+
+ def set_ignore(self, *field_names) -> Collator:
+ """
+ ``DataSet`` 中想要对绑定的 collator 进行调整可以调用此函数。 ``collator`` 为 :class:`~fastNLP.core.collators.Collator`
+ 时该函数才有效。调用该函数可以设置忽略输出某些 field 的内容,被设置的 field 将在 batch 的输出中被忽略::
+
+ dataset.set_ignore('field1', 'field2')
+
+ :param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的
+ field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``;
+ 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。
+ :return: 自身的 collator;
+ """
+ if isinstance(self.collator, Collator):
+ self.collator.set_ignore(*field_names)
+ return self.collator
+ else:
+ raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
+
+ @classmethod
+ def from_datasets(cls, dataset):
+ """
+ 将 Huggingface Dataset 转为 fastNLP 的 DataSet
+
+ :param dataset: 实例化好的 huggingface Dataset 对象
+ """
+ from datasets import Dataset
+ if not isinstance(dataset, Dataset):
+ raise ValueError(f"Support huggingface dataset, but is {type(dataset)}!")
+
+ data_dict = dataset.to_dict()
+ return DataSet(data_dict)
\ No newline at end of file
diff --git a/fastNLP/core/dataset/field.py b/fastNLP/core/dataset/field.py
new file mode 100644
index 00000000..a272c1ee
--- /dev/null
+++ b/fastNLP/core/dataset/field.py
@@ -0,0 +1,246 @@
+r"""
+.. todo::
+ doc
+"""
+__all__ = [
+ 'FieldArray'
+]
+
+from collections import Counter
+from typing import Any, Union, List, Callable
+from ..log import logger
+
+import numpy as np
+
+
+class FieldArray:
+ """
+ :class:`~fastNLP.core.dataset.DatSet` 中用于表示列的数据类型。
+
+ :param name: 字符串的名称
+ :param content: 任意类型的数据
+ """
+
+ def __init__(self, name: str, content):
+ if len(content) == 0:
+ raise RuntimeError("Empty fieldarray is not allowed.")
+ _content = content
+ try:
+ _content = list(_content)
+ except BaseException as e:
+ logger.error(f"Cannot convert content(of type:{type(content)}) into list.")
+ raise e
+ self.name = name
+ self.content = _content
+
+ def append(self, val: Any) -> None:
+ r"""
+ :param val: 把该 ``val`` 添加到 fieldarray 中。
+ """
+ self.content.append(val)
+
+ def pop(self, index: int) -> None:
+ r"""
+ 删除该 field 中 ``index`` 处的元素
+
+ :param index: 从 ``0`` 开始的数据下标。
+ """
+ self.content.pop(index)
+
+ def __iter__(self):
+ for idx in range(len(self)):
+ yield self[idx]
+
+ def __getitem__(self, indices: Union[int, List[int]]):
+ return self.get(indices)
+
+ def __setitem__(self, idx: int, val: Any):
+ assert isinstance(idx, int)
+ self.content[idx] = val
+
+ def get(self, indices: Union[int, List[int]]):
+ r"""
+ 根据给定的 ``indices`` 返回内容。
+
+ :param indices: 获取 ``indices`` 对应的内容。
+ :return: 根据给定的 ``indices`` 返回的内容,可能是单个值 或 :class:`numpy.ndarray`
+ """
+ if isinstance(indices, int):
+ if indices == -1:
+ indices = len(self) - 1
+ assert 0 <= indices < len(self)
+ return self.content[indices]
+ try:
+ contents = [self.content[i] for i in indices]
+ except BaseException as e:
+ raise e
+ return np.array(contents)
+
+ def __len__(self):
+ r"""
+ 返回长度
+
+ :return:
+ """
+ return len(self.content)
+
+ def split(self, sep: str = None, inplace: bool = True):
+ r"""
+ 依次对自身的元素使用 ``.split()`` 方法,应该只有当本 field 的元素为 :class:`str` 时,该方法才有用。
+
+ :param sep: 分割符,如果为 ``None`` 则直接调用 ``str.split()``。
+ :param inplace: 如果为 ``True``,则将新生成值替换本 field。否则返回 :class:`list`。
+ :return: List[List[str]] or self
+ """
+ new_contents = []
+ for index, cell in enumerate(self.content):
+ try:
+ new_contents.append(cell.split(sep))
+ except Exception as e:
+ logger.error(f"Exception happens when process value in index {index}.")
+ raise e
+ return self._after_process(new_contents, inplace=inplace)
+
+ def int(self, inplace: bool = True):
+ r"""
+ 将本 field 中的值调用 ``int(cell)``. 支持 field 中内容为以下两种情况:
+
+ * ['1', '2', ...](即 field 中每个值为 :class:`str` 的),
+ * [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 :class:`list` ,:class:`list` 中的值会被依次转换。)
+
+ :param inplace: 如果为 ``True``,则将新生成值替换本 field,并返回当前 field 。否则返回 :class:`list`。
+ :return: List[int], List[List[int]], self
+ """
+ new_contents = []
+ for index, cell in enumerate(self.content):
+ try:
+ if isinstance(cell, list):
+ new_contents.append([int(value) for value in cell])
+ else:
+ new_contents.append(int(cell))
+ except Exception as e:
+ print(f"Exception happens when process value in index {index}.")
+ raise e
+ return self._after_process(new_contents, inplace=inplace)
+
+ def float(self, inplace=True):
+ r"""
+ 将本 field 中的值调用 ``float(cell)``. 支持 field 中内容为以下两种情况:
+
+ * ['1', '2', ...](即 field 中每个值为 :class:`str` 的),
+ * [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 :class:`list` ,:class:`list` 中的值会被依次转换。)
+
+ :param inplace: 如果为 ``True``,则将新生成值替换本 field,并返回当前 field 。否则返回 :class:`list`。
+ :return:
+ """
+ new_contents = []
+ for index, cell in enumerate(self.content):
+ try:
+ if isinstance(cell, list):
+ new_contents.append([float(value) for value in cell])
+ else:
+ new_contents.append(float(cell))
+ except Exception as e:
+ print(f"Exception happens when process value in index {index}.")
+ raise e
+ return self._after_process(new_contents, inplace=inplace)
+
+ def bool(self, inplace=True):
+ r"""
+ 将本field中的值调用 ``bool(cell)``. 支持 field 中内容为以下两种情况
+
+ * ['1', '2', ...](即 field 中每个值为 :class:`str` 的),
+ * [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 :class:`list` ,:class:`list` 中的值会被依次转换。)
+
+ :param inplace: 如果为 ``True``,则将新生成值替换本 field,并返回当前 field 。否则返回 :class:`list`。
+ :return:
+ """
+ new_contents = []
+ for index, cell in enumerate(self.content):
+ try:
+ if isinstance(cell, list):
+ new_contents.append([bool(value) for value in cell])
+ else:
+ new_contents.append(bool(cell))
+ except Exception as e:
+ print(f"Exception happens when process value in index {index}.")
+ raise e
+
+ return self._after_process(new_contents, inplace=inplace)
+
+ def lower(self, inplace=True):
+ r"""
+ 将本 field 中的值调用 ``cell.lower()``, 支持 field 中内容为以下两种情况
+
+ * ['1', '2', ...](即 field 中每个值为 :class:`str` 的),
+ * [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 :class:`list` ,:class:`list` 中的值会被依次转换。)
+
+ :param inplace: 如果为 ``True``,则将新生成值替换本 field,并返回当前 field 。否则返回 :class:`list`。
+ :return: List[int], List[List[int]], self
+ """
+ new_contents = []
+ for index, cell in enumerate(self.content):
+ try:
+ if isinstance(cell, list):
+ new_contents.append([value.lower() for value in cell])
+ else:
+ new_contents.append(cell.lower())
+ except Exception as e:
+ print(f"Exception happens when process value in index {index}.")
+ raise e
+ return self._after_process(new_contents, inplace=inplace)
+
+ def upper(self, inplace=True):
+ r"""
+ 将本 field 中的值调用 ``cell.upper()``, 支持 field 中内容为以下两种情况
+
+ * ['1', '2', ...](即 field 中每个值为 :class:`str` 的),
+ * [['1', '2', ..], ['3', ..], ...](即 field 中每个值为一个 :class:`list` ,:class:`list` 中的值会被依次转换。)
+
+ :param inplace: 如果为 ``True``,则将新生成值替换本 field,并返回当前 field 。否则返回 :class:`list`。
+ :return: List[int], List[List[int]], self
+ """
+ new_contents = []
+ for index, cell in enumerate(self.content):
+ try:
+ if isinstance(cell, list):
+ new_contents.append([value.upper() for value in cell])
+ else:
+ new_contents.append(cell.upper())
+ except Exception as e:
+ print(f"Exception happens when process value in index {index}.")
+ raise e
+ return self._after_process(new_contents, inplace=inplace)
+
+ def value_count(self) -> Counter:
+ r"""
+ 返回该 field 下不同 value 的数量。多用于统计 label 数量
+
+ :return: 计数结果,key 是 label,value 是出现次数
+ """
+ count = Counter()
+
+ def cum(cells):
+ if isinstance(cells, Callable) and not isinstance(cells, str):
+ for cell_ in cells:
+ cum(cell_)
+ else:
+ count[cells] += 1
+
+ for cell in self.content:
+ cum(cell)
+ return count
+
+ def _after_process(self, new_contents: list, inplace: bool):
+ r"""
+ 当调用处理函数之后,决定是否要替换 field。
+
+ :param new_contents:
+ :param inplace:
+ :return: self或者生成的content
+ """
+ if inplace:
+ self.content = new_contents
+ return self
+ else:
+ return new_contents
diff --git a/fastNLP/core/dataset/instance.py b/fastNLP/core/dataset/instance.py
new file mode 100644
index 00000000..b721472e
--- /dev/null
+++ b/fastNLP/core/dataset/instance.py
@@ -0,0 +1,79 @@
+r"""
+instance 模块实现了 Instance 类,即在 fastNLP 中 sample 对应的类型。一个 sample 可以认为是一个 Instance 类型的对象。
+便于理解的例子可以参考文档 :mod:`fastNLP.core.dataset.dataset` 。
+"""
+
+__all__ = [
+ "Instance"
+]
+
+from typing import Mapping
+from fastNLP.core.utils.utils import pretty_table_printer
+
+
+class Instance(Mapping):
+ r"""
+ Instance 是 fastNLP 中对应一个 sample 的类。每个 sample 在 fastNLP 中是一个 Instance 对象。
+ Instance 一般与 :class:`~fastNLP.DataSet` 一起使用, Instance 的初始化如下面的代码所示::
+
+ >>> instance = Instance(input="this is a demo sentence", label='good')
+
+ """
+
+ def __init__(self, **fields):
+
+ self.fields = fields
+
+ def add_field(self, field_name: str, field: any):
+ r"""
+ 向 Instance 中增加一个 field
+
+ :param field_name: 新增 field 的名称
+ :param field: 新增 field 的内容
+ """
+ self.fields[field_name] = field
+
+ def items(self):
+ r"""
+ 返回一个迭代器,迭代器返回两个内容,第一个内容是 field_name, 第二个内容是 field_value
+
+ :return: 一个迭代器
+ """
+ return self.fields.items()
+
+ def keys(self):
+ r"""
+ 返回一个迭代器,内容是 field_name
+
+ :return: 一个迭代器
+ """
+ return self.fields.keys()
+
+ def values(self):
+ r"""
+ 返回一个迭代器,内容是 field_value
+
+ :return: 一个迭代器
+ """
+ return self.fields.values()
+
+ def __contains__(self, item):
+ return item in self.fields
+
+ def __getitem__(self, name):
+ if name in self.fields:
+ return self.fields[name]
+ else:
+ raise KeyError("{} not found".format(name))
+
+ def __setitem__(self, name, field):
+ return self.add_field(name, field)
+
+ def __repr__(self):
+ return str(pretty_table_printer(self))
+
+ def __len__(self):
+ return len(self.fields)
+
+ def __iter__(self):
+ return iter(self.fields)
diff --git a/fastNLP/core/dist_trainer.py b/fastNLP/core/dist_trainer.py
deleted file mode 100644
index 74ac7028..00000000
--- a/fastNLP/core/dist_trainer.py
+++ /dev/null
@@ -1,521 +0,0 @@
-r"""
-分布式 Trainer
-使用步骤
-1. 在代码中调用 DistTrainer,类似 Trainer,传入模型和数据等等参数
-2. 在命令行中,将 python your_script.py 替换为 python -m torch.distributed.launch --nproc_per_node=N your_script.py
-"""
-import logging
-import os
-import time
-from datetime import datetime
-
-import contextlib
-import torch
-import torch.cuda
-import torch.distributed as dist
-import torch.optim
-from torch.serialization import default_restore_location
-from pkg_resources import parse_version
-from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.utils.data.distributed import DistributedSampler
-from tqdm import tqdm
-import time
-
-from ._logger import logger, init_logger_dist
-from .batch import DataSetIter, BatchIter
-from .callback import DistCallbackManager, CallbackException
-from .callback import _TesterCallback
-from .dataset import DataSet
-from .losses import _prepare_losser
-from .optimizer import Optimizer
-from .utils import _build_args
-from .utils import _build_fp16_env
-from .utils import _get_func_signature
-from .utils import _move_dict_value_to_device
-from .sampler import Sampler
-
-__all__ = [
- 'get_local_rank',
- 'DistTrainer',
-]
-
-def get_local_rank():
- r"""
- 返回当前进程的 local rank, 0 到 N-1 ,N为当前分布式总进程数
- """
- if 'LOCAL_RANK' in os.environ:
- return int(os.environ['LOCAL_RANK'])
- from argparse import ArgumentParser
- parser = ArgumentParser()
- parser.add_argument('--local_rank', type=int)
- args, _ = parser.parse_known_args()
- if 'local_rank' in args and args.local_rank:
- os.environ['LOCAL_RANK'] = str(args.local_rank) # for multiple calls for this function
- return args.local_rank
- raise RuntimeError('Please use "python -m torch.distributed.launch --nproc_per_node=N train_script.py')
-
-
-class DistTrainer:
- r"""
- 分布式的 Trainer,支持分布式训练和混合精度的训练。具体实现原理请阅读 pytorch 官方文档。
-
- Note: 使用分布式 Trainer 时会同时有多个进程执行训练代码。因此将单进程的训练代码改为多进程之前,
- 请仔细检查,确保训练代码中的同步和互斥操作能正确执行(如模型保持,打印日志等)
- """
- def __init__(self, train_data, model, optimizer=None, loss=None,
- callbacks_all=None, callbacks_master=None,
- batch_size_per_gpu=8, n_epochs=1,
- num_workers=1, drop_last=False,
- dev_data=None, metrics=None, metric_key=None,
- update_every=1, print_every=10, validate_every=-1,
- save_path=None, device='auto',
- fp16=False, use_tqdm=True, sampler=None, **kwargs):
- r"""
-
- :param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。
- :param nn.modules, DDP model: 待训练的模型
- :param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器
- :param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward`
- :param list callbacks_all: 用于在train过程中起调节作用的回调函数,作用于所有训练进程中。
- 可使用的callback参见 :mod:`callback模块 `
- :param list callbacks_master: 用于在train过程中起调节作用的回调函数,只作用于其中一个进程( Master 进程)。
- 可使用的callback参见 :mod:`callback模块 `
- :param int batch_size_per_gpu: 训练时,每个进程的 batch 大小。
- :param int n_epochs: 需要优化迭代多少次。
- :param num_workers: int, 有多少个线程来进行数据pad处理。
- :param drop_last: 如果最后一个batch没有正好为batch_size这么多数据,就扔掉最后一个batch
- :param dev_data: 用于做验证的DataSet, :class:`~fastNLP.DataSet` 类型。
- :param metrics: 验证的评估函数。可以只使用一个 :class:`Metric` ,
- 也可以使用多个 :class:`Metric` ,通过列表传入。
- 如验证时取得了更好的验证结果(如果有多个Metric,以列表中第一个Metric为准),且save_path不为None,
- 则保存当前模型。Metric种类详见 :mod:`metrics模块 ` 。仅在传入dev_data时有效。
- :param str,None metric_key: :class:`Metric` 有时会有多个指标,
- 比如 :class:`~fastNLP.core.metrics.SpanFPreRecMetric` 中包含了'f', 'pre', 'rec'。此时需
- 要指定以哪个指标为准。另外有些指标是越小效果越好,比如语言模型的困惑度,这种情况下,在key前面增加一个'-'来表
- 明验证时,值越小越好(比如: "-ppl")。仅在传入dev_data时有效。
- :param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128
- 会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。
- :param int print_every: 多少次反向传播更新tqdm显示的loss; 如果use_tqdm=False, 则多少次反向传播打印loss。
- :param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。
- :param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存
- 最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。
- :param str device: 指定 device,可以是 gpu,cpu 或 auto
- :param bool fp16: 指定是否使用半精度训练。
- :param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。
- :param Sampler sampler: 使用的sampler,如果不指定,默认使用的DistributedSampler。使用这个参数的情况一般为,明确修改了每个
- rank的Dataset,使得每个rank上的dataset虽然sample数量一样多,但是sample其实不一样。
- :param kwargs: 支持配置可选参数
- bool test_use_tqdm: 在dev上验证的时候是否开启tqdm
- Sampler test_sampler: 在evaluate的时候使用的sampler
- int dev_batch_size: 在evaluate时,使用的evaluate的batch大小
- bool test_use_fp16: test时使用fp16
- bool set_grad_to_none: zero_grad时将grad设为None而不是0
- GradScaler grad_scaler: 自定义的梯度 scaler
- bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快数据速度。一般在tensor较多或tensor维度较大时,有速度增益。
- bool find_unused_parameters: 在将model转化为DistributedDataParallel类型的时候,需要填入该参数,除非model内确实有
- forward没用上的参数,否则应该不需要用到该参数。
- """
- assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']"
- if device == 'auto':
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
-
- # init distributed
- if device == 'cuda':
- torch.cuda.set_device(get_local_rank())
- self.device = torch.device("cuda", get_local_rank())
- else:
- self.device = torch.device(device)
-
- init_logger_dist()
-
- self.world_size = dist.get_world_size()
- self.rank = dist.get_rank() # unique id for each process
-
- self.train_data = train_data
- self.kwargs = kwargs
- if kwargs.get('batch_size', None):
- batch_size_per_gpu = int(kwargs.get('batch_size'))
- self.batch_size_per_gpu = int(batch_size_per_gpu)
- self.n_epochs = int(n_epochs)
- self.num_data_workers = int(num_workers)
- self.drop_last = drop_last
- self.update_every = int(update_every)
- self.print_every = int(print_every)
- self.validate_every = int(validate_every)
- self.save_path = save_path
- self.losser = _prepare_losser(loss)
- self.fp16 = fp16
- self.local_rank = get_local_rank()
- self.callback_manager = DistCallbackManager(
- env={"trainer": self}, callbacks_all=callbacks_all,
- callbacks_master=callbacks_master)
- self.test_manager = DistCallbackManager(env={'trainer': self})
- self.metric_key = metric_key
- self.use_tqdm = use_tqdm
-
- # init fp16, must before DataParallel init
- autocast, GradScaler = _build_fp16_env(dummy=not self.fp16)
- self.auto_cast = autocast
- user_grad_scaler = kwargs.get('grad_scaler', None)
- if user_grad_scaler is not None:
- assert self.fp16, "must set fp16=True to enable grad_scaler"
- grad_scaler = user_grad_scaler
- else:
- grad_scaler = GradScaler()
- self.grad_scaler = grad_scaler
-
- self.set_grad_to_none = kwargs.get('set_grad_to_none', False)
- # init DataParallel
- if isinstance(model, DDP):
- self.ddp_model = model
- else:
- model.to(self.device)
- if parse_version(torch.__version__)>=parse_version('1.1'):
- self.ddp_model = DDP(model, device_ids=[self.local_rank],
- output_device=self.local_rank,
- find_unused_parameters=kwargs.get('find_unused_parameters', False))
- else:
- self.ddp_model = DDP(model, device_ids=[self.local_rank],
- output_device=self.local_rank)
- self.model = self.ddp_model.module
-
- self._forward_func = self.model.forward
- self.model.to(self.device)
-
- optimizer = self._get_optimizer(optimizer)
- self.optimizer = optimizer
- if isinstance(self.train_data, DataSet):
- if sampler is None:
- self.sampler = DistributedSampler(self.train_data)
- else:
- # sampler check
- if sampler is not None and not isinstance(sampler, (Sampler, torch.utils.data.Sampler)):
- raise ValueError(
- f"The type of sampler should be fastNLP.BaseSampler or pytorch's Sampler, got {type(sampler)}")
- elif hasattr(sampler, 'set_batch_size'):
- sampler.set_batch_size(batch_size_per_gpu)
- self.sampler = sampler
- # concerning issue from https://github.com/pytorch/pytorch/issues/57273
- self.pin_memory = kwargs.get('pin_memory', False if parse_version(torch.__version__)==parse_version('1.9') else True)
- self.data_iterator = self._get_data_iter(self.train_data)
- self.batch_size = self.world_size * self.batch_size_per_gpu
- self.n_steps = self._get_n_steps()
-
- self.dev_data = dev_data
- self.metrics = metrics
- self.test_use_tqdm = True
- self.test_use_tqdm = kwargs.get('test_use_tqdm', self.use_tqdm)
- dev_batch_size = kwargs.get('dev_batch_size', batch_size_per_gpu)
-
- # for evaluation, only run eval on master proc
- if dev_data and metrics:
- cb = _TesterCallback(
- dev_data, self.model, metrics,
- batch_size=dev_batch_size, num_workers=num_workers, sampler=kwargs.get('test_sampler', None),
- use_tqdm=self.test_use_tqdm)
- self.test_manager.add_callback([cb], master=True)
- # Setup logging
- # 同步start_time
- sync_time = torch.tensor(time.time(), dtype=torch.double).to(self.device)
- dist.broadcast(sync_time, src=0)
- self.start_time = datetime.fromtimestamp(sync_time.item()).strftime('%Y-%m-%d-%H-%M-%S-%f')
- # print('sync_time: {}, start_time: {}'.format(sync_time, self.start_time))
-
- if self.save_path:
- self.cp_save_path = self.save_path
- else:
- self.cp_save_path = None
- # use INFO in the master, WARN for others
- self.logger = logger
- self.logger.info("Setup Distributed Trainer")
- self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format(
- os.getpid(), self.rank, self.local_rank, self.device, self.fp16))
- self.logger.info("Num of processes: {}".format(self.world_size))
- self.logger.info("Use device: {}".format(device))
-
- def _get_n_steps(self):
- return len(self.data_iterator) * self.n_epochs
-
- def _get_data_iter(self, dataset):
- if isinstance(dataset, DataSet):
- return DataSetIter(dataset=dataset, batch_size=self.batch_size_per_gpu, sampler=self.sampler,
- num_workers=self.num_data_workers, drop_last=self.drop_last,
- pin_memory=self.pin_memory)
- elif isinstance(dataset, BatchIter):
- return dataset
- else:
- raise TypeError("train_data type {} not support".format(type(dataset)))
-
- def _get_optimizer(self, optimizer):
- if isinstance(optimizer, torch.optim.Optimizer):
- return optimizer
- elif isinstance(optimizer, Optimizer):
- return optimizer.construct_from_pytorch(self.ddp_model.parameters())
- elif optimizer is None:
- return torch.optim.Adam(self.ddp_model.parameters(), lr=4e-3)
- else:
- if not (hasattr(optimizer, 'step') and callable(optimizer.step)):
- raise TypeError("optimizer must have a callable step() function.")
- else:
- self.optimizer = optimizer
- @property
- def is_master(self):
- r"""是否是主进程"""
- return self.rank == 0
-
- def train(self, load_best_model=True, on_exception='auto'):
- r"""
- 使用该函数使Trainer开始训练。
-
- :param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。
- 支持'ignore','raise', 'auto': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出;
- 'auto'将ignore以下两种Exception: CallbackException与KeyboardInterrupt, raise其它exception.
- :return dict: 返回一个字典类型的数据,
- 内含以下内容::
-
- seconds: float, 表示训练时长
- 以下三个内容只有在提供了dev_data的情况下会有。
- best_eval: Dict of Dict, 表示evaluation的结果。第一层的key为Metric的名称,
- 第二层的key为具体的Metric
- best_epoch: int,在第几个epoch取得的最佳值
- best_step: int, 在第几个step(batch)更新取得的最佳值
-
- """
- try:
- self.logger.info("###### Training epochs started ######")
- self.logger.info('Total epochs: %d'% self.n_epochs)
- self.logger.info('Total steps: %d'% self.n_steps)
- self.logger.info('Num instances per GPU: %d'% self.batch_size_per_gpu)
- self.logger.info('Num of steps per update: %d' % self.update_every)
- self.logger.info('Total batch_size: %d'%
- (self.batch_size_per_gpu * dist.get_world_size() * self.update_every))
- self.logger.info('Total num of samples: %d'% len(self.train_data))
- self.logger.info("Num of callbacks for all workers: {}".format(
- len(self.callback_manager.callbacks_all)))
- self.logger.info("Num of callbacks for master workers: {}".format(
- len(self.callback_manager.callbacks_master)))
- self.logger.info("Callbacks for all workers: {}".format(
- [repr(cb) for cb in self.callback_manager.callbacks_all]))
- self.logger.info("Callbacks for master workers: {}".format(
- [repr(cb) for cb in self.callback_manager.callbacks_master]))
-
- start_time = time.time()
- results = {}
- if self.n_epochs <= 0:
- self.logger.info("Training epoch is {}, nothing was done.".format(self.n_epochs))
- results['seconds'] = 0.
- return results
-
- try:
- self.callback_manager.on_train_begin()
- self._train()
- self.callback_manager.on_train_end()
-
- except BaseException as e:
- self.callback_manager.on_exception(e)
- if on_exception == 'auto':
- if not isinstance(e, (CallbackException, KeyboardInterrupt)):
- raise e
- else:
- self.logger.info('Catch {}, ignored.'.format(e.__class__.__name__))
- elif on_exception == 'raise':
- raise e
-
- results['seconds'] = round(time.time() - start_time, 2)
- self.logger.info("###### Train finished ######")
- self.logger.info('Total train time: {} seconds.'. format(results['seconds']))
- if load_best_model and self.cp_save_path and len(self.test_manager.callbacks):
- self.load_check_point(self._best_save_name())
- finally:
- pass
- dist.barrier()
- return results
-
- def _train(self):
- dist.barrier()
- if not self.use_tqdm:
- from .utils import _pseudo_tqdm as inner_tqdm
- else:
- inner_tqdm = tqdm
-
- self.step = 0
- self.epoch = 0
- self.pbar = inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}',
- leave=False, dynamic_ncols=True, disable=not self.is_master)
- pbar = self.pbar
- avg_loss = 0
- data_iterator = self.data_iterator
- self.ddp_model.zero_grad()
- self.batch_per_epoch = self.data_iterator.num_batches
- for epoch in range(1, self.n_epochs + 1):
- self.epoch = epoch
- pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
- # early stopping
- self.callback_manager.on_epoch_begin()
- for batch_x, batch_y in data_iterator:
- self.step += 1
- if self.step%self.update_every!=0:
- no_sync = self.ddp_model.no_sync
- else:
- no_sync = contextlib.ExitStack
- with no_sync():
- self.ddp_model.train()
- _move_dict_value_to_device(batch_x, batch_y, device=self.device, non_blocking=self.pin_memory)
- indices = data_iterator.get_batch_indices()
- # negative sampling; replace unknown; re-weight batch_y
- self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
- with self.auto_cast():
- prediction = self._data_forward(self.ddp_model, batch_x)
- # edit prediction
- self.callback_manager.on_loss_begin(batch_y, prediction)
- loss = self._compute_loss(prediction, batch_y)
-
- avg_loss += loss.detach()
-
- # Is loss NaN or inf? requires_grad = False
- self.callback_manager.on_backward_begin(loss)
- self._grad_backward(loss)
- self.callback_manager.on_backward_end()
- self._update()
- self.callback_manager.on_step_end()
-
- if self.step % self.print_every == 0:
- avg_loss = float(avg_loss) / self.print_every
- print_output = "loss:{:<6.5f}".format(avg_loss)
- pbar.update(self.print_every)
- pbar.set_postfix_str(print_output)
- avg_loss = 0
-
- self.callback_manager.on_batch_end()
-
- if (self.validate_every > 0 and self.step % self.validate_every == 0) and len(self.test_manager.callbacks):
- self._do_validation()
-
- # ================= mini-batch end ==================== #
- if self.validate_every < 0 and len(self.test_manager.callbacks):
- self._do_validation()
-
- # lr decay; early stopping
- self.callback_manager.on_epoch_end()
- # =============== epochs end =================== #
- pbar.close()
- self.pbar = None
- # ============ tqdm end ============== #
-
- def _clear_grad(self, optimizer):
- if self.set_grad_to_none:
- for group in optimizer.param_groups:
- for p in group['params']:
- if p.grad is not None:
- p.grad = None
- else:
- optimizer.zero_grad()
-
- def _grad_backward(self, loss):
- r"""Compute gradient with link rules.
-
- :param loss: a scalar where back-prop starts
-
- For PyTorch, just do "loss.backward()"
- """
- if (self.step-1) % self.update_every == 0:
- self._clear_grad(self.optimizer)
- self.grad_scaler.scale(loss).backward()
-
- def _update(self):
- r"""Perform weight update on a model.
-
- """
- if self.step % self.update_every == 0:
- self.grad_scaler.step(self.optimizer)
- self.grad_scaler.update()
-
- def _data_forward(self, network, x):
- x = _build_args(self._forward_func, **x)
- y = network(**x)
- if not isinstance(y, dict):
- raise TypeError(
- f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.")
- return y
-
- def _compute_loss(self, predict, truth):
- r"""Compute loss given prediction and ground truth.
-
- :param predict: prediction dict, produced by model.forward
- :param truth: ground truth dict, produced by batch_y
- :return: a scalar
- """
- loss = self.losser(predict, truth)
- if self.update_every > 1:
- loss = loss / self.update_every
- if loss.dim() > 0:
- loss = loss.mean()
- return loss
-
- def save_check_point(self, name=None, only_params=False):
- r"""保存当前模型"""
- # only master save models
- if name is None:
- name = 'checkpoint-{}.bin'.format(self.step)
- os.makedirs(self.cp_save_path, exist_ok=True)
- path = os.path.join(self.cp_save_path, name)
- self.logger.info("Save checkpoint to {}".format(path))
- model_to_save = self.ddp_model.module
- if only_params:
- model_to_save = model_to_save.state_dict()
- if self.is_master:
- torch.save(model_to_save, path)
-
- def load_check_point(self, name):
- path = os.path.join(self.cp_save_path, name)
- self.logger.info('reload best model from %s', path)
- model_load = torch.load(
- path,
- map_location=lambda s, l: default_restore_location(s, "cpu"))
- if not isinstance(model_load, dict):
- model_load = model_load.state_dict()
- self.model.load_state_dict(model_load)
-
- def _best_save_name(self, auto_fix=True):
- best_name = "best_" + "_".join([self.model.__class__.__name__, str(self.metric_key), self.start_time])
- return best_name
-
- def _do_validation(self):
- with self.ddp_model.no_sync():
- # 因为模型参数不更新,可以关闭同步
- self.callback_manager.on_valid_begin()
- eval_res = self.test_manager.on_valid_begin()
- eval_res = list(filter(lambda x: x is not None, eval_res))
- if len(eval_res):
- eval_res, is_better = list(zip(*eval_res))
- eval_res = eval_res[0]
- is_better = is_better[0]
- else:
- eval_res, is_better = None, None
- if self.metric_key is None and eval_res is not None:
- eval_res0 = list(eval_res.values())[0]
- self.metric_key = list(eval_res0.keys())[0]
- # logger.info('{}, {}'.format(eval_res, is_better))
- # save better model on master node
- if is_better is not None and self.cp_save_path:
- if is_better:
- self.save_check_point(self._best_save_name(), only_params=False)
- dist.barrier()
-
- if not self.is_master and self.metric_key is None:
- # 主进程自动得到了metric_key,而其它进程没有
- prefix = 'best_' + self.model.__class__.__name__
- suffix = self.start_time
- fn_list = os.listdir(self.cp_save_path)
- fn_list = [fn for fn in fn_list if fn.startswith(prefix) and fn.endswith(suffix)]
- if len(fn_list) == 1:
- best_name = fn_list[0]
- self.metric_key = best_name[len(prefix):-len(suffix)].strip('_')
- # print('RANK {} metric_key {}'.format(self.rank, self.metric_key))
- self.callback_manager.on_valid_end(
- eval_res, self.metric_key, self.optimizer, is_better)
- self.ddp_model.train()
-
- def close(self):
- r"""关闭Trainer,销毁进程"""
- dist.destroy_process_group()
diff --git a/fastNLP/core/drivers/__init__.py b/fastNLP/core/drivers/__init__.py
new file mode 100644
index 00000000..14c004d1
--- /dev/null
+++ b/fastNLP/core/drivers/__init__.py
@@ -0,0 +1,39 @@
+__all__ = [
+ 'Driver',
+ 'TorchDriver',
+ "TorchSingleDriver",
+ "TorchDDPDriver",
+ "FairScaleDriver",
+ "TorchFSDPDriver",
+ "DeepSpeedDriver",
+ "PaddleDriver",
+ "PaddleSingleDriver",
+ "PaddleFleetDriver",
+ "JittorDriver",
+ "JittorSingleDriver",
+ "JittorMPIDriver",
+ 'PaddleDriver',
+ 'PaddleSingleDriver',
+ 'PaddleFleetDriver',
+ 'JittorDriver',
+ 'JittorSingleDriver',
+ 'JittorMPIDriver',
+ 'OneflowDriver',
+ 'OneflowSingleDriver',
+ 'OneflowDDPDriver',
+ 'torch_seed_everything',
+ 'paddle_seed_everything',
+ 'oneflow_seed_everything',
+ 'optimizer_state_to_device'
+]
+
+from .torch_driver import TorchDriver, TorchSingleDriver, TorchDDPDriver, DeepSpeedDriver, FairScaleDriver, \
+ TorchFSDPDriver, torch_seed_everything, optimizer_state_to_device
+from .jittor_driver import JittorDriver, JittorMPIDriver, JittorSingleDriver
+from .paddle_driver import PaddleDriver, PaddleFleetDriver, PaddleSingleDriver, paddle_seed_everything
+from .oneflow_driver import OneflowDriver, OneflowSingleDriver, OneflowDDPDriver, oneflow_seed_everything
+from .driver import Driver
+
+
+
+
diff --git a/fastNLP/core/drivers/choose_driver.py b/fastNLP/core/drivers/choose_driver.py
new file mode 100644
index 00000000..003eee90
--- /dev/null
+++ b/fastNLP/core/drivers/choose_driver.py
@@ -0,0 +1,49 @@
+from typing import Union, Optional, List
+
+from .driver import Driver
+from ..utils import is_torch_module, is_paddle_module, is_jittor_module, is_oneflow_module
+
+__all__ = []
+
+def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver:
+ r"""
+ 根据输入的参数 ``driver`` 和 ``device`` 的格式来决定具体的工作模式。
+
+ :param model: 运行过程中使用的具体的最原始的模型。
+ :param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:``["auto", "torch", "paddle", "jittor", "fairscale", "deepspeed", "oneflow", "torch_fsdp"]``,分别对应
+ 各种框架。值为 ``'auto'`` 时,将会根据模型的类型进行选择。
+ :param device: 训练使用的设备。详细的格式可以查阅 :class:`~fastNLP.core.controllers.Trainer` 中的说明。
+ :param kwargs: 其余的传给 `Driver` 的参数。
+ """
+
+ # 如果用户直接传进来一个 driver 实例,我们就直接返回回去,目前用户需要自己保证传进来的 driver 的正确性;
+ if isinstance(driver, Driver):
+ return driver
+
+ if driver == "auto":
+ if is_torch_module(model):
+ driver = "torch"
+ elif is_paddle_module(model):
+ driver = "paddle"
+ elif is_jittor_module(model):
+ driver = "jittor"
+ elif is_oneflow_module(model):
+ driver = "oneflow"
+ else:
+ raise ValueError(f"Cannot choose driver automatically based on model, please set `driver` specifically.")
+
+ if driver in {"torch", "fairscale", "deepspeed", "torch_fsdp"}:
+ from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver
+ return initialize_torch_driver(driver, device, model, **kwargs)
+ elif driver in {"jittor"}:
+ from fastNLP.core.drivers.jittor_driver.initialize_jittor_driver import initialize_jittor_driver
+ return initialize_jittor_driver(driver, device, model, **kwargs)
+ elif driver in {"paddle"}:
+ from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver
+ return initialize_paddle_driver(driver, device, model, **kwargs)
+ elif driver in {"oneflow"}:
+ from fastNLP.core.drivers.oneflow_driver.initialize_oneflow_driver import initialize_oneflow_driver
+ return initialize_oneflow_driver(driver, device, model, **kwargs)
+ else:
+ raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale', "
+ "'jittor', 'paddle', 'oneflow'].")
\ No newline at end of file
diff --git a/fastNLP/core/drivers/driver.py b/fastNLP/core/drivers/driver.py
new file mode 100644
index 00000000..8bdc9e75
--- /dev/null
+++ b/fastNLP/core/drivers/driver.py
@@ -0,0 +1,453 @@
+import os
+import signal
+import sys
+from typing import Sequence, List, Optional, Callable, Dict, Union, Tuple
+from abc import ABC, abstractmethod
+from datetime import datetime
+from pathlib import Path
+from io import BytesIO
+import json
+
+__all__ = [
+ 'Driver'
+]
+
+from fastNLP.core.utils import nullcontext
+
+
+class Driver(ABC):
+ r"""
+ 用来初始化 `Driver` 的基类,所有定制的 `driver` 都需要继承此类;
+ **fastNLP** 提供的 driver 实例都会同时被 :class:`~fastNLP.core.controllers.Trainer` 和 :class:`~fastNLP.core.controllers.Evaluator` 调用。
+
+ :param model: 训练或者评测的模型,需要注意该模型可能为用户已经使用类似 :class:`torch.nn.DataParallel` 或者
+ :class:`torch.nn.parallel.DistributedDataParallel` 包裹过的模型。
+ """
+
+ def __init__(self, model):
+ self.model = model
+
+ # 这些属性用于 open_subprocess 和 on_exception 函数协同配合;
+ # self._consensus_file: Optional[Union[str, Path]] = None
+ self._pids: Optional[List[int]] = None
+
+ @abstractmethod
+ def setup(self):
+ r"""
+ 该函数用来初始化训练环境,例如将模型迁移到对应的设备上等。
+ 多卡的 ``driver`` 的该函数要更为复杂一些,例如其可能需要开启多进程之间的通信环境,以及设置一些环境变量和其余所需要的变量值;
+ """
+
+ def set_dist_repro_dataloader(self, dataloader, dist=None, reproducible: bool = False):
+ r"""
+ 根据输入的 ``dataloader`` 得到一个 支持分布式 (**distributed**) 与 可复现的 (**reproducible**) 的 dataloader。
+
+ :param dataloader: 根据 ``dataloader`` 设置其对应的分布式版本以及可复现版本。
+ :param dist: 应当为一个字符串,其值应当为以下之一:``[None, "dist", "unrepeatdist"]``,并且根据在 :class:`~fastNLP.core.controllers.Trainer`
+ 和 :class:`~fastNLP.core.controllers.Evaluator` 中 *kwargs* 的参数 ``use_dist_sampler`` 和调用时机不同,对应不同的值:
+
+ * 当 ``use_dist_sampler`` 为 ``False`` ,且在 :class:`~fastNLP.core.controllers.Trainer` 或 :class:`~fastNLP.core.controllers.Evaluator`
+ **初始化** 中被调用时,参数值为 ``None`` ,表示不需要考虑当前 ``dataloader`` 切换为分布式状态;
+ * 当 ``use_dist_sampler`` 为 ``True`` ,且在 :class:`~fastNLP.core.controllers.Trainer` **初始化** 中被调用时,参数值为 ``"dist"`` ,表示该
+ ``dataloader`` 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample 在不同 gpu 上出现重复;
+ * 当 ``use_dist_sampler`` 为 ``True`` ,且在 :class:`~fastNLP.core.controllers.Evaluator` **初始化** 中被调用时,参数值为 ``"unrepeatdist"`` ,
+ 表示该 ``dataloader`` 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的数据,允许不同 gpu 上 batch 的数量不一致;
+ * 当 **断点重训加载** 中调用 :meth:`load_checkpoint` 时,该函数也会被调用,且 ``dist`` 值为 :class:`~fastNLP.core.samplers.ReproducibleSampler`
+ 或 :class:`~fastNLP.core.samplers.ReproducibleBatchSampler` ,此时表示需要用 ``dist`` 代表的 sampler 或 batch_sampler 重新实例化一个新的 dataloader;
+
+ :param reproducible: 如果为 ``False``,不要做任何考虑;如果为 ``True``,需要保证返回的 dataloader 可以保存当前的迭代状态,使得
+ 该状态可以加载到一个全新的 dataloader 中然后恢复其状态。
+ :return: 应当返回一个被替换 sampler 后的 **新的** dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外,
+ 如果传入的 ``dataloader`` 中是 :class:`~fastNLP.core.samplers.ReproducibleSampler` 或者 :class:`~fastNLP.core.samplers.ReproducibleBatchSampler`
+ 需要 **重新初始化** 一个放入返回的 dataloader 中。如果 ``dist`` 为空,且 ``reproducible`` 为 ``False``,可直接返回原对象。
+ """
+ if dist is None and reproducible is False:
+ return dataloader
+ raise NotImplementedError(f"Driver:{self.__class__.__name__} does not support `set_dist_repro_dataloader` "
+ f"function.")
+
+ def set_deterministic_dataloader(self, dataloader):
+ r"""
+ 为了确定性训练要对 ``dataloader`` 进行修改,保证在确定随机数种子后,每次重新训练得到的结果是一样的;例如对于 **pytorch** 的 ``dataloader``,其
+ 需要将 ``worker_init_fn`` 替换。
+ """
+
+ def set_sampler_epoch(self, dataloader, cur_epoch_idx):
+ r"""
+ 对于分布式的 ``sampler``,例如 **pytorch** 的 :class:`DistributedSampler`,其需要在每一个 ``epoch`` 前设置随机数种子,来保证每一个进程上的 ``shuffle`` 是一样的;
+ ``dataloader`` 中可能真正发挥作用的是 ``batch_sampler`` 也可能是 ``sampler``。
+
+ :param dataloader: 需要设置 ``epoch`` 的 ``dataloader``
+ :param cur_epoch_idx: 当前是第几个 ``epoch``
+ """
+
+ @abstractmethod
+ def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
+ r"""
+ 通过调用 ``fn`` 来实现训练时的前向传播过程;
+ 注意 :class:`~fastNLP.core.controllers.Trainer` 和 :class:`~fastNLP.core.controllers.Evaluator` 会调用该函数来
+ 实现网络的前向传播过程,其中传入该函数的参数 ``fn`` 是函数 :meth:`get_model_call_fn` 所返回的函数。
+
+ :param batch: 当前的一个 batch 的数据;可以为字典或者其它类型。
+ :param fn: 调用该函数进行一次计算。
+ :param signature_fn: 由 :class:`~fastNLP.core.controllers.Trainer` 传入的用于网络前向传播一次的签名函数,因为当
+ batch 是一个 :class:`Dict` 的时候,我们会自动调用 :func:`fastNLP.core.utils.auto_param_call` 函数,而一些被
+ 包裹的模型需要暴露其真正的函数签名,例如 :class:`DistributedDataParallel` 的调用函数是 ``forward``,但是需要其
+ 函数签名为 ``model.module.forward``。
+ :return: 由 ``fn`` 返回的结果(应当为一个 :class:`dict` 或者 :class:`dataclass` ,但是不需要我们去检查)。
+ """
+ raise NotImplementedError("Each specific driver should implemented its own `model_call` function.")
+
+ @abstractmethod
+ def get_model_call_fn(self, fn: str) -> Tuple:
+ r"""
+ 该函数会接受 :class:`~fastNLP.core.controllers.Trainer` 的 ``train_fn`` 或者 :class:`~fastNLP.core.controllers.Evaluator`
+ 的 ``evaluate_fn``,返回一个实际用于调用 :meth:`model_call` 时传入的函数参数;该函数会由 :class:`~fastNLP.core.controllers.Trainer`
+ 和 :class:`~fastNLP.core.controllers.Evaluator` 在 :func:`driver.setup` 函数之后调用。
+
+ 之所以设置该函数的目的在于希望将具体的 model_call function 从 driver 中抽离出来,然后将其附着在 ``Trainer`` 或者 ``Evaluator`` 身上;
+ 这样是因为在新版的设计中,使用 model 的哪种方法来进行 ``train step`` 或者 ``evaluate step`` 是通过额外的参数 ``train_fn`` 和
+ ``evaluate_fn`` 来确定的,而二者又分别是通过 ``Trainer`` 和 ``Evaluator`` 来控制的;因此不能将确定具体的 ``train step fn`` 和
+ ``evaluate step fn`` 的逻辑放在每一个 driver 的初始化的时候(因此在 ``Trainer`` 初始化第一个 driver 时,``Evaluator`` 还没有初始化,但是
+ ``evaluate step fn`` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中.
+
+ 这一函数应当通过参数 ``fn`` 来判断应当返回的实际的调用的函数,具体逻辑如下所示:
+
+ 1. 如果 ``fn`` == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 ``fn``,则默认调用模型的 :meth:`forward`
+ 函数,然后给出 warning;
+ 2. 如果 ``fn`` 是其他字符串,那么如果模型没有定义方法 ``fn`` 则直接报错;
+
+ 注意不同的 driver 需要做额外的检测处理,例如在 :class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver` 中,当传入的模型本身就是
+ :class:`DistributedDataParallel` 时,我们只能调用模型的 :meth:`forward` 函数,因此需要额外的 warning;这一点特别需要注意的问题在于
+ driver 自己在 setup 时也会对模型进行改变( :class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver` ),因此可能需要额外标记最初
+ 传入 driver 的模型是哪种形式的.
+
+ :param fn: 一个字符串,该函数通过该字符串判断要返回模型的哪种方法
+ :return: 一个元组,包含两个函数,用于在调用 :meth:`model_call` 时传入
+ """
+ raise NotImplementedError("Each specific driver should implemented its own `get_model_call_fn` function.")
+
+ @property
+ def model(self):
+ r"""
+ :return: driver 中在实际训练或者评测时所使用的模型。
+ """
+ return self._model
+
+ @model.setter
+ def model(self, model):
+ self._model = model
+
+ @property
+ def optimizers(self) -> List:
+ r"""
+ 如下所示,driver 返回的 :attr:`optimizers` 一定是一个 :class:`List`,如果用户直接向 :class:`~fastNLP.core.controllers.Trainer` 传入一个单独的 optimizer,
+ 我们会使用一个 List 将其包裹;
+
+ :return: List[optimizer0, optimizer1, optimizer2, ...]
+ """
+ return self._optimizers
+
+ @optimizers.setter
+ def optimizers(self, optimizers):
+ if not isinstance(optimizers, Sequence):
+ self._optimizers = [optimizers]
+ else:
+ self._optimizers = optimizers
+ self._check_optimizer_legality(self._optimizers)
+
+ @property
+ def model_device(self):
+ r"""
+ :return: driver 中模型实际所在的设备。
+ """
+ return self._model_device
+
+ @model_device.setter
+ def model_device(self, model_device):
+ self._model_device = model_device
+
+ @property
+ def data_device(self):
+ """
+ :return: driver 中数据默认会被迁移到的设备。
+ """
+ return self.model_device
+
+ @staticmethod
+ def _check_optimizer_legality(optimizers):
+ r"""
+ 对于用户传入 trainer 的每一个 optimizer,检测其是否合理,因为不同的深度学习框架所使用的的 optimizer 是不相同的。
+
+ :param optimizers: 需要检测的 `optimizers`。
+ """
+ raise NotImplementedError(
+ "Each specific driver should implemented its own `_check_optimizer_legality` function.")
+
+ def check_dataloader_legality(self, dataloader):
+ """
+ 检测 ``dataloader`` 是否合法,如果不合法,会 ``raise TypeError`` 。
+
+ :param dataloder:
+ """
+
+ def set_optimizers(self, optimizers=None):
+ r"""
+ trainer 会调用该函数将用户传入的 ``optimizers`` 挂载到 driver 实例上。
+ """
+ self.optimizers = optimizers
+
+ @abstractmethod
+ def backward(self, loss):
+ r"""
+ 实现深度学习中的反向传播过程。
+
+ :param loss: 用来实现反向传播的损失函数值
+ """
+ raise NotImplementedError("Each specific driver should implemented its own `backward` function.")
+
+ @abstractmethod
+ def step(self):
+ r"""
+ 实现深度学习中的参数的优化更新过程,应当直接通过优化器 :attr:`optimizers` 来更新参数。
+ """
+ raise NotImplementedError("Each specific driver should implemented its own `step` function.")
+
+ @abstractmethod
+ def zero_grad(self):
+ r"""
+ 实现深度学习中的梯度的置零操作,应当直接通过优化器 :attr:`optimizers` 来将梯度置零;
+ 注意梯度累积不需要在这里实现,trainer 已经在内部实现了梯度累积。
+
+ """
+ raise NotImplementedError("Each specific driver should implemented its own `zero_grad` function.")
+
+ def get_model_no_sync_context(self):
+ r"""
+ 返回一个用于关闭多进程之间 model 中的自动互相同步操作的 context 上下文对象;只有多卡的 driver 需要单独实现该函数,
+ 单卡的 driver 不需要。
+
+ :return: 一个类似于 ``DistributedDataParallel(model).no_sync`` 的 context 上下文对象
+ """
+ return nullcontext
+
+ def get_evaluate_context(self):
+ r"""
+ 返回一个不计算梯度的环境用来对模型进行评测。
+
+ :return: 一个类似 ``torch.no_grad`` 的 context 上下文对象
+ """
+ return nullcontext
+
+ @property
+ def auto_cast(self):
+ r"""
+ fp16 的上下文环境。
+
+ :return: 一个用于 fp16 计算的上下文环境
+ """
+ return self._auto_cast
+
+ @auto_cast.setter
+ def auto_cast(self, auto_cast):
+ self._auto_cast = auto_cast
+
+ @abstractmethod
+ def save_model(self, filepath: Union[str, Path, BytesIO], only_state_dict: bool = True, **kwargs):
+ r"""
+ 保存模型的函数;注意函数 :meth:`save_checkpoint` 是用来进行断点重训的函数。
+
+ :param filepath: 保存文件的文件位置(需要包括文件名)或一个 BytesIO 对象
+ :param only_state_dict: 是否只保存模型的 `state_dict`
+ """
+ raise NotImplementedError("Each specific driver should implemented its own `save_model` function.")
+
+ @abstractmethod
+ def load_model(self, filepath: Union[str, Path, BytesIO], only_state_dict: bool = False, **kwargs):
+ r"""
+ 加载模型的函数;将 ``filepath`` 中的模型加载并赋值给当前 model 。
+
+ :param filepath: 需要被加载的对象的文件位置(需要包括文件名)或一个 ``BytesIO`` 对象。
+ :param load_state_dict: 保存的文件是否只是模型的权重,还是完整的模型。即便是保存的完整的模型,此处也只能使用尝试加载filepath
+ 模型中的权重到自身模型,而不会直接替代当前 Driver 中的模型。
+ """
+ raise NotImplementedError("Each specific driver should implemented its own `load_model` function.")
+
+ @abstractmethod
+ def save_checkpoint(self, folder, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True,
+ **kwargs):
+
+ r"""
+ 断点重训的保存函数,该函数会负责保存优化器、fp16 状态和 sampler 的状态,以及模型的保存(若 ``should_save_model`` 为 ``True``)
+
+ :param folder: 保存断点重训的状态的文件夹;:meth:`save_checkpoint` 函数应该在该路径下面下面新增名为 ``FASTNLP_CHECKPOINT_FILENAME`` 与
+ ``FASTNLP_MODEL_FILENAME`` (如果 ``should_save_model`` 为 ``True`` )的文件。把 model 相关的内容放入到 ``FASTNLP_MODEL_FILENAME`` 文件
+ 中,将传入的 ``states`` 以及自身产生的其它状态一并保存在 ``FASTNLP_CHECKPOINT_FILENAME`` 里面。
+ :param states: 由 :class:`~fastNLP.core.controllers.Trainer` 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态。Driver 应该
+ 只需要保存该对象而不需要理解该对象,同时在 :meth:`load_checkpoint` 的时候需要将 ``states`` 返回回去,返回的值与这里传入的值保持一致。
+ :param dataloader: 正在使用的 dataloader,需要保存里面的状态使得之后可以从当前迭代的位置恢复。
+ :param only_state_dict: 是否只保存模型的参数,当 ``should_save_model`` 为 ``False`` ,该参数无效。
+ :param should_save_model: 是否应该保存模型,如果为 ``False`` ,Driver 将不负责 model 的保存。
+ """
+ raise NotImplementedError("Each specific driver should implemented its own `save_checkpoint` function.")
+
+ @abstractmethod
+ def load_checkpoint(self, folder: Union[str, Path], dataloader, only_state_dict: bool = True, should_load_model: bool = True,
+ **kwargs) -> Dict:
+ r"""
+ 断点重训的加载函数,该函数会负责读取数据,并且恢复优化器 、sampler 的状态和模型(如果 ``should_load_model`` 为 True)以及其它在 :meth:`save_checkpoint`
+ 函数中执行的保存操作,然后将一个 state 字典返回给 :class:`~fastNLP.core.controllers.Trainer` ( 内容为 :meth:`save_checkpoint` 接受到的 ``states`` )。
+
+ 该函数应该在所有 rank 上执行。
+
+ :param folder: 读取该 folder 下的 ``FASTNLP_CHECKPOINT_FILENAME`` 文件与 ``FASTNLP_MODEL_FILENAME``
+ (如果 should_load_model 为True)。
+ :param dataloader: 当前给定 dataloader,需要根据保存的 dataloader 状态合理设置。若该值为 ``None`` ,则不需要返回 ``'dataloader'``
+ 以及 ``'batch_idx_in_epoch'`` 这两个值。
+ :param only_state_dict: 是否仅读取模型的 state_dict ,当 ``should_save_model`` 为 ``False`` ,该参数无效。如果为 ``True`` ,说明保存的内容为权重;如果为
+ False 说明保存的是模型,但也是通过当前 Driver 的模型去加载保存的模型的权重,而不是使用保存的模型替换当前模型。
+ :param should_load_model: 是否应该加载模型,如果为 ``False`` ,Driver 将不负责加载模型。若该参数为 ``True`` ,但在保存的状态中没有
+ 找到对应的模型状态,则报错。
+ :return: :meth:`save_checkpoint` 函数输入的 ``states`` 内容。除此之外,还返回的内容有:
+
+ * *dataloader* -- 根据传入的 ``dataloader`` 与读取出的状态设置为合理状态的 dataloader。在当前 ``dataloader`` 样本数与读取出的 sampler 样本数
+ 不一致时报错。
+ * *batch_idx_in_epoch* -- :class:`int` 类型的数据,表明当前 epoch 进行到了第几个 batch 。请注意,该值不能仅通过保存的数据中读取的,因为前后两次运行的
+ ``batch_size`` 可能有变化,而应该符合以下等式::
+
+ 返回的 dataloader 还会产生的 batch 数量 + batch_idx_in_epoch = 原来不断点训练时的 batch 的总数
+
+ 由于 ``返回的 dataloader 还会产生的batch数`` 在 ``batch_size`` 与 ``drop_last`` 参数给定的情况下,无法改变,因此只能通过调整 ``batch_idx_in_epoch``
+ 这个值来使等式成立。一个简单的计算原则如下:
+
+ * drop_last 为 ``True`` 时,等同于 floor(sample_in_this_rank/batch_size) - floor(num_left_samples/batch_size);
+ * drop_last 为 ``False`` 时,等同于 ceil(sample_in_this_rank/batch_size) - ceil(num_left_samples/batch_size)。
+ """
+
+ @staticmethod
+ def tensor_to_numeric(tensor, reduce: Optional[str] = None):
+ r"""
+ 将一个 ``tensor`` 对象(仅处理当前 driver 使用的 tensor 即可)转换为 python 的 ``numeric`` 对象;如果 ``tensor`` 只包含一个
+ 元素则返回 ``float`` 或 ``int``。
+
+ :param tensor: 需要被转换的 ``tensor`` 对象
+ :param reduce: 可选 ``['sum', 'max', 'mea', 'min']``,如果不为 ``None`` 将使用该 ``reduce`` 方法来处理当前 ``tensor`` 再返回
+ :class:`float` 或 :class:`int` 对象
+ :return: 转换后返回的结果
+ """
+ raise NotImplementedError("Each specific driver should implemented its own `tensor_to_numeric` function.")
+
+ @abstractmethod
+ def set_model_mode(self, mode: str):
+ r"""
+ 设置模型为 ``train`` 或 ``eval`` 的模式;目的是为切换模型的训练和推理(会关闭 dropout 等)模式。
+
+ :param mode: 应为二者之一:``["train", "eval"]``
+ """
+
+ def unwrap_model(self):
+ r"""
+ 保证用户拿到的模型一定是最原始的模型;
+ 注意因为我们把保存模型的主要逻辑和代码移到了 `Driver` 中,因此在 :meth:`save_model` 函数中,一定要先调用此函数来保证我们保存的模型一定是
+ 最为原始的模型;
+ 需要注意用户本身传入的模型就是经过类似 :class:`torch.nn.DataParallel` 或者 :class:`torch.nn.parallel.DistributedDataParallel` 包裹的模型,
+ 因此在该函数内需要先判断模型的类别。
+
+ :return: 最原始的模型,例如没有被 :class:`DistributedDataParallel` 包裹的模型。
+ """
+
+ @staticmethod
+ def move_model_to_device(model, device):
+ r"""
+ 用来将模型转移到指定的 ``device`` 上;
+ 之所以写成 :class:`staticmethod`,是因为一方面在 `Driver` 中我们要使用 :meth:`unwrap_model` 来拿到最原始的模型,另一方面,在 :meth`save_model`
+ 中,我们需要先将模型移到 cpu 后,又再移到 gpu 上,因此不适宜在该函数内部调用 :meth:`unwrap_model`,而是将 ``model`` 作为该函数的参数。
+ """
+
+ @abstractmethod
+ def move_data_to_device(self, batch):
+ r"""
+ 将数据迁移到指定的机器上;``batch`` 是包含了张量的数据集合,可以是 **List**、**Dict** 等嵌套类型。
+
+ :return: 移动到指定机器上的 ``batch`` 对象
+ """
+
+ def get_local_rank(self) -> int:
+ r"""
+ 返回当前的 ``local_rank``,本函数的返回值只在运行分布式训练的时候有实际含义。
+
+ :return: 一个整数值,表示当前进程在当前这台机器上的序号
+ """
+ return 0
+
+ def barrier(self):
+ r"""
+ 用于在多进程工作时同步各进程的工作进度,运行快的进程运行到这里会等待运行慢的进程,只有所有进程都运行到此函数时,所有的进程才会继续运行;
+ 仅在多分布式训练场景中有使用。
+
+ 注意,该函数的行为会受到环境变量 ``FASTNLP_NO_SYNC`` 的影响。仅当 ``FASTNLP_NO_SYNC`` 在 ``os.environ`` 中不存在,或小于 **1** 时
+ 才真的执行 :meth:`barrier`。
+ """
+
+ def is_distributed(self) -> bool:
+ r"""
+ 当前的 driver 实例是否是分布式的。
+
+ :return: 一个 bool 值,如果当前的 driver 实例是用于分布式的,那么返回 ``True``
+ """
+ return False
+
+ def on_exception(self):
+ r"""
+ 该函数用于在训练或者预测过程中出现错误时正确地关掉其它的进程,这一点是通过在多进程 driver 调用 :meth:`open_subprocess` 的时候将每一个进程
+ 的 pid 记录下来,然后在出现错误后,由出现错误的进程手动地将其它进程 kill 掉。
+
+ 因此,每一个多进程 driver 如果想要该函数能够正确地执行,其需要在自己的 :meth:`open_subprocess` (开启多进程的函数)中正确地记录每一个进程的
+ pid 的信息;单卡 driver 不需要这个函数。
+ """
+ # 单卡 driver 不需要这个函数;
+ if self._pids is not None:
+ exc_type, exc_value, exc_traceback_obj = sys.exc_info()
+ _write_exc_info = {
+ 'exc_type': str(exc_type.__name__),
+ 'exc_value': str(exc_value),
+ 'exc_time': str(datetime.now().strftime('%Y-%m-%d-%H:%M:%S')),
+ 'exc_global_rank': getattr(self, "global_rank", None),
+ 'exc_local_rank': self.get_local_rank(),
+ }
+ sys.stderr.write("\nException info:\n")
+ sys.stderr.write(json.dumps(_write_exc_info, indent=2) + "\n")
+
+ sys.stderr.write(f"Start to stop these pids:{self._pids}, please wait several seconds.\n")
+ for pid in self._pids:
+ if pid != os.getpid():
+ os.kill(pid, signal.SIGKILL)
+
+ def broadcast_object(self, obj, src: int = 0, group=None, **kwargs):
+ r"""
+ 从 ``src`` 端将 ``obj`` 对象(可能是 ``tensor``,可能是 ``object`` )broadcast 到其它所有进程。如果是非 ``tensor`` 的对象会尝试使用 ``pickle`` 进行打包进行
+ 传输,然后再 ``dst`` 处再加载回来。仅在分布式的 ``driver`` 中有实际意义。
+
+ :param obj: obj,可能是 ``Tensor`` 或 嵌套类型的数据
+ :param src: source 的 ``global rank``
+ :param group: 所属的通信组
+ :return: 输入的 ``obj``
+ """
+ if not self.is_distributed():
+ return obj
+ raise NotImplementedError(f"Driver:{self.__class__.__name__} does not support `broadcast_object` method right "
+ f"now.")
+
+ def all_gather(self, obj, group) -> List:
+ r"""
+ 将 ``obj`` 互相传送到其它所有的 rank 上,其中 ``obj`` 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过
+ pickle 进行序列化,接收到之后再反序列化。
+
+ :param obj: 可以是 ``float/int/bool/np.ndarray/{}/[]/Tensor`` 等类型的数据
+ :param group: 用于不同进程之间互相通信的通信组
+ :return: 返回值应该是 ``[obj0, obj1, ...]``,其中 ``obj0`` 是 ``rank0`` 上的对象,``obj1`` 是 ``rank1`` 上的对象。以此类推
+ """
+ if not self.is_distributed():
+ return [obj]
+ raise NotImplementedError(f"Driver:{self.__class__.__name__} does not support `all_gather` method right "
+ f"now.")
diff --git a/fastNLP/core/drivers/jittor_driver/__init__.py b/fastNLP/core/drivers/jittor_driver/__init__.py
new file mode 100644
index 00000000..701fb04b
--- /dev/null
+++ b/fastNLP/core/drivers/jittor_driver/__init__.py
@@ -0,0 +1,9 @@
+__all__ = [
+ "JittorDriver",
+ "JittorSingleDriver",
+ "JittorMPIDriver",
+]
+
+from .jittor_driver import JittorDriver
+from .single_device import JittorSingleDriver
+from .mpi import JittorMPIDriver
\ No newline at end of file
diff --git a/fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py b/fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py
new file mode 100644
index 00000000..4e37342b
--- /dev/null
+++ b/fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py
@@ -0,0 +1,40 @@
+from typing import Union, List
+
+from fastNLP.core.drivers.jittor_driver.mpi import JittorMPIDriver
+from fastNLP.core.drivers.jittor_driver.jittor_driver import JittorDriver
+from fastNLP.core.drivers.jittor_driver.single_device import JittorSingleDriver
+from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
+
+if _NEED_IMPORT_JITTOR:
+ import jittor
+
+__all__ = []
+
+def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], model: "jittor.Module", **kwargs) -> JittorDriver:
+ r"""
+ 用来根据参数 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去。
+
+ .. todo::
+
+ 创建多卡的 driver
+
+ :param driver: 该参数的值应为以下之一:``["jittor"]``
+ :param device: ``jittor`` 运行的设备
+ :param model: 训练或者评测的具体的模型
+ :param kwargs:
+
+ :return: :class:`~fastNLP.core.JittorSingleDriver` 或 :class:`~fastNLP.core.JittorMPIDriver` 实例;
+ """
+
+ if driver not in {"jittor"}:
+ raise ValueError("Parameter `driver` can only be one of these values: ['jittor'].")
+
+ # TODO 实现更详细的判断
+ if device in ["cpu", "gpu", "cuda", None]:
+ return JittorSingleDriver(model, device, **kwargs)
+ elif type(device) is int:
+ return JittorMPIDriver(model, device, **kwargs)
+ elif type(device) is list:
+ return JittorMPIDriver(model, device, **kwargs)
+ else:
+ raise NotImplementedError(f"Device={device}")
diff --git a/fastNLP/core/drivers/jittor_driver/jittor_driver.py b/fastNLP/core/drivers/jittor_driver/jittor_driver.py
new file mode 100644
index 00000000..f2fe3c9e
--- /dev/null
+++ b/fastNLP/core/drivers/jittor_driver/jittor_driver.py
@@ -0,0 +1,425 @@
+import os
+from pathlib import Path
+from typing import Union, Optional, Dict
+from dataclasses import dataclass
+
+from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
+from fastNLP.core.drivers.driver import Driver
+from fastNLP.core.dataloaders import JittorDataLoader
+from fastNLP.core.dataloaders import OverfitDataLoader
+from fastNLP.core.samplers import ReproducibleSampler, RandomSampler
+from fastNLP.core.log import logger
+from fastNLP.core.utils import apply_to_collection, nullcontext
+from fastNLP.envs import (
+ FASTNLP_MODEL_FILENAME,
+ FASTNLP_CHECKPOINT_FILENAME,
+)
+
+if _NEED_IMPORT_JITTOR:
+ import jittor as jt
+ from jittor import Module
+ from jittor.optim import Optimizer
+ from jittor.dataset import Dataset
+ from jittor.dataset import (
+ BatchSampler as JittorBatchSampler,
+ Sampler as JittorSampler,
+ RandomSampler as JittorRandomSampler,
+ SequentialSampler as JittorSequentialSampler
+ )
+
+ _reduces = {
+ 'max': jt.max,
+ 'min': jt.min,
+ 'mean': jt.mean,
+ 'sum': jt.sum
+ }
+
+__all__ = [
+ "JittorDriver",
+]
+
+class JittorDriver(Driver):
+ r"""
+ 实现了 **jittor** 框架训练功能的基本 ``Driver``。这个类被以下子类继承:
+
+ 1. :class:`~fastNLP.core.drivers.jittor_driver.JittorSingleDriver` :实现了使用单卡和 ``cpu`` 训练的具体功能;
+ 2. :class:`~fastNLP.core.drivers.jittor_driver.JittorMPIDriver` :实现了使用 ``mpi`` 启动 **jittor** 分布式训练的功能;
+
+ .. warning::
+
+ 您不应当直接初始化该类,然后传入给 ``Trainer``,换句话说,您应当使用该类的子类 ``JittorSingleDriver`` 和 ``TorchDDPDriver``,而不是
+ 该类本身。
+
+ .. note::
+
+ 您可以在使用 ``JittorSingleDriver`` 和 ``JittorMPIDriver`` 时使用 ``JittorDriver`` 提供的接口。
+
+ :param model: 训练时使用的 **jittor** 模型
+ :param fp16: 是否开启混合精度训练
+ :param jittor_kwargs:
+ """
+ def __init__(self, model, fp16: bool = False, jittor_kwargs: Dict = None, **kwargs):
+ if not isinstance(model, Module):
+ raise ValueError(f"Parameter `model` can not be `{type(model)}` in `JittorDriver`, it should be exactly "
+ f"`jittor.Module` type.")
+ super(JittorDriver, self).__init__(model)
+
+ if fp16:
+ jt.flags.auto_mixed_precision_level = 6
+ else:
+ jt.flags.auto_mixed_precision_level = 0
+ self.fp16 = fp16
+ self._auto_cast = nullcontext
+ self._jittor_kwargs = jittor_kwargs if jittor_kwargs is not None else {}
+
+ # 用来设置是否关闭 auto_param_call 中的参数匹配问题;
+ self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)
+
+ def check_dataloader_legality(self, dataloader):
+ """
+ 检测 DataLoader 是否合法。支持的类型包括 :class:`~fastNLP.core.dataloaders.JittorDataLoader`、 :class:`jittor.dataset.Dataset` 。
+
+ :param dataloder:
+ """
+ if not isinstance(dataloader, (Dataset, JittorDataLoader, OverfitDataLoader)):
+ raise TypeError(f"{Dataset} or {JittorDataLoader} is expected, instead of `{type(dataloader)}`")
+ if len(dataloader) == 0:
+ logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it "
+ "may cause some unexpected exceptions.", once=True)
+
+ @staticmethod
+ def _check_optimizer_legality(optimizers):
+ for each_optimizer in optimizers:
+ if not isinstance(each_optimizer, Optimizer):
+ raise TypeError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, "
+ f"not {type(each_optimizer)}.")
+
+ def step(self):
+ r"""
+ 实现参数的优化更新过程
+ """
+ for optimizer in self.optimizers:
+ optimizer.step()
+
+ def backward(self, loss):
+ """
+ 对 ``loss`` 进行反向传播
+ """
+ for optimizer in self.optimizers:
+ optimizer.backward(loss)
+
+ def zero_grad(self):
+ """
+ 实现梯度置零的过程
+ """
+ for optimizer in self.optimizers:
+ optimizer.zero_grad()
+
+ def save_model(self, filepath: Union[str, Path], only_state_dict: bool = True, **kwargs):
+ r"""
+ 将模型保存到 ``filepath`` 中。
+
+ :param filepath: 保存文件的文件位置
+ :param only_state_dict: 在 **Jittor** 中,该参数无效,因为 **Jittor** 仅支持保存模型的 ``state_dict``。
+ """
+ if not only_state_dict:
+ logger.rank_zero_warning(
+ "Jittor only supports saving state_dict, and we will also save state_dict for you.",
+ once=True
+ )
+ if isinstance(filepath, Path):
+ filepath = str(filepath)
+ model = self.unwrap_model()
+ model.save(filepath)
+
+ def load_model(self, filepath: Union[Path, str], only_state_dict: bool = True, **kwargs):
+ r"""
+ 加载模型的函数;将 ``filepath`` 中的模型加载并赋值给当前 ``model`` 。
+
+ :param filepath: 保存文件的文件位置
+ :param load_state_dict: 在 **Jittor** 中,该参数无效,**Jittor** 仅支持加载模型的 ``state_dict``。
+ """
+ if not only_state_dict:
+ logger.rank_zero_warning(
+ "Jittor only supports loading state_dict, and we will also load state_dict for you.",
+ once=True
+ )
+ if isinstance(filepath, Path):
+ filepath = str(filepath)
+ model = self.unwrap_model()
+ model.load(filepath)
+
+ def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
+ r"""
+ 断点重训的保存函数,该函数会负责保存 **优化器** 和 **sampler** 的状态,以及 **模型** (若 ``should_save_model`` 为 ``True``)
+
+ :param folder: 保存断点重训的状态的文件夹;:meth:`save_checkpoint` 函数应该在该路径下面下面新增名为 ``FASTNLP_CHECKPOINT_FILENAME`` 与
+ ``FASTNLP_MODEL_FILENAME`` (如果 ``should_save_model`` 为 ``True`` )的文件。把 model 相关的内容放入到 ``FASTNLP_MODEL_FILENAME`` 文件
+ 中,将传入的 ``states`` 以及自身产生的其它状态一并保存在 ``FASTNLP_CHECKPOINT_FILENAME`` 里面。
+ :param states: 由 :class:`~fastNLP.core.controllers.Trainer` 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态。
+ :param dataloader: 正在使用的 dataloader。
+ :param only_state_dict: 是否只保存模型的参数,当 ``should_save_model`` 为 ``False`` ,该参数无效。
+ :param should_save_model: 是否应该保存模型,如果为 ``False`` ,Driver 将不负责 model 的保存。
+ """
+ dataloader_args = self.get_dataloader_args(dataloader)
+ if dataloader_args.sampler:
+ sampler = dataloader_args.sampler
+ else:
+ raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.")
+
+ num_consumed_batches = states.pop('num_consumed_batches')
+ if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
+ sampler_states = sampler.state_dict()
+ if dataloader_args.batch_size is not None:
+ sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \
+ * num_consumed_batches
+ else:
+ logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
+ "it may cause missing some samples when reload.")
+
+ states['sampler_states'] = sampler_states
+ else:
+ raise RuntimeError('The sampler has no `state_dict()` method, fastNLP cannot save the training '
+ 'state.')
+
+ # 2. 保存模型的状态;
+ if should_save_model:
+ if not os.path.exists(folder):
+ os.mkdir(folder)
+ model_path = folder.joinpath(FASTNLP_MODEL_FILENAME)
+ self.save_model(model_path, only_state_dict=only_state_dict)
+
+ # 3. 保存 optimizers 的状态;
+ states["optimizers_state_dict"] = self.get_optimizer_state()
+
+ # 4. 保存fp16的状态
+
+ logger.debug("Save optimizer state dict")
+ jt.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME))
+
+ def get_optimizer_state(self):
+ optimizers_state_dict = {}
+ for i in range(len(self.optimizers)):
+ optimizer: Optimizer = self.optimizers[i]
+ optimizers_state_dict[f"optimizer{i}"] = optimizer.state_dict() # 注意这里没有使用 deepcopy,测试是不需要的;
+ return optimizers_state_dict
+
+ def load_optimizer_state(self, states):
+ assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \
+ f"checkpoint it is:{len(states)}"
+ for i in range(len(self.optimizers)):
+ optimizer: Optimizer = self.optimizers[i]
+ optimizer.load_state_dict(states[f"optimizer{i}"])
+ logger.debug("Load optimizer state dict.")
+
+ def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
+ r"""
+ 断点重训的加载函数,该函数会负责读取数据,并且恢复 **优化器** 、**sampler** 的状态和 **模型** (如果 ``should_load_model`` 为 True)以及其它
+ 在 :meth:`save_checkpoint` 函数中执行的保存操作,然后将一个 state 字典返回给 :class:`~fastNLP.core.controllers.Trainer` ( 内容为 :meth:`save_checkpoint`
+ 接受到的 ``states`` )。
+
+ 该函数应该在所有 rank 上执行。
+
+ :param folder: 读取该 folder 下的 ``FASTNLP_CHECKPOINT_FILENAME`` 文件与 ``FASTNLP_MODEL_FILENAME``
+ (如果 should_load_model 为True)。
+ :param dataloader: 当前给定 dataloader,需要根据保存的 dataloader 状态合理设置。若该值为 ``None`` ,则不需要返回 ``'dataloader'``
+ 以及 ``'batch_idx_in_epoch'`` 这两个值。
+ :param only_state_dict: 是否仅读取模型的 state_dict ,当 ``should_save_model`` 为 ``False`` ,该参数无效。如果为 ``True`` ,说明保存的内容为权重;如果为
+ False 说明保存的是模型,但也是通过当前 Driver 的模型去加载保存的模型的权重,而不是使用保存的模型替换当前模型。
+ :param should_load_model: 是否应该加载模型,如果为 ``False`` ,Driver 将不负责加载模型。若该参数为 ``True`` ,但在保存的状态中没有
+ 找到对应的模型状态,则报错。
+ :return: :meth:`save_checkpoint` 函数输入的 ``states`` 内容。除此之外,还返回的内容有:
+
+ * *dataloader* -- 根据传入的 ``dataloader`` 与读取出的状态设置为合理状态的 dataloader。在当前 ``dataloader`` 样本数与读取出的 sampler 样本数
+ 不一致时报错。
+ * *batch_idx_in_epoch* -- :class:`int` 类型的数据,表明当前 epoch 进行到了第几个 batch 。请注意,该值不能仅通过保存的数据中读取的,因为前后两次运行的
+ ``batch_size`` 可能有变化,而应该符合以下等式::
+
+ 返回的 dataloader 还会产生的 batch 数量 + batch_idx_in_epoch = 原来不断点训练时的 batch 的总数
+
+ 由于 ``返回的 dataloader 还会产生的batch数`` 在 ``batch_size`` 与 ``drop_last`` 参数给定的情况下,无法改变,因此只能通过调整 ``batch_idx_in_epoch``
+ 这个值来使等式成立。一个简单的计算原则如下:
+
+ * drop_last 为 ``True`` 时,等同于 floor(sample_in_this_rank/batch_size) - floor(num_left_samples/batch_size);
+ * drop_last 为 ``False`` 时,等同于 ceil(sample_in_this_rank/batch_size) - ceil(num_left_samples/batch_size)。
+ """
+ states = jt.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)))
+
+ # 1. 加载 optimizers 的状态;
+ optimizers_state_dict = states.pop("optimizers_state_dict")
+ self.load_optimizer_state(optimizers_state_dict)
+
+ # 2. 加载模型状态;
+ if should_load_model:
+ self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict)
+
+ # 3. 加载fp16的状态
+
+ # 4. 恢复 sampler 的状态;
+ dataloader_args = self.get_dataloader_args(dataloader)
+ if dataloader_args.sampler is None:
+ sampler = RandomSampler(dataloader_args.sampler.dataset, shuffle=dataloader_args.shuffle)
+ elif isinstance(dataloader_args.sampler, ReproducibleSampler):
+ sampler = dataloader_args.sampler
+ elif isinstance(dataloader_args.sampler, JittorRandomSampler):
+ sampler = RandomSampler(dataloader_args.sampler.dataset)
+ logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.")
+ elif isinstance(dataloader_args.sampler, JittorSequentialSampler):
+ sampler = RandomSampler(dataloader_args.sampler.dataset, shuffle=False)
+ logger.debug("Replace jittor Sampler into fastNLP RandomSampler without shuffle.")
+ elif self.is_distributed():
+ raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our"
+ "`ReproducibleSampler`.")
+ else:
+ raise RuntimeError(f"Jittor sampler {type(dataloader_args.sampler)} is not supported now.")
+ sampler.load_state_dict(states.pop('sampler_states'))
+ states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)
+
+ # 4. 修改 trainer_state.batch_idx_in_epoch
+ # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler;
+ if dataloader_args.drop_last:
+ batch_idx_in_epoch = len(
+ sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size
+ else:
+ batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \
+ (sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size
+
+ states["batch_idx_in_epoch"] = batch_idx_in_epoch
+
+ return states
+
+ def get_evaluate_context(self):
+ r"""
+ 返回一个不计算梯度的上下文环境用来对模型进行评测;
+
+ :return: 上下文对象 ``jittor.no_grad``
+ """
+ return jt.no_grad
+
+ @staticmethod
+ def move_model_to_device(model: "jt.Module", device):
+ r"""
+ 将模型转移到指定的设备上。由于 **Jittor** 会自动为数据分配设备,因此该函数实际上无效。
+ """
+ ...
+
+ def move_data_to_device(self, batch: 'jt.Var'):
+ """
+ 将数据迁移到指定的机器上;**jittor** 会自动为变量分配设备无需手动迁移,因此这个函数只是简单地返回 ``batch``。
+ """
+ return batch
+
+ def move_data_to_device(self, batch):
+ """
+ 将数据 ``batch`` 转移到指定的设备上。由于 **Jittor** 会自动为数据分配设备,因此该函数实际上无效。
+ """
+ return batch
+
+ @staticmethod
+ def tensor_to_numeric(tensor, reduce=None):
+ r"""
+ 将一个 :class:`jittor.Var` 对象转换为 转换成 python 中的数值类型。
+
+ :param tensor: :class:`jittor.Var` 类型的对象
+ :param reduce: 当 tensor 是一个多数值的张量时,应当使用何种归一化操作来转换成单一数值,应当为以下类型之一:``['max', 'min', 'sum', 'mean']``。
+ :return: 返回一个单一数值,其数值类型是 python 中的基本的数值类型,例如 ``int,float`` 等
+ """
+ if tensor is None:
+ return None
+
+ def _translate(_data):
+ # 如果只含有一个元素,则返回元素本身,而非list
+ if _data.numel() == 1:
+ return _data.item()
+ if reduce is None:
+ return _data.tolist()
+ return _reduces[reduce](_data).item()
+
+ return apply_to_collection(
+ data=tensor,
+ dtype=jt.Var,
+ function=_translate
+ )
+
+ def set_model_mode(self, mode: str):
+ r"""
+ 设置模型为 ``train`` 或 ``eval`` 的模式;目的是为切换模型的训练和推理(会关闭 dropout 等)模式。
+
+ :param mode: 应为二者之一:``["train", "eval"]``
+ """
+ assert mode in {"train", "eval"}
+ getattr(self.model, mode)()
+
+ @property
+ def data_device(self):
+ """
+ :return: 数据默认会被迁移到的设备
+ """
+ return self.model_device
+
+ def set_deterministic_dataloader(self, dataloader: Union["JittorDataLoader", "Dataset"]):
+ r"""
+ 为了确定性训练要对 ``dataloader`` 进行修改,保证在确定随机数种子后,每次重新训练得到的结果是一样的。 **jittor** 暂时不提供
+ 该功能。
+ """
+ ...
+
+ def set_sampler_epoch(self, dataloader: Union["JittorDataLoader", "Dataset"], cur_epoch_idx: int):
+ r"""
+ 对于分布式的 ``sampler``,需要在每一个 ``epoch`` 前设置随机数种子,来保证每一个进程上的 ``shuffle`` 是一样的。
+
+ :param dataloader: 需要设置 ``epoch`` 的 ``dataloader``
+ :param cur_epoch_idx: 当前是第几个 ``epoch``
+ """
+ # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的;
+ if callable(getattr(dataloader.sampler, "set_epoch", None)):
+ dataloader.sampler.set_epoch(cur_epoch_idx)
+
+ @staticmethod
+ def get_dataloader_args(dataloader: Union["JittorDataLoader", "Dataset"]):
+ """
+ 从 ``dataloader`` 中获取参数 ``dataset``, ``batch_sampler``, ``sampler``, ``batch_size``, ``shuffle``
+ 和 ``drop_last`` 。
+ """
+ @dataclass
+ class Res:
+ dataset: Optional[Dataset] = None
+ batch_sampler: Optional[JittorBatchSampler] = None
+ sampler: Optional[JittorSampler] = None
+ batch_size: Optional[int] = None
+ shuffle: Optional[bool] = None
+ drop_last: Optional[bool] = None
+
+ res = Res()
+ from fastNLP.core.dataloaders.jittor_dataloader.fdl import _JittorDataset
+ if isinstance(dataloader, JittorDataLoader):
+ # JittorDataLoader 实际上是迭代 dataset 成员的
+ dataloader = dataloader.dataset
+ if isinstance(dataloader, _JittorDataset):
+ # 获取最原始的 dataset
+ res.dataset = dataloader.dataset
+ else:
+ res.dataset = dataloader
+
+ # jittor 现在不支持 batch_sampler,所以除了 shuffle 都可以直接获取
+ res.batch_size = dataloader.batch_size
+ res.drop_last = dataloader.drop_last
+ if dataloader.sampler is None:
+ # sampler 是 None,那么就从 Dataset 的属性中获取
+ res.shuffle = dataloader.shuffle
+ elif isinstance(list(dataloader.sampler.__iter__())[0], (list,tuple)):
+ # jittor 目前不支持 batch_sampler
+ raise NotImplementedError("Jittor does not support using batch_sampler in `Dataset` now, "
+ "please check if you have set `Dataset.sampler` as `BatchSampler`")
+ else:
+ # sampler 不为 None
+ res.sampler = dataloader.sampler
+ if hasattr(dataloader.sampler, "shuffle"):
+ # 这种情况一般出现在 fastNLP 的 ReproduceSampler 中
+ res.shuffle = dataloader.sampler.shuffle
+ elif isinstance(dataloader.sampler, JittorRandomSampler):
+ res.shuffle = True
+ else:
+ res.shuffle = False
+
+ return res
\ No newline at end of file
diff --git a/fastNLP/core/drivers/jittor_driver/mpi.py b/fastNLP/core/drivers/jittor_driver/mpi.py
new file mode 100644
index 00000000..54b8bd1e
--- /dev/null
+++ b/fastNLP/core/drivers/jittor_driver/mpi.py
@@ -0,0 +1,157 @@
+import os
+from typing import Optional, Union, Callable, Dict, Tuple
+
+from .jittor_driver import JittorDriver
+from fastNLP.core.utils import auto_param_call
+from fastNLP.core.utils.utils import _get_fun_msg
+from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
+from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
+from fastNLP.core.log import logger
+
+if _NEED_IMPORT_JITTOR:
+ import jittor as jt
+
+__all__ = [
+ "JittorMPIDriver",
+]
+
+class JittorMPIDriver(JittorDriver):
+ """
+ 执行 ``Jittor`` 框架下分布式训练的 ``Driver``。
+
+ .. note::
+
+ 这是一个正在开发中的功能,敬请期待。
+
+ .. todo:
+
+ 实现断点重训中替换 dataloader 的 set_dist_repro_dataloader 函数
+
+ """
+ def __init__(
+ self,
+ model,
+ parallel_device: None,
+ is_pull_by_jittor_run: bool = False,
+ fp16: bool = False,
+ jittor_kwargs: Dict = None,
+ **kwargs
+ ):
+
+ super(JittorMPIDriver, self).__init__(model, fp16=fp16, jittor_kwargs=jittor_kwargs, **kwargs)
+ raise NotImplementedError("MPI for Jittor is not supported right now.")
+
+ self.is_pull_by_jittor_run = is_pull_by_jittor_run
+ self.parallel_device = parallel_device
+
+ self.outside_mpi = False
+
+ def setup(self):
+ self.__fork_with_mpi__()
+
+ def __fork_with_mpi__(self):
+ import sys
+ if jt.in_mpi:
+ # you can mult other process output
+ if jt.rank != 0:
+ sys.stdout = open("/dev/null", "w")
+ return
+ else:
+ if self.parallel_device == -1: # device 为 -1,那么默认使用全部的显卡
+ raise NotImplementedError(f"Device={self.parallel_device}")
+ elif type(self.parallel_device) is int: # device 为 *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练
+ num_procs = 1
+ devices = self.parallel_device
+ elif type(self.parallel_device) is list: # device 为 *list(int)*: 多于 1 个device,应当通过该种方式进行设定
+ num_procs = len(self.parallel_device)
+ devices = str(self.parallel_device)[1:-1]
+ else:
+ raise NotImplementedError(f"Device={self.parallel_device}")
+ print(sys.argv)
+ cmd = " ".join(["CUDA_VISIBLE_DEVICES='%s'" % devices, "mpirun", "-np", str(num_procs), sys.executable] + sys.argv)
+ print("[RUN CMD]:", cmd)
+ os.system(cmd)
+ exit(0)
+
+ def configure_mpi(self):
+ pass
+
+ @property
+ def world_size(self) -> int:
+ return self._world_size
+
+ @world_size.setter
+ def world_size(self, size: int):
+ self._world_size = size
+
+ @property
+ def global_rank(self) -> int:
+ return self._global_rank
+
+ @global_rank.setter
+ def global_rank(self, rank: int) -> None:
+ self._global_rank = rank
+
+ @property
+ def local_rank(self) -> int:
+ return int(os.environ.get("LOCAL_RANK", 0))
+
+ @property
+ def data_device(self):
+ if self.outside_mpi:
+ return self._data_device
+ return self.parallel_device
+
+ def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
+ if isinstance(batch, Dict) and not self.wo_auto_param_call:
+ return auto_param_call(fn, batch, signature_fn=signature_fn)
+ else:
+ return fn(batch)
+
+ def get_model_call_fn(self, fn: str) -> Tuple:
+ if hasattr(self.model, fn):
+ fn = getattr(self.model, fn)
+ if not callable(fn):
+ raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
+ logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
+ return fn, None
+ elif fn in {"train_step", "evaluate_step"}:
+ logger.debug(f'Use {_get_fun_msg(self.model.execute, with_fp=False)}...')
+ return self.model, self.model.execute
+ else:
+ raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")
+
+ def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]],
+ reproducible: bool = False, sampler_or_batch_sampler=None):
+ return dataloader
+
+ def is_global_zero(self):
+ return self.global_rank == 0
+
+ def get_model_no_sync_context(self):
+ return self.model.no_sync
+
+ def unwrap_model(self):
+ """
+ 返回训练使用的模型。
+ """
+ return self.model
+
+ def get_local_rank(self) -> int:
+ return self.local_rank
+
+ def barrier(self):
+ pass
+
+ def is_distributed(self):
+ """
+ 判断是否为分布式的 **Driver** ,在 ``JittorMPIDriver`` 中,返回 ``True``。
+ """
+ return True
+
+ @property
+ def data_device(self) -> str:
+ """
+ :return: 数据所在的设备。
+ """
+ return self.model_device
\ No newline at end of file
diff --git a/fastNLP/core/drivers/jittor_driver/single_device.py b/fastNLP/core/drivers/jittor_driver/single_device.py
new file mode 100644
index 00000000..0abadd45
--- /dev/null
+++ b/fastNLP/core/drivers/jittor_driver/single_device.py
@@ -0,0 +1,137 @@
+from typing import Dict, Union, Tuple, Callable, Optional
+
+from .jittor_driver import JittorDriver
+from .utils import replace_batch_sampler, replace_sampler
+from fastNLP.core.utils import auto_param_call
+from fastNLP.core.utils.utils import _get_fun_msg
+from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
+from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \
+ ReproduceBatchSampler
+from fastNLP.core.samplers import RandomSampler
+from fastNLP.core.log import logger
+
+if _NEED_IMPORT_JITTOR:
+ import jittor as jt
+ from jittor.dataset import (
+ RandomSampler as JittorRandomSampler,
+ SequentialSampler as JittorSequentialSampler,
+ )
+
+__all__ = [
+ "JittorSingleDriver",
+]
+
+class JittorSingleDriver(JittorDriver):
+ r"""
+ ``Jittor`` 框架下用于 ``cpu`` 和单卡 ``gpu`` 运算的 ``Driver``。
+
+ :param model: 传入给 ``Trainer`` 的 ``model`` 参数。
+ :param device: 训练和模型所在的设备,在 **Jittor** 中,应当为以下值之一:``[None, 'cpu', 'gpu', 'cuda']``:
+
+ * 为 ``None`` 或 ``cpu`` 时,表示在 ``cpu`` 上进行训练;
+ * 为 ``gpu`` 或 ``cuda`` 时,表示在显卡设备上进行训练;
+
+ :param fp16: 是否开启 fp16 混合精度训练。
+ :param jittor_kwargs:
+ :kwargs:
+ * *model_wo_auto_param_call* (``bool``) -- 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为
+
+ .. note::
+
+ 关于该参数的详细说明,请参见 :class:`~fastNLP.core.controllers.Trainer` 中的描述;函数 ``auto_param_call`` 详见 :func:`fastNLP.core.utils.auto_param_call`。
+
+ """
+
+ def __init__(self, model, device=None, fp16: bool = False, jittor_kwargs: Dict = None, **kwargs):
+ if device not in [None, "cpu", "gpu", "cuda"]:
+ raise RuntimeError("Parameter `device` should be one of [None, 'cpu', 'gpu', 'cuda'] .")
+ super(JittorSingleDriver, self).__init__(model, fp16, jittor_kwargs=jittor_kwargs)
+
+ self.model_device = device if device is not None else "cpu"
+
+ self.local_rank = 0
+ self.global_rank = 0
+ self.world_size = 1
+
+ def setup(self):
+ r"""
+ 初始化训练环境;根据传入的 ``device`` 值设置模型的训练场景为 ``cpu`` 或 ``gpu``。
+ """
+ if self.model_device in ["cpu", None]:
+ jt.flags.use_cuda = 0 # 使用 cpu
+ else:
+ jt.flags.use_cuda = 1 # 使用 cuda
+
+ def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
+ if isinstance(batch, Dict) and not self.wo_auto_param_call:
+ return auto_param_call(fn, batch, signature_fn=signature_fn)
+ else:
+ return fn(batch)
+
+ def get_model_call_fn(self, fn: str) -> Tuple:
+ if hasattr(self.model, fn):
+ fn = getattr(self.model, fn)
+ if not callable(fn):
+ raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
+ logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
+ return fn, None
+ elif fn in {"train_step", "evaluate_step"}:
+ logger.debug(f'Use {_get_fun_msg(self.model.execute, with_fp=False)}...')
+ return self.model, self.model.execute
+ else:
+ raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")
+
+ def unwrap_model(self):
+ """
+ 返回训练使用的模型。
+ """
+ return self.model
+
+ def is_distributed(self):
+ """
+ 判断是否为分布式的 **Driver** ,在 ``JittorSingleDriver`` 中,返回 ``False``。
+ """
+ return False
+
+ def set_dist_repro_dataloader(self, dataloader,
+ dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None,
+ reproducible: bool = False):
+ # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load_checkpoint 函数调用;
+ if isinstance(dist, ReproducibleBatchSampler):
+ return replace_batch_sampler(dataloader, dist)
+ elif isinstance(dist, ReproducibleSampler):
+ return replace_sampler(dataloader, dist)
+
+ # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用;
+ args = self.get_dataloader_args(dataloader)
+ if isinstance(args.batch_sampler, ReproducibleBatchSampler):
+ batch_sampler = re_instantiate_sampler(args.batch_sampler)
+ return replace_batch_sampler(dataloader, batch_sampler)
+ elif isinstance(args.sampler, ReproducibleSampler):
+ sampler = re_instantiate_sampler(args.sampler)
+ return replace_sampler(dataloader, sampler)
+
+ if reproducible:
+ if args.sampler is None:
+ sampler = RandomSampler(args.dataset, args.shuffle)
+ return replace_sampler(dataloader, sampler)
+ elif type(args.sampler) is JittorRandomSampler:
+ if getattr(args.sampler, '_num_samples', None) is None \
+ and getattr(args.sampler, 'rep', False) is False:
+ # 如果本来就是随机的,并且没有定制,直接替换掉吧。
+ sampler = RandomSampler(args.sampler.dataset, shuffle=True)
+ logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.")
+ return replace_sampler(dataloader, sampler)
+ elif type(args.sampler) is JittorSequentialSampler:
+ # 需要替换为不要 shuffle 的。
+ sampler = RandomSampler(args.sampler.dataset, shuffle=False)
+ logger.debug("Replace jittor SequentialSampler into fastNLP RandomSampler.")
+ return replace_sampler(dataloader, sampler)
+ batch_sampler = ReproduceBatchSampler(
+ batch_sampler=args.batch_sampler,
+ batch_size=args.batch_size,
+ drop_last=args.drop_last
+ )
+ return replace_batch_sampler(dataloader, batch_sampler)
+ else:
+ return dataloader
diff --git a/fastNLP/core/drivers/jittor_driver/utils.py b/fastNLP/core/drivers/jittor_driver/utils.py
new file mode 100644
index 00000000..3caeb3ef
--- /dev/null
+++ b/fastNLP/core/drivers/jittor_driver/utils.py
@@ -0,0 +1,80 @@
+import inspect
+import os
+import random
+from copy import deepcopy
+from typing import Union
+
+import numpy as np
+
+from fastNLP.core.dataloaders import JittorDataLoader
+from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
+from fastNLP.envs.utils import get_global_seed
+from fastNLP.envs import (
+ get_global_rank,
+ FASTNLP_BACKEND_LAUNCH,
+ FASTNLP_GLOBAL_SEED,
+)
+from fastNLP.core.samplers import ReproducibleBatchSampler
+from fastNLP.core.log import logger
+
+if _NEED_IMPORT_JITTOR:
+ import jittor as jt
+ from jittor.dataset import Dataset
+
+__all__ = [
+ "jittor_seed_everything",
+]
+
+def jittor_seed_everything(seed: int = None, add_global_rank_to_seed: bool = True) -> int:
+ r"""
+ 为 **jittor**、**numpy**、**python.random** 伪随机数生成器设置种子。
+
+ :param seed: 全局随机状态的整数值种子。如果为 ``None`` 则会根据时间戳生成一个种子。
+ :param add_global_rank_to_seed: 在分布式训练中,是否在不同 **rank** 中使用不同的随机数。
+ 当设置为 ``True`` 时,**fastNLP** 会将种子加上当前的 ``global_rank``。
+ """
+ max_seed_value = np.iinfo(np.uint32).max
+ min_seed_value = np.iinfo(np.uint32).min
+
+ if seed is None:
+ if os.getenv(FASTNLP_BACKEND_LAUNCH) == "1":
+ seed = 42
+ else:
+ seed = get_global_seed()
+ logger.info(f"'FASTNLP_GLOBAL_SEED' is set to {seed} automatically.")
+ if not isinstance(seed, int):
+ seed = int(seed)
+
+ if not (min_seed_value <= seed <= max_seed_value):
+ logger.rank_zero_warning("Your seed value is too big or too small for numpy, we will choose a random seed for you.")
+ seed %= max_seed_value
+
+ os.environ[FASTNLP_GLOBAL_SEED] = f"{seed}"
+ if add_global_rank_to_seed:
+ seed += get_global_rank()
+
+ random.seed(seed)
+ np.random.seed(seed)
+ jt.set_global_seed(seed)
+ return seed
+
+def replace_batch_sampler(dataloader, batch_sampler):
+ raise NotImplementedError("Jittor does not support using batch_sampler in `Dataset` now, "
+ "please check if you have set `Dataset.sampler` as `BatchSampler`"
+ "or report this bug to us.")
+
+def replace_sampler(dataloader: Union["Dataset", "JittorDataLoader"], sampler):
+ batch_sampler = getattr(dataloader, "sampler")
+ if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler):
+ raise RuntimeError("It should not be running here, please report a bug to us.")
+ if isinstance(dataloader, JittorDataLoader):
+ init_params = dict(inspect.signature(dataloader.__init__).parameters)
+ reconstruct_args = {name: getattr(dataloader, name, p.default) for name, p in init_params.items()}
+ reconstruct_args["dataset"] = replace_sampler(reconstruct_args["dataset"].dataset, reconstruct_args["dataset"].sampler)
+ new_dataloader = type(dataloader)(**reconstruct_args)
+ new_dataloader.dataset.set_attrs(sampler=sampler)
+ else:
+ new_dataloader = deepcopy(dataloader)
+ new_dataloader.set_attrs(sampler=sampler)
+
+ return new_dataloader
\ No newline at end of file
diff --git a/fastNLP/core/drivers/oneflow_driver/__init__.py b/fastNLP/core/drivers/oneflow_driver/__init__.py
new file mode 100644
index 00000000..167c0431
--- /dev/null
+++ b/fastNLP/core/drivers/oneflow_driver/__init__.py
@@ -0,0 +1,17 @@
+__all__ = [
+ "OneflowDriver",
+ "OneflowSingleDriver",
+ "OneflowDDPDriver",
+ "oneflow_seed_everything",
+]
+
+from .ddp import OneflowDDPDriver
+from .single_device import OneflowSingleDriver
+from .oneflow_driver import OneflowDriver
+from .utils import oneflow_seed_everything
+
+
+
+
+
+
diff --git a/fastNLP/core/drivers/oneflow_driver/ddp.py b/fastNLP/core/drivers/oneflow_driver/ddp.py
new file mode 100644
index 00000000..974c0b69
--- /dev/null
+++ b/fastNLP/core/drivers/oneflow_driver/ddp.py
@@ -0,0 +1,351 @@
+import os
+from typing import List, Optional, Union, Dict
+
+from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW
+
+if _NEED_IMPORT_ONEFLOW:
+ import oneflow
+ import oneflow.comm as comm
+ import oneflow.env as dist_env
+ from oneflow.nn.parallel import DistributedDataParallel
+ from oneflow.utils.data import BatchSampler
+
+__all__ = [
+ "OneflowDDPDriver"
+]
+
+from .oneflow_driver import OneflowDriver
+from fastNLP.core.drivers.oneflow_driver.utils import (
+ replace_sampler,
+ replace_batch_sampler
+)
+from fastNLP.core.utils import check_user_specific_params
+from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, \
+ ReproducibleBatchSampler, \
+ re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler
+from fastNLP.envs import FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC
+from fastNLP.core.log import logger
+from fastNLP.core.drivers.oneflow_driver.dist_utils import fastnlp_oneflow_all_gather, fastnlp_oneflow_broadcast_object
+from .utils import _check_dataloader_args_for_distributed
+
+
+class OneflowDDPDriver(OneflowDriver):
+ r"""
+ ``OneflowDDPDriver`` 实现了动态图下使用 ``DistributedDataParallel`` 进行的数据并行分布式训练。
+
+ .. note::
+
+ 您在绝大多数情况下不需要自己使用到该类,通过向 ``Trainer`` 传入正确的参数,您可以方便快速地部署您的分布式训练。
+
+ ``OneflowDDPDriver`` 目前支持两种启动方式:
+
+ 1. 用户不做任何处理,通过运行 ``python -m oneflow.distributed.launch --nproc_per_node 2 train.py`` 启动;
+ 2. 用户将模型通过 ``DistributedDataParallel`` 处理后,通过运行 ``python -m oneflow.distributed.launch --nproc_per_node 2 train.py`` 启动;
+
+ 注意多机的启动强制要求用户在每一台机器上使用 ``python -m oneflow.distributed.launch`` 启动;因此我们不会在 ``OneflowDDPDriver`` 中保存
+ 任何当前有多少台机器的信息。
+
+ :param model: 传入给 ``Trainer`` 的 ``model`` 参数
+ :param parallel_device: 该参数无效,**fastNLP** 会自动获取当前进程的设备
+ :param fp16: 是否开启 fp16 训练;目前该参数无效
+ :param oneflow_kwargs:
+ * *ddp_kwargs* -- 用于 ``DistributedDataParallel`` 的其它参数,详情可查阅 **oneflow** 的官方文档
+ :kwargs:
+ * *model_wo_auto_param_call* (``bool``) -- 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为
+
+ .. note::
+
+ 关于该参数的详细说明,请参见 :class:`~fastNLP.core.controllers.Trainer` 中的描述;函数 ``auto_param_call`` 详见 :func:`fastNLP.core.utils.auto_param_call`。
+
+ """
+
+ def __init__(
+ self,
+ model,
+ parallel_device: Optional["oneflow.device"],
+ fp16: bool = False,
+ oneflow_kwargs: Dict = None,
+ **kwargs
+ ):
+
+ super(OneflowDDPDriver, self).__init__(model, fp16=fp16, oneflow_kwargs=oneflow_kwargs, **kwargs)
+
+ # oneflow 会自己初始化通信组,因此 parallel_device 实际上不起作用,可以通过 current_device 获取设备
+ self.model_device = oneflow.device("cuda", oneflow.cuda.current_device())
+ self._data_device = self.model_device
+
+ self.global_rank = int(os.environ["RANK"])
+ self.world_size = int(os.environ["WORLD_SIZE"])
+
+ self._ddp_kwargs = self._oneflow_kwargs.get("ddp_kwargs", {})
+ check_user_specific_params(self._ddp_kwargs, DistributedDataParallel.__init__, DistributedDataParallel.__name__)
+ if len(self.model._buffers) != 0 and self._ddp_kwargs.get("broadcast_buffers", None) is None:
+ logger.info("Notice your model has buffers and you are using `OneflowDDPDriver`, but you do not set "
+ "'broadcast_buffers' in your trainer. Cause in most situations, this parameter can be set"
+ " to 'False' to avoid redundant data communication between different processes.")
+
+ self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error")
+ assert isinstance(self.output_from_new_proc, str), "Parameter `output_from_new_proc` can only be `str` type."
+ if self.output_from_new_proc not in {"all", "ignore", "only_error"}:
+ os.makedirs(name=self.output_from_new_proc, exist_ok=True)
+ self.output_from_new_proc = os.path.abspath(self.output_from_new_proc)
+
+ self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的;
+ self._has_ddpwrapped = False# hasattr(model, )
+
+ def setup(self):
+ r"""
+ 将模型用 ``DistributedDataParallel`` 进行处理。
+ """
+ if self._has_setup:
+ return
+ self._has_setup = True
+
+ self.configure_ddp()
+ self.barrier()
+ # 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作;
+ # self._pids = [oneflow.tensor(0, dtype=oneflow.int).to(self.data_device) for _ in range(dist_env.get_world_size())]
+ # comm.all_gather(self._pids, oneflow.tensor(os.getpid(), dtype=oneflow.int).to(self.data_device))
+ # local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) if "LOCAL_WORLD_SIZE" in os.environ else None
+ # if local_world_size is None:
+ # local_world_size = oneflow.tensor(int(os.environ.get("LOCAL_RANK")), dtype=oneflow.int).to(self.data_device)
+ # comm.all_reduce(local_world_size, op=dist_env.ReduceOp.MAX)
+ # local_world_size = local_world_size.tolist() + 1
+
+ # node_rank = self.global_rank // local_world_size
+ # self._pids = self._pids[node_rank * local_world_size: (node_rank + 1) * local_world_size]
+ # self._pids = self.tensor_to_numeric(self._pids)
+
+ def configure_ddp(self):
+ if not hasattr(self.model, "_ddp_state_for_reversed_params"):
+ self.model.to(self.model_device)
+ self.model = DistributedDataParallel(
+ # 注意这里的 self.model_device 是 `oneflow.device` type,因此 self.model_device.index;
+ self.model,
+ **self._ddp_kwargs
+ )
+ self._has_ddpwrapped = True
+
+ @property
+ def master_address(self) -> str:
+ """
+ 分布式训练中的地址 ``MASTER_ADDR``
+ """
+ return os.environ.get("MASTER_ADDR")
+
+ @property
+ def master_port(self) -> str:
+ """
+ 分布式训练使用的端口 ``MASTER_PORT``
+ """
+ return os.environ.get("MASTER_PORT")
+
+ @property
+ def world_size(self) -> int:
+ """
+ 分布式训练的进程总数 ``WORLD_SIZE``
+ """
+ return self._world_size
+
+ @world_size.setter
+ def world_size(self, size: int):
+ self._world_size = size
+
+ @property
+ def global_rank(self) -> int:
+ """
+ 当前进程的全局编号 ``global_rank``
+ """
+ return self._global_rank
+
+ @global_rank.setter
+ def global_rank(self, rank: int) -> None:
+ self._global_rank = rank
+
+ @property
+ def local_rank(self) -> int:
+ """
+ 当前进程的局部编号 ``local_rank``
+ """
+ return int(os.environ.get("LOCAL_RANK", 0))
+
+ @property
+ def data_device(self):
+ """
+ 数据所在的设备。由于 **oneflow** 可以通过 :func:`oneflow.cuda.current_device` 获取当前进程的设备,因此
+ 该属性和 ``model_device`` 表现相同。
+ """
+ return self._data_device
+
+ def set_dist_repro_dataloader(self, dataloader,
+ dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]] = None,
+ reproducible: bool = False):
+ # 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load_checkpoint 函数调用;
+ # 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 OneflowDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数;
+ if isinstance(dist, ReproducibleBatchSampler):
+ dist.set_distributed(
+ num_replicas=self.world_size,
+ rank=self.global_rank,
+ pad=True
+ )
+ return replace_batch_sampler(dataloader, dist)
+ if isinstance(dist, ReproducibleSampler):
+ dist.set_distributed(
+ num_replicas=self.world_size,
+ rank=self.global_rank,
+ pad=True
+ )
+ return replace_sampler(dataloader, dist)
+
+ # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用;
+ # trainer, evaluator
+ if dist is None:
+ if reproducible:
+ raise RuntimeError("It is not allowed to save checkpoint if the sampler is not allowed to be replaced.")
+ else:
+ args = self.get_dataloader_args(dataloader)
+ if isinstance(args.batch_sampler, ReproducibleBatchSampler):
+ return replace_batch_sampler(dataloader, re_instantiate_sampler(args.batch_sampler))
+ if isinstance(args.sampler, ReproducibleSampler):
+ return replace_sampler(dataloader, re_instantiate_sampler(args.sampler))
+ return dataloader
+ # trainer
+ elif dist == "dist":
+ args = self.get_dataloader_args(dataloader)
+ # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为;
+ if isinstance(args.batch_sampler, ReproducibleBatchSampler):
+ batch_sampler = re_instantiate_sampler(args.batch_sampler)
+ batch_sampler.set_distributed(
+ num_replicas=self.world_size,
+ rank=self.global_rank,
+ pad=True
+ )
+ return replace_batch_sampler(dataloader, batch_sampler)
+ elif isinstance(args.sampler, ReproducibleSampler):
+ sampler = re_instantiate_sampler(args.sampler)
+ sampler.set_distributed(
+ num_replicas=self.world_size,
+ rank=self.global_rank,
+ pad=True
+ )
+ return replace_sampler(dataloader, sampler)
+ else:
+ _check_dataloader_args_for_distributed(args, controller="Trainer")
+ sampler = RandomSampler(
+ dataset=args.dataset,
+ shuffle=args.shuffle,
+ seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0))
+ )
+ sampler.set_distributed(
+ num_replicas=self.world_size,
+ rank=self.global_rank,
+ pad=True
+ )
+ return replace_sampler(dataloader, sampler)
+ # evaluator
+ elif dist == "unrepeatdist":
+ args = self.get_dataloader_args(dataloader)
+ if type(args.batch_sampler) != BatchSampler:
+ # TODO 这里的目的是判断用户的 batch_sampler 是定制的,可能需要完善
+ logger.warning("Note that you are using customized ``batch_sampler`` in evaluate dataloader or" \
+ "train dataloader while testing ``overfit_batches``, which may cause that" \
+ "the data for distributed evaluation is not unrepeated.")
+ if isinstance(args.sampler, ReproducibleSampler):
+ sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler)
+ elif not isinstance(args.sampler, UnrepeatedSampler):
+ _check_dataloader_args_for_distributed(args, controller='Evaluator')
+ sampler = UnrepeatedSequentialSampler(
+ dataset=args.dataset
+ )
+ else:
+ sampler = re_instantiate_sampler(args.sampler)
+ sampler.set_distributed(
+ num_replicas=self.world_size,
+ rank=self.global_rank
+ )
+ # TODO 这里暂时统一替换为 BatchSampler
+ batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False)
+ return replace_batch_sampler(dataloader, batch_sampler)
+ else:
+ raise ValueError(
+ "Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).")
+
+ def is_global_zero(self):
+ r"""
+ :return: 当前的进程是否在全局上是进程 0
+ """
+ return self.global_rank == 0
+
+ def get_model_no_sync_context(self):
+ r"""
+ :return: 一个 ``context`` 上下文环境,用于关闭各个进程之间的同步;该功能暂时无效,返回一个空的上下文环境
+ """
+ # TODO 暂时没有在 oneflow 中找到类似的功能;
+ from fastNLP.core.utils import nullcontext
+ return nullcontext
+ return self.model.no_sync
+
+ def unwrap_model(self):
+ r"""
+ :return: 使用的原始模型
+ """
+ return self.model
+
+ def get_local_rank(self) -> int:
+ r"""
+ :return: 当前进程局部的进程编号
+ """
+ return self.local_rank
+
+ def barrier(self):
+ r"""
+ 同步各个进程之间的操作
+ """
+ if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行
+ comm.barrier()
+
+ def is_distributed(self):
+ r"""
+ :return: 当前使用的 driver 是否是分布式的 driver,对于 ``OneflowDDPDriver`` 来说,该函数一定返回 ``True``
+ """
+ return True
+
+ def broadcast_object(self, obj, src: int = 0, group=None, **kwargs):
+ r"""
+ 从 ``src`` 端将 ``obj`` 对象(可能是 tensor ,可能是 object )广播到其它进程。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行
+ 传输,然后在接收处处再加载回来。仅在分布式的 driver 中有实际意义。
+
+ :param obj: obj,可能是 Tensor 或 嵌套类型的数据
+ :param src: 发送方的 ``global_rank``
+ :param group: 该参数无效
+ :return: 如果当前 rank 是接收端,则返回接收到的参数;如果是 source 端则返回发送的内容。如果环境变量 ``FASTNLP_NO_SYNC`` 为 **2** 则
+ 返回 ``None``
+ """
+ if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC == 2 直接返回。
+ return
+ return fastnlp_oneflow_broadcast_object(obj, src, device=self.data_device)
+
+ def all_gather(self, obj, group) -> List:
+ r"""
+ 将 ``obj`` 互相传送到其它所有的 rank 上,其中 ``obj`` 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,将会尝试通过
+ pickle 进行序列化,接收到之后再反序列化。
+
+ example::
+
+ >>> # rank 0
+ >>> obj = {'a': 1, 'b':[1, 2], 'c':{'d': 1}}
+ >>> # rank 1
+ >>> obj = {'a': 1, 'b':[1, 2], 'c':{'d': 2}}
+ >>> # after all_gather():
+ >>> result = [
+ {'a': 1, 'b':[1, 2], 'c':{'d': 1}},
+ {'a': 1, 'b':[1, 2], 'c':{'d': 2}}
+ ]
+
+ :param obj: 需要传输的对象,在每个 rank 上都应该保持相同的结构。
+ :param group: 该参数无效。
+ :return: 所有 rank 发送的 ``obj`` 聚合在一起的内容;如果环境变量 ``FASTNLP_NO_SYNC`` 为 **2** 则不会执行,直接返回 ``[obj]`` 。
+ """
+ if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC 表示不执行
+ return [obj]
+ return fastnlp_oneflow_all_gather(obj)
diff --git a/fastNLP/core/drivers/oneflow_driver/dist_utils.py b/fastNLP/core/drivers/oneflow_driver/dist_utils.py
new file mode 100644
index 00000000..c4570120
--- /dev/null
+++ b/fastNLP/core/drivers/oneflow_driver/dist_utils.py
@@ -0,0 +1,305 @@
+import io
+import pickle
+import os
+from typing import Any, List
+
+from fastNLP.core.utils import apply_to_collection, get_oneflow_device
+from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW
+from fastNLP.envs.env import FASTNLP_NO_SYNC
+if _NEED_IMPORT_ONEFLOW:
+ import oneflow
+ import oneflow.comm as comm
+ import oneflow.env as dist_env
+
+PROTOCOL_VERSION = 1
+
+__all__ = []
+
+def _validate_output_list_for_rank(my_rank, dst, gather_list):
+ if dst == my_rank:
+ if not gather_list:
+ raise ValueError(
+ "Argument ``gather_list`` must be specified on destination rank."
+ )
+ elif gather_list:
+ raise ValueError(
+ "Argument ``gather_list`` must NOT be specified "
+ "on non-destination ranks."
+ )
+
+ obj = {"protocol_version": PROTOCOL_VERSION, "data": obj}
+ pickled_bytes = pickle.dumps(obj)
+
+def fastnlp_oneflow_gather_object(obj, dst=0):
+ """
+ 从其它 rank gather 东西到 dst rank 。
+
+ Example::
+ >>> # Assumes world_size of 3.
+ >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
+ >>> output = [None for _ in gather_objects]
+ >>> fastnlp_oneflow_gather_object(
+ gather_objects[dist.get_rank()],
+ output if dist.get_rank() == 0 else None,
+ dst=0
+ )
+ >>> # On rank 0
+ >>> output
+ ['foo', 12, {1: 2}]
+
+ :param obj: 需要发送的 obj 对象,需要是可以 pickable 的对象
+ :param dst: 目标的 rank 。
+ :return: 在 dst 上面返回 world_size 的 list,依次为 rank 0;rank 1...上 obj
+ """
+ if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
+ return [obj]
+
+ if dist_env.get_rank() == dst:
+ object_gather_list = [None for _ in range(dist_env.get_world_size())]
+ else:
+ object_gather_list = None
+
+ # Ensure object_gather_list is specified appopriately.
+ my_rank = dist_env.get_rank()
+ _validate_output_list_for_rank(my_rank, dst, object_gather_list)
+ # 防止 unpickle 的时候出现在了发送的 gpu 上。
+ obj = apply_to_collection(obj, oneflow.Tensor, _to_device, device=oneflow.device("cpu"))
+ input_tensor, local_size = _object_to_tensor(obj)
+ current_device = oneflow.device("cuda")
+ input_tensor = input_tensor.to(current_device)
+ local_size = local_size.to(current_device)
+ # Gather all local sizes. This is so that we can find the max size, and index
+ # until the correct size when deserializing the tensors.
+ group_size = dist_env.get_world_size()
+ object_sizes_tensor = oneflow.zeros(group_size, dtype=oneflow.long, device=current_device)
+ object_size_list = [
+ object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
+ ]
+ # Allgather tensor sizes. An all-gather is needed here despite this being a
+ # gather, since each rank needs to broadcast a tensor of the same (maximal)
+ # size.
+ comm.all_gather(object_size_list, local_size)
+ max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
+ # Resize tensor to max size across all ranks.
+ input_tensor = input_tensor.reshape(max_object_size)
+ # Avoid populating output tensors if the result won't be gathered on this rank.
+ if my_rank == dst:
+ coalesced_output_tensor = oneflow.empty(
+ max_object_size * group_size, dtype=oneflow.uint8, device=current_device
+ )
+ # Output tensors are nonoverlapping views of coalesced_output_tensor
+ output_tensors = [
+ coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
+ for i in range(group_size)
+ ]
+ # All ranks call gather with equal-sized tensors.
+ comm.gather(
+ input_tensor,
+ gather_list=output_tensors if my_rank == dst else None,
+ dst=dst,
+ )
+ if my_rank != dst:
+ return
+ for i, tensor in enumerate(output_tensors):
+ tensor = tensor.type(oneflow.uint8) # type: ignore[call-overload]
+ tensor_size = object_size_list[i]
+ object_gather_list[i] = _tensor_to_object(tensor, tensor_size)
+
+
+def _object_to_tensor(obj, device=None):
+ f = io.BytesIO()
+ obj = {"protocol_version": PROTOCOL_VERSION, "data": obj}
+ pickled_bytes = pickle.dumps(obj)
+
+ byte_tensor = oneflow.ByteTensor(list(pickled_bytes))
+ local_size = oneflow.LongTensor([byte_tensor.numel()])
+ if device is not None:
+ byte_tensor = byte_tensor.to(device)
+ local_size = local_size.to(device)
+ return byte_tensor, local_size
+
+def _tensor_to_object(tensor, tensor_size):
+ buf = tensor.detach().cpu().numpy().tobytes()[:tensor_size]
+ res = pickle.loads(buf)
+ assert res["protocol_version"] == PROTOCOL_VERSION
+ return res["data"]
+
+def send_recv_object(obj, src, cur_rank, device):
+ r"""
+ oneflow 中的单点对多点的分发函数;
+
+ 例如将进程 0 上的对象 object 分发到其它进程上;
+
+ Example::
+
+ cur_rank = int(os.environ.get('LOCAL_RANK', 0))
+
+ # 拿到 local_device
+
+ send_recv_object(object, 0, cur_rank, local_device)
+
+ :param obj: 一个可以序列化的 python 对象;
+ :param src: 从哪一个 rank 上发送到其它 rank;
+ :param cur_rank: 当前的进程的 rank 序号;
+ :param device: 当前的进程所在的设备;
+ :param group: 通信组,默认为 None;
+ :param tag: 将发送与远程接收匹配的标记;
+ :return:
+ """
+ # src rank send to all other ranks
+ size = oneflow.LongTensor([0]).to(device)
+
+ if cur_rank == src:
+ world_size = dist_env.get_world_size()
+ tensor, size = _object_to_tensor(obj)
+ tensor = tensor.to(device)
+ size = size.to(device)
+
+ # 首先同步 obj 的 size 的信息;
+ comm.broadcast(size, src)
+ for subrank in range(world_size):
+ if subrank != src:
+ comm.send(tensor=tensor, dst=subrank)
+ else:
+ comm.broadcast(size, src)
+ tensor = oneflow.ByteTensor([0] * size).to(device)
+ comm.recv(tensor=tensor, src=src)
+
+ return _tensor_to_object(tensor.cpu(), size)
+
+
+def _to_device(tensor, device):
+ return tensor.contiguous().to(device)
+
+
+def fastnlp_oneflow_all_gather(obj: Any, device=None) ->List:
+ """
+ 实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。
+
+ example::
+
+ >>> # rank 0
+ >>> obj = {'a': 1, 'b':[1, 2], 'c':{'d': 1}}
+ >>> # rank 1
+ >>> obj = {'a': 1, 'b':[1, 2], 'c':{'d': 2}}
+ >>> # after all_gather():
+ >>> result = [
+ {'a': 1, 'b':[1, 2], 'c':{'d': 1}},
+ {'a': 1, 'b':[1, 2], 'c':{'d': 2}}
+ ]
+
+ :param obj: 任意结构的数据,如果为 tensor ,需要保证每个显卡上的 tensor 的形状是一样的。如果传入的是非 tensor 对象都将直接进行
+ 序列化之后进行传输。
+ :param device: 当前该参数无意义。
+ :param group:
+ :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。
+ """
+ if int(os.environ.get(FASTNLP_NO_SYNC, "0")) == 2:
+ return [obj]
+
+ if isinstance(obj, oneflow.Tensor):
+ objs = [oneflow.zeros_like(obj) for _ in range(dist_env.get_world_size())]
+ comm.all_gather(objs, obj)
+ else:
+ objs = [None for _ in range(dist_env.get_world_size())]
+ # 防止 unpickle 的时候弄到发送的 gpu 上了
+ obj = apply_to_collection(obj, oneflow.Tensor, _to_device, device=oneflow.device("cpu"))
+ all_gather_object(objs, obj)
+ return objs
+
+
+def fastnlp_oneflow_broadcast_object(obj, src, device=None):
+ """
+ 将 src 上的 obj 对象广播到其它 rank 上。
+
+ :param obj: 需要发送的对象
+ :param src: 从哪里发出。
+ :param device:
+ :param group: 属于哪个通信 group
+ :return:
+ """
+ if int(os.environ.get(FASTNLP_NO_SYNC, "0")) == 2:
+ if src == dist_env.get_rank():
+ return obj
+ else:
+ return None
+
+ cur_rank = dist_env.get_rank()
+ if cur_rank == src:
+ # 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里
+ obj = apply_to_collection(obj, oneflow.Tensor, _to_device, device=oneflow.device("cpu"))
+ if device is None:
+ device = oneflow.cuda.current_device()
+ device = get_oneflow_device(device)
+
+ if cur_rank == src:
+ tensor, size = _object_to_tensor(obj, device=device)
+ else:
+ size = oneflow.LongTensor([0]).to(device)
+
+ comm.broadcast(size, src=src)
+ if cur_rank != src:
+ tensor = oneflow.empty(
+ size.int().item(), # type: ignore[arg-type]
+ dtype=oneflow.uint8,
+ device=device
+ )
+ comm.broadcast(tensor, src=src)
+
+ return _tensor_to_object(tensor, tensor_size=size.item())
+
+def all_gather_object(object_list, obj):
+ """
+
+ Example::
+ >>> # Note: Process group initialization omitted on each rank.
+ >>> # Assumes world_size of 3.
+ >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
+ >>> output = [None for _ in gather_objects]
+ >>> all_gather_object(output, gather_objects[dist.get_rank()])
+ >>> output
+ ['foo', 12, {1: 2}]
+
+ :param object_list:
+ :param obj:
+ :param group:
+ :return:
+ """
+ if int(os.environ.get(FASTNLP_NO_SYNC, "0")) == 2:
+ return [obj]
+
+ current_device = get_oneflow_device(oneflow.cuda.current_device())
+
+ input_tensor, local_size = _object_to_tensor(obj, device=current_device)
+
+ # Gather all local sizes. This is so that we can find the max size, and index
+ # until the correct size when deserializing the tensors.
+ group_size = dist_env.get_world_size()
+ object_sizes_tensor = oneflow.zeros(
+ group_size, dtype=oneflow.long, device=current_device
+ )
+ object_size_list = [
+ object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
+ ]
+ # Allgather tensor sizes
+ comm.all_gather(object_size_list, local_size)
+ max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
+ # Resize tensor to max size across all ranks.
+ input_tensor = input_tensor.reshape(max_object_size)
+ coalesced_output_tensor = oneflow.empty(
+ max_object_size * group_size, dtype=oneflow.uint8, device=current_device
+ )
+ # Output tensors are nonoverlapping views of coalesced_output_tensor
+ output_tensors = [
+ coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
+ for i in range(group_size)
+ ]
+ comm.all_gather(output_tensors, input_tensor)
+ # Deserialize outputs back to object.
+ for i, tensor in enumerate(output_tensors):
+ tensor = tensor.type(oneflow.uint8)
+ if tensor.device != oneflow.device("cpu"):
+ tensor = tensor.cpu()
+ tensor_size = object_size_list[i]
+ object_list[i] = _tensor_to_object(tensor, tensor_size)
+ return object_list
diff --git a/fastNLP/core/drivers/oneflow_driver/initialize_oneflow_driver.py b/fastNLP/core/drivers/oneflow_driver/initialize_oneflow_driver.py
new file mode 100644
index 00000000..dd0597b2
--- /dev/null
+++ b/fastNLP/core/drivers/oneflow_driver/initialize_oneflow_driver.py
@@ -0,0 +1,69 @@
+import os
+from typing import Optional, Union, List, Sequence
+from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW
+if _NEED_IMPORT_ONEFLOW:
+ import oneflow
+
+from .oneflow_driver import OneflowDriver
+from .single_device import OneflowSingleDriver
+from .ddp import OneflowDDPDriver
+from fastNLP.core.log import logger
+from fastNLP.envs import FASTNLP_BACKEND_LAUNCH
+
+__all__ = []
+
+
+def initialize_oneflow_driver(driver: str, device: Optional[Union[str, "oneflow.device", int, List[int]]],
+ model: "oneflow.nn.Module", **kwargs) -> OneflowDriver:
+ r"""
+ 用来根据参数 ``driver` 和 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去;
+
+ :param driver: 该参数的值应为以下之一:``["oneflow"]``
+ :param device: 该参数的格式与 ``Trainer`` 对参数 ``device`` 的要求一致
+ :param model: 训练或者评测的具体的模型;
+
+ :return: 一个 :class:`~fastNLP.core.OneflowSingleDriver` 或 :class:`~fastNLP.core.OneflowDDPDriver` 实例;
+ """
+ # world_size 和 rank
+ if FASTNLP_BACKEND_LAUNCH in os.environ:
+ if device is not None:
+ logger.rank_zero_warning("Parameter `device` would be ignored when you are using `oneflow.distributed.launch` to pull "
+ "up your script. ", once=True)
+ return OneflowDDPDriver(model, None, **kwargs)
+
+ if driver not in {"oneflow"}:
+ raise ValueError("Parameter `driver` can only be one of these values: ['oneflow'].")
+
+ _could_use_device_num = oneflow.cuda.device_count()
+ if isinstance(device, str):
+ device = oneflow.device(device)
+ elif isinstance(device, int):
+ if device < 0:
+ if device != -1:
+ raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.")
+ device = [oneflow.device(f"cuda:{w}") for w in range(_could_use_device_num)]
+ elif device >= _could_use_device_num:
+ raise ValueError("The gpu device that parameter `device` specifies is not existed.")
+ else:
+ device = oneflow.device(f"cuda:{device}")
+ elif isinstance(device, Sequence):
+ device = list(set(device))
+ for each in device:
+ if not isinstance(each, int):
+ raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be 'int' type.")
+ elif each < 0:
+ raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be bigger than 0.")
+ elif each >= _could_use_device_num:
+ raise ValueError(f"When parameter `device` is 'Sequence' type, the value in it should not be bigger than"
+ f" the available gpu number:{_could_use_device_num}.")
+ device = [oneflow.device(f"cuda:{w}") for w in device]
+ elif device is not None and not isinstance(device, oneflow.device):
+ raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.")
+
+ if driver == "oneflow": # single, ddp, 直接启动。
+ if not isinstance(device, List):
+ return OneflowSingleDriver(model, device, **kwargs)
+ else:
+ raise RuntimeError("If you want to run distributed training, please use "
+ "'python -m oneflow.distributed.launch xxx.py'.")
+ return OneflowDDPDriver(model, device, **kwargs)
\ No newline at end of file
diff --git a/fastNLP/core/drivers/oneflow_driver/oneflow_driver.py b/fastNLP/core/drivers/oneflow_driver/oneflow_driver.py
new file mode 100644
index 00000000..35d8d8bf
--- /dev/null
+++ b/fastNLP/core/drivers/oneflow_driver/oneflow_driver.py
@@ -0,0 +1,518 @@
+import os
+from typing import Union, Dict, Optional, Callable, Tuple
+from functools import partial
+import numpy as np
+import random
+from dataclasses import dataclass
+from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW
+from pathlib import Path
+if _NEED_IMPORT_ONEFLOW:
+ import oneflow
+ from oneflow.utils.data import DataLoader, Sampler, BatchSampler, Dataset
+ from oneflow.optim import Optimizer
+ from oneflow.utils.data import RandomSampler as OneflowRandomSampler
+ _reduces = {
+ "sum": oneflow.sum,
+ "min": oneflow.min,
+ "max": oneflow.max,
+ "mean": oneflow.mean
+ }
+
+
+__all__ = [
+ "OneflowDriver"
+]
+
+from .utils import optimizer_state_to_device, DummyGradScaler
+from fastNLP.core.drivers.driver import Driver
+from fastNLP.core.utils.utils import _get_fun_msg, nullcontext
+from fastNLP.core.utils import apply_to_collection, oneflow_move_data_to_device, auto_param_call
+from fastNLP.envs import rank_zero_call
+from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
+from fastNLP.core.log import logger
+from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler
+from fastNLP.core.dataloaders import OverfitDataLoader
+
+
+class OneflowDriver(Driver):
+ r"""
+ 实现了 **oneflow** 框架训练功能的基本 ``Driver``。这个类被以下子类继承:
+
+ 1. :class:`~fastNLP.core.drivers.oneflow_driver.OneflowSingleDriver` :实现了使用单卡和 ``cpu`` 训练的具体功能;
+ 2. :class:`~fastNLP.core.drivers.oneflow_driver.OneflowDDPDriver` :实现了使用 ``DistributedDataParallel`` 启动 **oneflow** 分布式训练的功能;
+
+ .. warning::
+
+ 您不应当直接初始化该类,然后传入给 ``Trainer``,换句话说,您应当使用该类的子类 ``OneflowSingleDriver`` 和 ``OneflowDDPDriver``,而不是
+ 该类本身。
+
+ .. note::
+
+ 您可以在使用 ``OneflowSingleDriver`` 和 ``OneflowDDPDriver`` 时使用 ``OneflowDriver`` 提供的接口。
+
+ :param model: 训练使用的模型
+ :param fp16: 该参数暂时无效
+ :param oneflow_kwargs:
+ """
+ def __init__(self, model, fp16: Optional[bool] = False, oneflow_kwargs: Dict = None, **kwargs):
+ super(OneflowDriver, self).__init__(model)
+
+ """ 进行 fp16 的设置 """
+ self._oneflow_kwargs = oneflow_kwargs if oneflow_kwargs is not None else {}
+
+ self.fp16 = fp16
+ if fp16:
+ logger.warn("OneflowDriver of eager mode dose not support fp16 now.``")
+ # self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not self.fp16)
+ # self.grad_scaler = _grad_scaler(**self._oneflow_kwargs.get("gradscaler_kwargs", {}))
+ self.auto_cast = nullcontext
+ self.grad_scaler = DummyGradScaler()
+ self.set_grad_to_none = self._oneflow_kwargs.get("set_grad_to_none")
+
+ self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)
+
+ def zero_grad(self):
+ """
+ 实现梯度置零的过程
+ """
+ for optimizer in self.optimizers:
+ optimizer.zero_grad(self.set_grad_to_none)
+
+ def backward(self, loss):
+ """
+ 对 ``loss`` 进行反向传播
+ """
+ loss.backward()
+ # self.grad_scaler.scale(loss).backward()
+
+ def step(self):
+ r"""
+ 实现参数的优化更新过程
+ """
+ for optimizer in self.optimizers:
+ self.grad_scaler.step(optimizer)
+ self.grad_scaler.update()
+
+ def check_dataloader_legality(self, dataloader):
+ """
+ 检测 DataLoader 是否合法。支持的类型包括 :class:`~fastNLP.core.dataloaders.OneflowDataLoader`、 :class:`oneflow.utils.data.DataLoader` 。
+
+ :param dataloder:
+ """
+ if not isinstance(dataloader, DataLoader) and not isinstance(dataloader, OverfitDataLoader):
+ raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`")
+ if len(dataloader) == 0:
+ logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it "
+ "may cause some unexpected exceptions.", once=True)
+
+ @staticmethod
+ def _check_optimizer_legality(optimizers):
+ for each_optimizer in optimizers:
+ if not isinstance(each_optimizer, Optimizer):
+ raise TypeError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, "
+ f"not {type(each_optimizer)}.")
+
+ @staticmethod
+ def tensor_to_numeric(tensor, reduce: str = None):
+ r"""
+ 将 ``oneflow.Tensor`` 转换成 python 中的数值类型。
+
+ :param tensor: ``oneflow.Tensor``
+ :param reduce: 当 tensor 是一个多数值的张量时,应当使用何种归一化操作来转换成单一数值,应当为以下类型之一:``['max', 'min', 'sum', 'mean']``
+ :return: 一个单一数值,其数值类型是 python 中的基本的数值类型,例如 ``int,float`` 等。
+ """
+
+ if tensor is None:
+ return None
+
+ def _translate(_data):
+ if _data.numel() == 1:
+ return _data.item()
+ if reduce is None:
+ return _data.tolist()
+ return _reduces[reduce](_data).item()
+
+ return apply_to_collection(
+ data=tensor,
+ dtype=oneflow.Tensor,
+ function=_translate
+ )
+
+ def set_model_mode(self, mode: str):
+ r"""
+ 设置模型为 ``train`` 或 ``eval`` 的模式;目的是为切换模型的训练和推理(会关闭 dropout 等)模式。
+
+ :param mode: 应为二者之一:``["train", "eval"]``
+ """
+ assert mode in {"train", "eval"}
+ getattr(self.model, mode)()
+
+ @rank_zero_call
+ def save_model(self, filepath: Union[str, Path], only_state_dict: bool = True, **kwargs):
+ """
+ 保存当前 driver 的模型到 ``filepath``。
+
+ :param filepath: 保存文件的文件位置
+ :param only_state_dict: 是否只保存权重;如果使用 ``DistributedDataParallel`` 启动分布式训练的话,该参数只能为 ``True``
+ :return:
+ """
+ model = self.unwrap_model()
+ if not only_state_dict and self.is_distributed():
+ logger.warn("`Cannot save ddp model directly, we will save its state_dict for you.")
+ only_state_dict = True
+
+ if only_state_dict:
+ states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()}
+ oneflow.save(states, filepath)
+ else:
+ if self.model_device is not None:
+ if not self.is_distributed():
+ self.move_model_to_device(model, oneflow.device("cpu"))
+ oneflow.save(model, filepath)
+ if not self.is_distributed():
+ self.move_model_to_device(model, self.model_device)
+ else:
+ oneflow.save(model, filepath)
+
+ def load_model(self, filepath: Union[Path, str], only_state_dict: bool = True, **kwargs):
+ """
+ 加载模型的函数;将 ``filepath`` 中的模型加载并赋值给当前 ``model`` 。
+
+ :param filepath: 保存文件的文件位置
+ :param load_state_dict: 保存的内容是否只是权重;如果使用 ``DistributedDataParallel`` 启动分布式训练的话,
+ 该参数只能为 ``True``
+ """
+ model = self.unwrap_model()
+ res = oneflow.load(filepath)
+ if isinstance(res, dict) and only_state_dict is False:
+ logger.rank_zero_warning(f"It seems like that {filepath} only contains state, you may need to use "
+ f"`only_state_dict=True`")
+ elif not isinstance(res, dict) and only_state_dict is True:
+ logger.rank_zero_warning(f"It seems like that {filepath} is not state, you may need to use "
+ f"`only_state_dict=False`")
+ if not isinstance(res, dict):
+ res = res.state_dict()
+ _strict = kwargs.get("strict")
+ model.load_state_dict(res, _strict)
+
+ @rank_zero_call
+ def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
+ r"""
+ 断点重训的保存函数,该函数会负责保存 **优化器** 和 **sampler** 的状态,以及 **模型** (若 ``should_save_model`` 为 ``True``)
+
+ :param folder: 保存断点重训的状态的文件夹;:meth:`save_checkpoint` 函数应该在该路径下面下面新增名为 ``FASTNLP_CHECKPOINT_FILENAME`` 与
+ ``FASTNLP_MODEL_FILENAME`` (如果 ``should_save_model`` 为 ``True`` )的文件。把 model 相关的内容放入到 ``FASTNLP_MODEL_FILENAME`` 文件
+ 中,将传入的 ``states`` 以及自身产生的其它状态一并保存在 ``FASTNLP_CHECKPOINT_FILENAME`` 里面。
+ :param states: 由 :class:`~fastNLP.core.controllers.Trainer` 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态。
+ :param dataloader: 正在使用的 dataloader。
+ :param only_state_dict: 是否只保存模型的参数,当 ``should_save_model`` 为 ``False`` ,该参数无效。
+ :param should_save_model: 是否应该保存模型,如果为 ``False`` ,Driver 将不负责 model 的保存。
+ """
+ # 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变
+ # trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境;
+
+ # 1. sampler 的状态;
+ num_consumed_batches = states.pop("num_consumed_batches")
+ states["sampler_states"] = self.get_sampler_state(dataloader, num_consumed_batches)
+
+ # 2. 保存模型的状态;
+ if should_save_model:
+ if not os.path.exists(folder):
+ os.mkdir(folder)
+ model_path = folder.joinpath(FASTNLP_MODEL_FILENAME)
+ self.save_model(model_path, only_state_dict=only_state_dict)
+
+ # 3. 保存 optimizers 的状态;
+ states["optimizers_state_dict"] = self.get_optimizer_state()
+ logger.debug("Save optimizer state dict.")
+
+ # # 4. 保存fp16的状态
+ # if not isinstance(self.grad_scaler, DummyGradScaler):
+ # grad_scaler_state_dict = self.grad_scaler.state_dict()
+ # states['grad_scaler_state_dict'] = grad_scaler_state_dict
+
+ oneflow.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME))
+
+ def get_sampler_state(self, dataloader, num_consumed_batches):
+ dataloader_args = self.get_dataloader_args(dataloader)
+ if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
+ sampler = dataloader_args.batch_sampler
+ elif dataloader_args.sampler:
+ sampler = dataloader_args.sampler
+ else:
+ raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.")
+
+ if hasattr(sampler, "state_dict") and callable(sampler.state_dict):
+ sampler_states = sampler.state_dict()
+ if dataloader_args.batch_size is not None:
+ sampler_states["num_consumed_samples"] = sampler.num_replicas * dataloader_args.batch_size \
+ * num_consumed_batches
+ else:
+ logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on sampler's "
+ "`num_consumed_samples`, it may cause missing some samples when reload.")
+ else:
+ raise RuntimeError("The sampler has no `state_dict()` method, fastNLP cannot save the training "
+ "state.")
+
+ return sampler_states
+
+ def load_sampler_state(self, dataloader, sampler_states):
+ states = {}
+ dataloader_args = self.get_dataloader_args(dataloader)
+ if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
+ sampler = dataloader_args.batch_sampler
+ elif isinstance(dataloader_args.sampler, ReproducibleSampler):
+ sampler = dataloader_args.sampler
+ elif isinstance(dataloader_args.sampler, OneflowRandomSampler):
+ sampler = RandomSampler(dataloader_args.sampler.data_source)
+ logger.debug("Replace oneflow RandomSampler into fastNLP RandomSampler.")
+ elif self.is_distributed():
+ raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our"
+ "`ReproducibleSampler`.")
+ else:
+ sampler = ReproduceBatchSampler(
+ batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
+ batch_size=dataloader_args.batch_size,
+ drop_last=dataloader_args.drop_last
+ )
+ sampler.load_state_dict(sampler_states)
+ states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)
+
+ # 修改 trainer_state.batch_idx_in_epoch
+ # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler;
+ if not isinstance(sampler, ReproducibleBatchSampler):
+ if dataloader_args.drop_last:
+ batch_idx_in_epoch = len(
+ sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size
+ else:
+ batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \
+ (sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size
+ # sampler 是 batch_sampler;
+ else:
+ batch_idx_in_epoch = sampler.batch_idx_in_epoch
+
+ states["batch_idx_in_epoch"] = batch_idx_in_epoch
+ return states
+
+ def get_optimizer_state(self):
+ optimizers_state_dict = {}
+ for i in range(len(self.optimizers)):
+ optimizer: oneflow.optim.Optimizer = self.optimizers[i]
+ optimizer_state = optimizer.state_dict()
+ optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], oneflow.device("cpu"))
+ optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的;
+ return optimizers_state_dict
+
+ def load_optimizer_state(self, states):
+ assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \
+ f"checkpoint it is:{len(states)}"
+ for i in range(len(self.optimizers)):
+ optimizer: oneflow.optim.Optimizer = self.optimizers[i]
+ optimizer.load_state_dict(states[f"optimizer{i}"])
+ logger.debug("Load optimizer state dict.")
+
+ def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
+ r"""
+ 断点重训的加载函数,该函数会负责读取数据,并且恢复 **优化器** 、**sampler** 的状态和 **模型** (如果 ``should_load_model`` 为 True)以及其它
+ 在 :meth:`save_checkpoint` 函数中执行的保存操作,然后将一个 state 字典返回给 :class:`~fastNLP.core.controllers.Trainer` ( 内容为 :meth:`save_checkpoint`
+ 接受到的 ``states`` )。
+
+ 该函数应该在所有 rank 上执行。
+
+ :param folder: 读取该 folder 下的 ``FASTNLP_CHECKPOINT_FILENAME`` 文件与 ``FASTNLP_MODEL_FILENAME``
+ (如果 should_load_model 为True)。
+ :param dataloader: 当前给定 dataloader,需要根据保存的 dataloader 状态合理设置。若该值为 ``None`` ,则不需要返回 ``'dataloader'``
+ 以及 ``'batch_idx_in_epoch'`` 这两个值。
+ :param only_state_dict: 是否仅读取模型的 state_dict ,当 ``should_save_model`` 为 ``False`` ,该参数无效。如果为 ``True`` ,说明保存的内容为权重;如果为
+ False 说明保存的是模型,但也是通过当前 Driver 的模型去加载保存的模型的权重,而不是使用保存的模型替换当前模型。
+ :param should_load_model: 是否应该加载模型,如果为 ``False`` ,Driver 将不负责加载模型。若该参数为 ``True`` ,但在保存的状态中没有
+ 找到对应的模型状态,则报错。
+ :return: :meth:`save_checkpoint` 函数输入的 ``states`` 内容。除此之外,还返回的内容有:
+
+ * *dataloader* -- 根据传入的 ``dataloader`` 与读取出的状态设置为合理状态的 dataloader。在当前 ``dataloader`` 样本数与读取出的 sampler 样本数
+ 不一致时报错。
+ * *batch_idx_in_epoch* -- :class:`int` 类型的数据,表明当前 epoch 进行到了第几个 batch 。请注意,该值不能仅通过保存的数据中读取的,因为前后两次运行的
+ ``batch_size`` 可能有变化,而应该符合以下等式::
+
+ 返回的 dataloader 还会产生的 batch 数量 + batch_idx_in_epoch = 原来不断点训练时的 batch 的总数
+
+ 由于 ``返回的 dataloader 还会产生的batch数`` 在 ``batch_size`` 与 ``drop_last`` 参数给定的情况下,无法改变,因此只能通过调整 ``batch_idx_in_epoch``
+ 这个值来使等式成立。一个简单的计算原则如下:
+
+ * drop_last 为 ``True`` 时,等同于 floor(sample_in_this_rank/batch_size) - floor(num_left_samples/batch_size);
+ * drop_last 为 ``False`` 时,等同于 ceil(sample_in_this_rank/batch_size) - ceil(num_left_samples/batch_size)。
+ """
+ states = oneflow.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))
+
+ # 1. 加载 optimizers 的状态;
+ optimizers_state_dict = states.pop("optimizers_state_dict")
+ self.load_optimizer_state(optimizers_state_dict)
+
+ # 2. 加载模型状态;
+ if should_load_model:
+ self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict)
+
+ # # 3. 加载 fp16 的状态
+ # if "grad_scaler_state_dict" in states:
+ # grad_scaler_state_dict = states.pop("grad_scaler_state_dict")
+ # if not isinstance(self.grad_scaler, DummyGradScaler):
+ # self.grad_scaler.load_state_dict(grad_scaler_state_dict)
+ # logger.debug("Load grad_scaler state dict...")
+ # elif not isinstance(self.grad_scaler, DummyGradScaler):
+ # logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "
+ # f"the training process may be unstable.")
+
+ # 4. 恢复 sampler 的状态;
+ sampler_states = states.pop("sampler_states")
+ states_ret = self.load_sampler_state(dataloader, sampler_states)
+ states.update(states_ret)
+
+ return states
+
+ def get_evaluate_context(self):
+ r"""
+ 返回一个不计算梯度的上下文环境用来对模型进行评测。
+
+ :return: 上下文对象 ``oneflow.no_grad``
+ """
+ return oneflow.no_grad
+
+ def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
+ if isinstance(batch, Dict) and not self.wo_auto_param_call:
+ return auto_param_call(fn, batch, signature_fn=signature_fn)
+ else:
+ return fn(batch)
+
+ def get_model_call_fn(self, fn: str) -> Tuple:
+ if hasattr(self.model, fn):
+ fn = getattr(self.model, fn)
+ if not callable(fn):
+ raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
+ logger.debug(f"Use {_get_fun_msg(fn, with_fp=False)}...")
+ return fn, None
+ elif fn in {"train_step", "evaluate_step"}:
+ logger.debug(f"Use {_get_fun_msg(self.model.forward, with_fp=False)}...")
+ return self.model, self.model.forward
+ else:
+ raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")
+
+ @staticmethod
+ def move_model_to_device(model: "oneflow.nn.Module", device: "oneflow.device"):
+ r"""
+ 将模型迁移到对应的设备上。
+ """
+ if device is not None:
+ model.to(device)
+
+ def move_data_to_device(self, batch):
+ """
+ 将一个 ``batch`` 的数据迁移到对应的设备上。
+
+ :param batch: 包含 :class:`oneflow.Tensor` 的数据集合,可以是 **List**、**Dict** 等嵌套类型
+ :return: 移动到指定机器后的 ``batch``
+ """
+ return oneflow_move_data_to_device(batch, self.data_device)
+
+ @staticmethod
+ def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover
+ global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))
+ process_seed = oneflow.initial_seed()
+
+ base_seed = process_seed - worker_id
+ ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
+
+ np.random.seed(ss.generate_state(4))
+
+ oneflow_ss, stdlib_ss = ss.spawn(2)
+ oneflow.manual_seed(oneflow_ss.generate_state(1, dtype=np.uint64)[0])
+
+ stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum()
+ random.seed(stdlib_seed)
+
+ def set_deterministic_dataloader(self, dataloader: "DataLoader"):
+ """
+ 为了确定性训练要对 ``dataloader`` 进行修改,保证在确定随机数种子后,每次重新训练得到的结果是一样的。
+ """
+ if dataloader.worker_init_fn is None:
+ dataloader.worker_init_fn = partial(self.worker_init_function,
+ rank=int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)))
+
+ def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx: int):
+ r"""
+ 对于分布式的 ``sampler``,需要在每一个 ``epoch`` 前设置随机数种子,来保证每一个进程上的 ``shuffle`` 是一样的。
+
+ :param dataloader: 需要设置 ``epoch`` 的 ``dataloader``
+ :param cur_epoch_idx: 当前是第几个 ``epoch``
+ """
+ # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的;
+ if callable(getattr(dataloader.sampler, "set_epoch", None)):
+ dataloader.sampler.set_epoch(cur_epoch_idx)
+
+ @staticmethod
+ def get_dataloader_args(dataloader: "DataLoader"):
+ """
+ 从 ``dataloader`` 中获取参数 ``dataset``, ``batch_sampler``, ``sampler``, ``batch_size``, ``shuffle``
+ 和 ``drop_last`` 。
+ """
+ @dataclass
+ class Res:
+ dataset: Optional[Dataset] = None
+ batch_sampler: Optional[BatchSampler] = None
+ sampler: Optional[Sampler] = None
+ batch_size: Optional[int] = None
+ shuffle: Optional[bool] = None
+ drop_last: Optional[bool] = None
+
+ res = Res()
+
+ # oneflow 的 DataLoader 一定会有 dataset 属性;
+ res.dataset = dataloader.dataset
+
+ # dataloader 使用的是 sampler;
+ if dataloader.batch_sampler is None:
+ res.sampler = dataloader.sampler
+ res.batch_size = 1
+ res.shuffle = True if isinstance(dataloader.sampler, RandomSampler) else False
+ res.drop_last = False
+ # dataloader 使用的是 batch_sampler;
+ else:
+ res.batch_sampler = dataloader.batch_sampler
+ if hasattr(dataloader.batch_sampler, "batch_size"):
+ res.batch_size = getattr(dataloader.batch_sampler, "batch_size")
+ # 用户使用的是自己的 batch_sampler 并且其没有 "batch_size" 属性;
+ else:
+ dataloader_iter = iter(dataloader)
+ pre_sample = next(dataloader_iter)
+ res.batch_size = pre_sample.shape[0]
+
+ if hasattr(dataloader.batch_sampler, "sampler"):
+ res.sampler = dataloader.batch_sampler.sampler
+ if hasattr(dataloader.batch_sampler.sampler, "shuffle"):
+ res.shuffle = dataloader.batch_sampler.sampler.shuffle
+ elif isinstance(dataloader.batch_sampler.sampler, OneflowRandomSampler):
+ res.shuffle = True
+ else:
+ res.shuffle = False
+ # ReproduceBatchSampler 的情况
+ elif hasattr(dataloader.batch_sampler, "batch_sampler"):
+ batch_sampler = dataloader.batch_sampler.batch_sampler
+ res.sampler = batch_sampler.sampler
+ if hasattr(batch_sampler.sampler, "shuffle"):
+ res.shuffle = dataloader.batch_sampler.sampler.shuffle
+ elif isinstance(batch_sampler.sampler, OneflowRandomSampler):
+ res.shuffle = True
+ else:
+ res.shuffle = False
+ else:
+ # 如果 dataloader.batch_sampler 没有 sampler 这个属性,那么说明其使用的是自己的 batch_sampler,且没有 "sampler" 属性;
+ # 这种情况下 DataLoader 会自己初始化一个 sampler;我们因此将这个默认初始化的 sampler 挂载到 res 上;
+ res.sampler = dataloader.sampler
+ res.shuffle = False
+
+ if hasattr(dataloader.batch_sampler, "drop_last"):
+ res.drop_last = getattr(dataloader.batch_sampler, "drop_last")
+ # 用户使用的是自己的 batch_sampler 并且其没有 "drop_last" 属性;
+ else:
+ res.drop_last = False
+
+ return res
diff --git a/fastNLP/core/drivers/oneflow_driver/single_device.py b/fastNLP/core/drivers/oneflow_driver/single_device.py
new file mode 100644
index 00000000..078b36b3
--- /dev/null
+++ b/fastNLP/core/drivers/oneflow_driver/single_device.py
@@ -0,0 +1,121 @@
+import os
+from typing import Dict, Union
+from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW
+
+if _NEED_IMPORT_ONEFLOW:
+ import oneflow
+ from oneflow.utils.data import SequentialSampler as OneflowSequentialSampler
+ from oneflow.utils.data import BatchSampler as OneflowBatchSampler
+
+__all__ = [
+ "OneflowSingleDriver"
+]
+
+from .oneflow_driver import OneflowDriver
+from fastNLP.core.drivers.oneflow_driver.utils import replace_sampler, replace_batch_sampler
+from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \
+ ReproduceBatchSampler
+from fastNLP.core.samplers import RandomSampler
+from fastNLP.core.log import logger
+
+
+class OneflowSingleDriver(OneflowDriver):
+ r"""
+ 用于执行 ``oneflow`` 动态图 cpu 和 单卡 gpu 运算的 ``driver``。
+
+ :param model: 传入给 ``Trainer`` 的 ``model`` 参数
+ :param device: oneflow.device,当前进程所使用的设备
+ :param fp16: 是否开启 fp16;目前动态图的单卡下该参数无效。
+ :param oneflow_kwargs:
+ :kwargs:
+ * *model_wo_auto_param_call* (``bool``) -- 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为。
+
+ .. note::
+
+ 关于该参数的详细说明,请参见 :class:`~fastNLP.core.controllers.Trainer` 中的描述;函数 ``auto_param_call`` 详见 :func:`fastNLP.core.utils.auto_param_call`。
+
+ """
+
+ def __init__(self, model, device: "oneflow.device", fp16: bool = False, oneflow_kwargs: Dict = None, **kwargs):
+ cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
+ if cuda_visible_devices == "":
+ device = oneflow.device("cpu")
+ logger.info("You have set `CUDA_VISIBLE_DEVICES` to '' in system environment variable, and we are gonna to"
+ "use `cpu` instead of `gpu` device.")
+
+ super(OneflowSingleDriver, self).__init__(model, fp16=fp16, oneflow_kwargs=oneflow_kwargs, **kwargs)
+
+ if device is None:
+ logger.debug("device is not set, fastNLP will try to automatically get it.")
+ try:
+ device = next(model.parameters()).device
+ assert isinstance(device, oneflow.device)
+ except:
+ raise ValueError("fastNLP cannot get device automatically, please set device explicitly.")
+
+ self.model_device = device
+
+ self.local_rank = 0
+ self.global_rank = 0
+ self.world_size = 1
+
+ def setup(self):
+ r"""
+ 将模型迁移到相应的设备上。
+ """
+ if self.model_device is not None:
+ self.model.to(self.model_device)
+
+ def set_dist_repro_dataloader(self, dataloader,
+ dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None,
+ reproducible: bool = False):
+
+ # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load_checkpoint 函数调用;
+ if isinstance(dist, ReproducibleBatchSampler):
+ return replace_batch_sampler(dataloader, dist)
+ elif isinstance(dist, ReproducibleSampler):
+ return replace_sampler(dataloader, dist)
+
+ # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用;
+ args = self.get_dataloader_args(dataloader)
+ if isinstance(args.batch_sampler, ReproducibleBatchSampler):
+ batch_sampler = re_instantiate_sampler(args.batch_sampler)
+ return replace_batch_sampler(dataloader, batch_sampler)
+ elif isinstance(args.sampler, ReproducibleSampler):
+ sampler = re_instantiate_sampler(args.sampler)
+ return replace_sampler(dataloader, sampler)
+
+ if reproducible:
+ if type(args.batch_sampler) is OneflowBatchSampler:
+ if type(args.sampler) is OneflowSequentialSampler:
+ # 需要替换为不要 shuffle 的。
+ sampler = RandomSampler(args.sampler.data_source, shuffle=False)
+ logger.debug("Replace oneflow SequentialSampler into fastNLP RandomSampler.")
+ return replace_sampler(dataloader, sampler)
+ batch_sampler = ReproduceBatchSampler(
+ batch_sampler=args.batch_sampler,
+ batch_size=args.batch_size,
+ drop_last=args.drop_last
+ )
+ return replace_batch_sampler(dataloader, batch_sampler)
+ else:
+ return dataloader
+
+ def unwrap_model(self):
+ r"""
+ :return: 训练使用的模型
+ """
+ return self.model
+
+ @property
+ def data_device(self):
+ r"""
+ :return: 数据和模型所在的设备。
+ """
+ return self.model_device
+
+ def is_distributed(self):
+ r"""
+ :return: 当前使用的 driver 是否是分布式的 driver,在 ``OneflowSingleDriver`` 中返回 ``False``。
+ """
+ return False
diff --git a/fastNLP/core/drivers/oneflow_driver/utils.py b/fastNLP/core/drivers/oneflow_driver/utils.py
new file mode 100644
index 00000000..175c7714
--- /dev/null
+++ b/fastNLP/core/drivers/oneflow_driver/utils.py
@@ -0,0 +1,302 @@
+import os
+
+from typing import Any, Dict, Optional
+from enum import IntEnum
+import contextlib
+import random
+import numpy as np
+import inspect
+
+from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW
+from fastNLP.envs.utils import get_global_seed
+from fastNLP.envs import (
+ get_global_rank,
+ FASTNLP_BACKEND_LAUNCH,
+ FASTNLP_GLOBAL_SEED,
+)
+from fastNLP.core.samplers import ReproducibleBatchSampler
+from fastNLP.core.utils import auto_param_call
+from fastNLP.core.log import logger
+
+if _NEED_IMPORT_ONEFLOW:
+ import oneflow
+ from oneflow.nn import Module
+ from oneflow.utils.data import DataLoader
+ from oneflow.utils.data import RandomSampler as oneflowRandomSampler
+ from oneflow.utils.data import SequentialSampler as oneflowSequentialSampler
+ from oneflow.utils.data import BatchSampler as oneflowBatchSampler
+else:
+ from fastNLP.core.utils.dummy_class import DummyClass as Module
+
+
+__all__ = [
+ 'oneflow_seed_everything',
+]
+
+def oneflow_seed_everything(seed: int = None, add_global_rank_to_seed: bool = True) -> int:
+ r"""
+ 为 **oneflow**、**numpy**、**python.random** 伪随机数生成器设置种子。
+
+ :param seed: 全局随机状态的整数值种子。如果为 ``None`` 则会根据时间戳生成一个种子。
+ :param add_global_rank_to_seed: 在分布式训练中,是否在不同 **rank** 中使用不同的随机数。
+ 当设置为 ``True`` 时,**fastNLP** 会将种子加上当前的 ``global_rank``。
+ """
+ max_seed_value = np.iinfo(np.uint32).max
+ min_seed_value = np.iinfo(np.uint32).min
+
+ if seed is None:
+ if os.getenv(FASTNLP_BACKEND_LAUNCH) == "1":
+ seed = 42
+ else:
+ seed = get_global_seed()
+ logger.info(f"'FASTNLP_GLOBAL_SEED' is set to {seed} automatically.")
+ if not isinstance(seed, int):
+ seed = int(seed)
+
+ if not (min_seed_value <= seed <= max_seed_value):
+ logger.rank_zero_warning("Your seed value is too big or too small for numpy, we will choose a random seed for you.")
+ seed %= max_seed_value
+
+ os.environ[FASTNLP_GLOBAL_SEED] = f"{seed}"
+ if add_global_rank_to_seed:
+ seed += get_global_rank()
+
+ random.seed(seed)
+ np.random.seed(seed)
+ oneflow.manual_seed(seed)
+ oneflow.cuda.manual_seed_all(seed)
+ return seed
+
+
+class ForwardState(IntEnum):
+ TRAIN = 0
+ VALIDATE = 1
+ TEST = 2
+ PREDICT = 3
+
+
+class _DDPWrappingModel(Module):
+ """
+ 该函数用于 DDP 训练时处理用户自己定制的 train_step 等函数;
+ 之所以要使用这一额外的包裹模型,是因为在使用 DDP 时,必须使用 DistributedDataParallel 的 forward 函数才能实现正常的运行;
+ 另一方面,我们要求用户在使用我们的框架时,需要针对不用的模式实现不同的处理函数,例如 'train_step', 'evaluate_step' 等;
+ 然而,当使用 DistributedDataParallel 包裹 model 后,模型看不见其除了 forward 之外的方法;并且当我们尝试在训练过程中主动提取
+ `model = model.module`,这同样会导致错误,会使得每一个gpu上的模型参数不同;
+
+ 因此出于以上考虑,我们实现了这一函数;
+ 对于更详细的解释,可以参考 'pytorch_lightning' 的 ddp 的设计;
+ """
+
+ def __init__(self, model: Module):
+ super(_DDPWrappingModel, self).__init__()
+ self.model = model
+
+ def forward(self, batch, **kwargs) -> Dict:
+ """
+ pytorch lightning 实现了先 unwrapping_model 的操作,但是感觉对于我们来说没有什么必须要,先写个注释放这里,之后有需求了再看;
+ """
+ fn = kwargs.pop("fastnlp_fn")
+ signature_fn = kwargs.pop("fastnlp_signature_fn")
+ wo_auto_param_call = kwargs.pop("wo_auto_param_call")
+
+ if isinstance(batch, Dict) and not wo_auto_param_call:
+ return auto_param_call(fn, batch, signature_fn=signature_fn)
+ else:
+ return fn(batch)
+
+
+class DummyGradScaler:
+
+ def __init__(self, *args, **kwargs):
+ pass
+
+ def get_scale(self):
+ return 1.0
+
+ def is_enabled(self):
+ return False
+
+ def scale(self, outputs):
+ return outputs
+
+ def step(self, optimizer, *args, **kwargs):
+ optimizer.step(*args, **kwargs)
+
+ def update(self, new_scale=None):
+ pass
+
+ def unscale_(self, optimizer):
+ pass
+
+ def load_state_dict(self, state_dict):
+ pass
+
+ def state_dict(self):
+ return {}
+
+
+def _build_fp16_env(dummy=False):
+ return
+ if dummy:
+ autocast = contextlib.ExitStack
+ GradScaler = DummyGradScaler
+ else:
+ if not oneflow.cuda.is_available():
+ raise RuntimeError("Oneflow is not installed in gpu version, please use device='cpu'.")
+ if oneflow.cuda.get_device_capability(0)[0] < 7:
+ logger.rank_zero_warning(
+ "NOTE: your device does NOT support faster training with fp16, "
+ "please switch to FP32 which is likely to be faster"
+ )
+ try:
+ from oneflow.amp import GradScaler
+ from oneflow.cuda.amp import autocast, GradScaler
+ except ImportError:
+ raise RuntimeError("torch version too low (less than 1.6)")
+ return autocast, GradScaler
+
+
+def replace_sampler(dataloader: "DataLoader", sampler):
+ r"""
+ 替换 sampler (初始化一个新的 dataloader 的逻辑在于):
+
+ 用户可能继承了 dataloader,定制了自己的 dataloader 类,这也是我们为什么先 `inspect.signature(dataloader)` 而不是直接
+ `inspect.signature(DataLoader)` 的原因,因此同时注意到我们在外层重新初始化一个 dataloader 时也是使用的用户传进来的 dataloader
+ 的类,而不是直接的 DataLoader;
+
+ 如果需要定制自己的 dataloader,保证以下两点:
+
+ 1. 在 __init__ 方法中加入 **kwargs,这是为了方便我们将 sampler 插入到具体的 DataLoader 的构造中;
+ 2. 在 __init__ 方法中出现的参数,请务必挂为同样名字的实例属性,例如 self.one_arg_name = one_arg_name,这是因为我们只能通过属性
+ 来获取实际的参数的值;
+
+ """
+
+ # 拿到实例属性;
+ instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')}
+
+ # 'multiprocessing_context' 是 user-defined function;
+ if getattr(dataloader, 'multiprocessing_context', None) is not None:
+ instance_attrs["multiprocessing_context"] = dataloader.multiprocessing_context
+
+ # 拿到 dataloader '__init__' 函数的默认函数签名;
+ init_params = dict(inspect.signature(dataloader.__init__).parameters)
+
+ # 防止用户的 DataLoader 是继承了 oneflow 的 DataLoader,然后还是使用了 **kwargs 的方式对父类传参数
+ has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items())
+ if has_variadic_kwargs and isinstance(dataloader, DataLoader):
+ # 防止用户写入了 super().__init__(**kwargs)
+ for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items():
+ if key not in init_params and key != 'self':
+ init_params[key] = value
+
+ # 如果初始化dataloader所使用的参数不是默认值,那么我们需要将其记录下来用于重新初始化时设置;
+ non_default_params = {name for name, p in init_params.items() if
+ name in instance_attrs and p.default != instance_attrs[name]}
+ # add `dataset` as it might have been replaced with `*args`
+ non_default_params.add("dataset")
+
+ reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params}
+ if isinstance(dataloader, DataLoader):
+ reconstruct_args.update({"sampler": sampler, "shuffle": False, "batch_sampler": None})
+
+ batch_sampler = getattr(dataloader, "batch_sampler")
+ if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler):
+ raise RuntimeError("It should not be running here, please report a bug to us.")
+
+ required_args = {
+ p.name
+ for p in init_params.values()
+ if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
+ and p.default is p.empty
+ and p.name not in reconstruct_args
+ }
+
+ # 在 attribute 中没有找到这些参数,导致了没有办法重新初始化
+ if required_args:
+ required_args = sorted(required_args)
+ dataloader_self_name = dataloader.__class__.__name__
+ raise Exception(
+ f"Need to inject arguments {required_args} into the __init__ of `{dataloader_self_name}`. "
+ f"But they are not found in the attribute of `{dataloader_self_name}`, fastNLP cannot determine its "
+ f"value when try to reinitialize `{dataloader_self_name}`, please add `{required_args}` to be "
+ f"`{dataloader_self_name}`'s attribute."
+ )
+
+ # 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs;
+ if not has_variadic_kwargs:
+ # the dataloader signature does not allow keyword arguments that need to be passed
+ missing_kwargs = reconstruct_args.keys() - init_params.keys()
+ if missing_kwargs:
+ missing_kwargs = sorted(missing_kwargs)
+ dataloader_self_name = dataloader.__class__.__name__
+ raise Exception(
+ f"The parameter:{missing_kwargs} needed to reinitialize `{dataloader_self_name}` is not found."
+ )
+ # 如果没有kwargs,则保证一下只传入需要的参数
+ if not isinstance(dataloader, DataLoader):
+ reconstruct_args = {key:value for key,value in reconstruct_args.items() if key in init_params}
+
+ return type(dataloader)(**reconstruct_args)
+
+
+def replace_batch_sampler(dataloader, new_batch_sampler):
+ r"""
+ 替换一个 dataloader 的 batch_sampler;
+ """
+ params_keys = [k for k in dataloader.__dict__.keys() if not k.startswith("_")]
+ for k in ["batch_size", "sampler", "drop_last", "batch_sampler", "dataset_kind"]:
+ if k in params_keys:
+ params_keys.remove(k)
+ params = {k: getattr(dataloader, k) for k in params_keys}
+ params["batch_sampler"] = new_batch_sampler
+
+ if not isinstance(dataloader, DataLoader):
+ init_params = dict(inspect.signature(dataloader.__init__).parameters)
+ has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items())
+ if not has_variadic_kwargs:
+ params = {key:value for key,value in params.items() if key in init_params}
+
+ return type(dataloader)(**params)
+
+
+def optimizer_state_to_device(state, device):
+ r"""
+ 将一个 ``optimizer`` 的 ``state_dict`` 迁移到对应的设备。
+
+ :param state: :func:`optimzier.state_dict` 获取的 state_dictt
+ :param device: 要迁移到的目的设备。
+ :return: 迁移后的新的 state_dict。
+ """
+ new_state = {}
+ for name, param in state.items():
+ if isinstance(param, dict):
+ new_state[name] = optimizer_state_to_device(param, device)
+ elif isinstance(param, oneflow.Tensor):
+ new_state[name] = param.to(device).clone()
+ else:
+ new_state[name] = param
+ return new_state
+
+
+def _check_dataloader_args_for_distributed(args, controller='Trainer'):
+ """
+ 检查 dataloader 的 sampler 情况,如果用户替换了自己定制的 sampler ,为了防止
+ 在分布式训练中出现错误会报错。
+ """
+ error_flag = (type(args.sampler) not in {oneflowRandomSampler, oneflowSequentialSampler})
+ if controller == 'Trainer':
+ mode = 'training'
+ substitution = 'fastNLP.RandomSampler'
+ error_flag = (type(args.batch_sampler) != oneflowBatchSampler) or error_flag
+ else: # Evaluator
+ mode = 'evaluation'
+ substitution = 'fastNLP.UnrepeatedSequentialSampler'
+ if error_flag:
+ raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause "
+ f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into "
+ f"``{substitution}``. The customized sampler should set for distributed running "
+ f"before initializing ``{controller}`` , and then set the "
+ f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``."
+ f"\n Current batch_sampler: {type(args.batch_sampler)}"
+ f"\n Current sampler: {type(args.sampler)}")
diff --git a/fastNLP/core/drivers/paddle_driver/__init__.py b/fastNLP/core/drivers/paddle_driver/__init__.py
new file mode 100644
index 00000000..0dc85934
--- /dev/null
+++ b/fastNLP/core/drivers/paddle_driver/__init__.py
@@ -0,0 +1,11 @@
+__all__ = [
+ "PaddleDriver",
+ "PaddleSingleDriver",
+ "PaddleFleetDriver",
+ "paddle_seed_everything",
+]
+
+from .paddle_driver import PaddleDriver
+from .single_device import PaddleSingleDriver
+from .fleet import PaddleFleetDriver
+from .utils import paddle_seed_everything
\ No newline at end of file
diff --git a/fastNLP/core/drivers/paddle_driver/dist_utils.py b/fastNLP/core/drivers/paddle_driver/dist_utils.py
new file mode 100644
index 00000000..28182ca3
--- /dev/null
+++ b/fastNLP/core/drivers/paddle_driver/dist_utils.py
@@ -0,0 +1,302 @@
+import io
+import pickle
+import os
+from typing import Any, List
+
+import numpy as np
+from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
+from fastNLP.envs.env import FASTNLP_NO_SYNC
+from fastNLP.core.utils import paddle_move_data_to_device
+
+if _NEED_IMPORT_PADDLE:
+ import paddle
+ import paddle.distributed as dist
+ from paddle.framework.io import (
+ _is_state_dict,
+ _build_saved_state_dict,
+ _unpack_saved_dict,
+ _pickle_save,
+ _pack_loaded_dict,
+ _ndarray_to_tensor,
+ _parse_load_result,
+ )
+
+__all__ = []
+
+def _validate_output_list_for_rank(my_rank, dst, gather_list):
+ if dst == my_rank:
+ if not gather_list:
+ raise ValueError(
+ "Argument ``gather_list`` must be specified on destination rank."
+ )
+ elif gather_list:
+ raise ValueError(
+ "Argument ``gather_list`` must NOT be specified "
+ "on non-destination ranks."
+ )
+
+def paddle_pickle_dump(obj, stream, protocol):
+ """
+ Reference to `paddle.save`
+ """
+ if _is_state_dict(obj):
+ saved_obj = _build_saved_state_dict(obj)
+ saved_obj = _unpack_saved_dict(saved_obj, protocol)
+ pickle.dump(saved_obj, stream, protocol=protocol)
+ else:
+ _pickle_save(obj, stream, protocol)
+
+def paddle_pickle_load(stream):
+ """
+ Reference to `paddle.load`
+ """
+ load_result = pickle.load(stream)
+ if isinstance(load_result, dict):
+ load_result = _pack_loaded_dict(load_result)
+ if "StructuredToParameterName@@" in load_result:
+
+ for key in load_result["StructuredToParameterName@@"]:
+ if isinstance(load_result[key], np.ndarray):
+ load_result[key] = _ndarray_to_tensor(
+ load_result[key], return_numpy=False)
+
+ if "StructuredToParameterName@@" in load_result:
+ del load_result["StructuredToParameterName@@"]
+ else:
+ load_result = _parse_load_result(load_result, return_numpy=False)
+
+ else:
+ load_result = _parse_load_result(load_result, return_numpy=False)
+
+ return load_result
+
+def _object_to_tensor(obj, device=None):
+ f = io.BytesIO()
+ paddle_pickle_dump(obj, f, protocol=2)
+ byte_data = list(f.getvalue())
+ byte_tensor = paddle.to_tensor(byte_data, dtype=paddle.int32)
+ local_size = paddle.to_tensor([byte_tensor.numel()])
+ if device is not None:
+ byte_tensor = paddle_move_data_to_device(byte_tensor, device)
+ local_size = paddle_move_data_to_device(local_size, device)
+ return byte_tensor, local_size
+
+def _tensor_to_object(tensor, tensor_size):
+ buf = tensor.astype(paddle.uint8).detach().cpu().numpy().tobytes()[:tensor_size]
+ return paddle_pickle_load(io.BytesIO(buf))
+
+def fastnlp_paddle_gather_object(obj, dst=0, group=None):
+ """
+ 从其它 rank gather 东西到 dst rank 。
+
+ Example::
+ >>> # Assumes world_size of 3.
+ >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
+ >>> output = [None for _ in gather_objects]
+ >>> fastnlp_paddle_gather_object(
+ gather_objects[dist.get_rank()],
+ output if dist.get_rank() == 0 else None,
+ dst=0
+ )
+ >>> # On rank 0
+ >>> output
+ ['foo', 12, {1: 2}]
+
+ :param obj: 需要发送的 obj 对象,需要是可以 pickable 的对象
+ :param dst: 目标的 rank 。
+ :param group: 在哪个 group 执行该函数。
+ :return: 在 dst 上面返回 world_size 的 list,依次为 rank 0;rank 1...上 obj
+ """
+ if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
+ return [obj]
+
+ if dist.get_rank() == dst:
+ object_gather_list = [None for _ in range(dist.get_world_size())]
+ else:
+ object_gather_list = None
+
+ # if group is None:
+ # TODO 2.2 版本存在 bug
+ # group = dist.collective._get_global_group()
+
+ if group is not None and not group.is_member():
+ return
+
+ # Ensure object_gather_list is specified appopriately.
+ my_rank = dist.get_rank()
+ _validate_output_list_for_rank(my_rank, dst, object_gather_list)
+ # 防止 unpickle 的时候出现在了发送的 gpu 上。
+ obj = paddle_move_data_to_device(obj, device="cpu")
+ input_tensor, local_size = _object_to_tensor(obj)
+ # 目前 paddle 的 group 仅支持 nccl
+ input_tensor = paddle_move_data_to_device(input_tensor, device=paddle.device.get_device())
+ local_size = paddle_move_data_to_device(local_size, device=paddle.device.get_device())
+
+ # 收集所有的 local_size,找到最大的 size
+ object_size_list = []
+ dist.all_gather(object_size_list, local_size, group=group)
+ max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
+ input_tensor.reshape_(max_object_size)
+ # TODO 暂时没有在 paddle 中发现类似 torch.distributed.gather 的函数
+ output_tensors = []
+ dist.all_gather(output_tensors, input_tensor, group)
+ if my_rank != dst:
+ return
+ for i, tensor in enumerate(output_tensors):
+ tensor = tensor.astype(paddle.uint8)
+ tensor_size = object_size_list[i]
+ object_gather_list[i] = _tensor_to_object(tensor, tensor_size)
+
+def send_recv_object(obj, src, cur_rank, device, group=None, use_calc_stream=True):
+ # src rank send to all other ranks
+ size = paddle_move_data_to_device(paddle.to_tensor([0]), device)
+
+ if cur_rank == src:
+ world_size = dist.get_world_size()
+ tensor, size = _object_to_tensor(obj)
+ tensor = tensor.to(device)
+ size = size.to(device)
+
+ # 首先同步 obj 的 size 的信息;
+ dist.broadcast(size, src, group=group)
+ for subrank in range(world_size):
+ if subrank != src:
+ dist.send(tensor=tensor, dst=subrank, group=group, use_calc_stream=use_calc_stream)
+ else:
+ dist.broadcast(size, src, group=group)
+ tensor = paddle_move_data_to_device(paddle.to_tensor([0] * size), device)
+ dist.recv(tensor=tensor, src=src, group=group, use_calc_stream=use_calc_stream)
+
+ return _tensor_to_object(tensor.cpu(), size)
+
+def fastnlp_paddle_all_gather(obj: Any, device=None, group=None) ->List:
+ """
+ 实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。
+
+ example::
+
+ >>> # rank 0
+ >>> obj = {'a': 1, 'b':[1, 2], 'c':{'d': 1}}
+ >>> # rank 1
+ >>> obj = {'a': 1, 'b':[1, 2], 'c':{'d': 2}}
+ >>> # after all_gather():
+ >>> result = [
+ {'a': 1, 'b':[1, 2], 'c':{'d': 1}},
+ {'a': 1, 'b':[1, 2], 'c':{'d': 2}}
+ ]
+
+ :param obj: 任意结构的数据,如果为 tensor ,需要保证每个显卡上的 tensor 的形状是一样的。如果传入的是非 tensor 对象都将直接进行
+ 序列化之后进行传输。
+ :param device: 当前该参数无意义。
+ :param group:
+ :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。
+ """
+ if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
+ return [obj]
+
+ # if group is None:
+ # TODO 2.2 版本存在 bug
+ # group = dist.collective._get_global_group()
+ if isinstance(obj, paddle.Tensor):
+ objs = []
+ dist.all_gather(objs, obj, group=group)
+ else:
+ objs = [None for _ in range(dist.get_world_size())]
+ # 防止 unpickle 的时候弄到发送的 gpu 上了
+ obj = paddle_move_data_to_device(obj, "cpu")
+ objs = all_gather_object(objs, obj, group=group)
+
+ return objs
+
+
+def fastnlp_paddle_broadcast_object(obj, src, device=None, group=None):
+ """
+ 将 src 上的 obj 对象广播到其它 rank 上。
+
+ :param obj: 需要发送的对象
+ :param src: 从哪里发出。
+ :param device:
+ :param group: 属于哪个通信 group
+ :return:
+ """
+ if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
+ if src == dist.get_rank():
+ return obj
+ else:
+ return None
+
+ cur_rank = dist.get_rank()
+ if cur_rank == src:
+ # 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里
+ obj = paddle_move_data_to_device(obj, "cpu")
+
+ if device is None:
+ device = paddle.device.get_device()
+
+ if cur_rank == src:
+ tensor, size = _object_to_tensor(obj, device=device)
+ else:
+ size = paddle_move_data_to_device(paddle.to_tensor([0]), device)
+
+ dist.broadcast(size, src=src, group=group)
+ if cur_rank != src:
+ tensor = paddle.empty(
+ size.astype(paddle.int32), # type: ignore[arg-type]
+ dtype=paddle.int32,
+ )
+ dist.broadcast(tensor, src=src, group=group)
+
+ return _tensor_to_object(tensor, tensor_size=size.item())
+
+def all_gather_object(object_list, obj, group=None):
+ """
+
+ Example::
+ >>> # Note: Process group initialization omitted on each rank.
+ >>> # Assumes world_size of 3.
+ >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
+ >>> output = [None for _ in gather_objects]
+ >>> all_gather_object(output, gather_objects[dist.get_rank()])
+ >>> output
+ ['foo', 12, {1: 2}]
+
+ :param object_list:
+ :param obj:
+ :param group:
+ :return:
+ """
+ if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
+ return [obj]
+
+ if group is not None and not group.is_member():
+ return
+
+ current_device = paddle.device.get_device()
+
+ input_tensor, local_size = _object_to_tensor(obj, device=current_device)
+
+ # 聚合 tensor 的 size,找到最大的
+ object_size_list = []
+ # Allgather tensor sizes
+ dist.all_gather(object_size_list, local_size, group=group)
+ max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
+ # 将张量进行 pad
+ pad_dims = []
+ pad_by = (max_object_size - local_size).detach().cpu()
+ for val in reversed(pad_by):
+ pad_dims.append(0)
+ pad_dims.append(val.item())
+ tensor_padded = paddle.nn.functional.pad(input_tensor, pad_dims)
+
+ # Output tensors are nonoverlapping views of coalesced_output_tensor
+ output_tensors = []
+ dist.all_gather(output_tensors, tensor_padded, group=group)
+ dist.barrier()
+ # Deserialize outputs back to object.
+ for i, tensor in enumerate(output_tensors):
+ tensor = tensor.astype(paddle.uint8)
+ if not tensor.place.is_cpu_place():
+ tensor = tensor.cpu()
+ tensor_size = object_size_list[i]
+ object_list[i] = _tensor_to_object(tensor, tensor_size)
+ return object_list
\ No newline at end of file
diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py
new file mode 100644
index 00000000..6a38af5e
--- /dev/null
+++ b/fastNLP/core/drivers/paddle_driver/fleet.py
@@ -0,0 +1,579 @@
+r"""
+用于实现 **PaddlePaddle** 框架下使用 ``fleet`` 分布式训练 API 进行集群式(*collective*)多卡训练的 Driver。
+
+.. note::
+
+ 在 **PaddlePaddle** 框架中,使用分布式训练的方式可以参见 **PaddlePaddle** 的
+ `官方文档 `_ 。
+ 简言之,分布式训练的过程可以概括为:导入 ``fleet`` 包 -> 使用 :func:`fleet.init` 初始化分布式环境 -> 初始化模型,转换为并行模型开始训练。
+
+**fastNLP** 支持三种启动分布式训练的方式(假设执行训练的文件名为 ``train.py``):
+
+ A. 用户自己不进行分布式的任何操作,直接使用我们的 :class:`~fastNLP.core.Trainer` 进行训练,此时将参数 ``device``
+ 设置为一个列表,然后使用 ``python train.py`` 的方式开始训练;
+ B. 用户自己不进行分布式的任何操作,但是使用 ``python -m paddle.distributed.launch train.py`` 开始训练;
+ C. 用户自己在外面初始化分布式环境,并且通过 ``python -m paddle.distributed.launch train.py`` 开始训练;
+
+.. note::
+
+ 在后两种启动方式中,您需要通过参数 ``--gpus`` 来指定训练使用的设备,在 ``trainer`` 中设置的参数是无效的。
+
+不过在使用该 Driver 之前,我们需要向您说明 **fastNLP** 实现 ``PaddleFleetDriver`` 的思路,以便于您理解代码编写过程中可能出现的问题。
+
+在 **fastNLP** 中,为了尽可能减少单卡向分布式训练转换过程中的代码变动,我们需要在 ``PaddleFleetDriver`` 中进行 **分布式环境初始化**
+和 **将模型转换为并行模式** 等操作,同时实现多卡训练的方法是从主进程(``rank=0``)中创建其它的所有子进程(``rank=1,2,...``)。
+在这个过程中,我们发现由于 **PaddlePaddle** 框架的特性,会出现下面的问题:
+
+ 1. **fastNLP** 中,初始化模型一定会在初始化 ``Driver`` 之前,因此调用 :func:`fleet.init` 的时机会在初始化模型之后;
+ 此时子进程中模型将无法正常地初始化,提示无法找到设备 ``gpu:0``;
+ 2. 在训练的过程中,会出现训练一个 ``batch`` 后程序卡住或程序会占用所有可见显卡的情况;
+
+考虑到这些问题,我们为 **PaddlePaddle** 的分布式训练制定了这样的约束:在导入 **fastNLP** 之前,必须设置环境变量 ``FASTNLP_BACKEND``
+为 ``paddle``。执行方法有两种::
+
+ >>> import os
+ >>> os.environ["FASTNLP_BACKEND"] = "paddle" # 设置环境变量
+ >>> import fastNLP # 设置之后才可以导入 fastNLP
+
+或是在执行脚本(假设文件名为 ``train.py`` )时设置::
+
+ FASTNLP_BACKEND=paddle python train.py
+ FASTNLP_BACKEND=paddle python -m paddle.distributed.lauch train.py
+
+设置 ``FASTNLP_BACKEND=paddle`` 后,**fastNLP** 会在 ``import paddle`` 之前通过 ``CUDA_VISIBLE_DEVICES`` 将设备限制在所有可见设备的第
+**0** 张卡上,以此绕开通信和同步上的种种限制。我们会将用户希望可见的设备(如用户自己设置了 ``CUDA_VISIBLE_DEVICES`` 的情况)保存在另一个环境变量
+``USER_CUDA_VISIBLE_DEVICES`` 中来确保 **fastNLP** 能够知道用户的设置。假设用户希望在 ``[0,2,3]`` 三张显卡上进行分布式训练,那么在三个训练进程中,
+``CUDA_VISIBLE_DEVICES`` 就分别为 0、2 和 3 。
+
+.. note::
+
+ 我们会事先将设备限制在所有可见设备的第 **0** 张卡上,因此多卡训练的参数 ``device`` 一定要以 **0** 开始,否则会无法正常地启动。
+ 如果您希望调整使用的第一张显卡,请使用 ``CUDA_VISIBLE_DEVICES`` 进行限制。
+
+.. note::
+
+ 根据 **PaddlePaddle** 的说明,设置 ``CUDA_VISIBLE_DEVICES`` 之后启动分布式训练时,情况A与情况BC设置设备的方式会有所不同。
+ 情况A应设置为实际设备相对可见设备的索引,而情况BC应设置为实际的设备号:
+
+ 1. 情况A中, ``CUDA_VISIBLE_DEVICES=3,4,5,6`` 且参数 ``device=[0,2,3]`` 代表使用 **3号、5号和6号** 显卡;
+ 2. 情况BC中,``CUDA_VISIBLE_DEVICES=3,4,5,6`` 且参数 ``--gpu=3,5,6`` 代表使用 **3号、5号和6号** 显卡;
+
+.. note::
+
+ 多机的启动强制要求用户在每一台机器上使用 ``python -m paddle.distributed.launch`` 启动;因此我们不会在 ``PaddleFleetDriver``
+ 中保存任何当前有多少台机器的信息。
+
+"""
+import os
+from typing import List, Union, Optional, Dict, Tuple, Callable
+
+from .paddle_driver import PaddleDriver
+from .fleet_launcher import FleetLauncher
+from .utils import (
+ _FleetWrappingModel,
+ replace_sampler,
+ replace_batch_sampler,
+ _check_dataloader_args_for_distributed
+)
+from .dist_utils import fastnlp_paddle_all_gather, fastnlp_paddle_broadcast_object
+
+from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
+from fastNLP.core.utils import (
+ auto_param_call,
+ check_user_specific_params,
+ is_in_paddle_dist,
+ get_paddle_device_id,
+)
+from fastNLP.core.utils.paddle_utils import _convert_data_device
+from fastNLP.envs.distributed import rank_zero_rm
+from fastNLP.core.samplers import (
+ ReproduceBatchSampler,
+ ReproducibleSampler,
+ ReproducibleBatchSampler,
+ RandomSampler,
+ UnrepeatedSampler,
+ UnrepeatedSequentialSampler,
+ re_instantiate_sampler,
+ conversion_between_reproducible_and_unrepeated_sampler,
+)
+from fastNLP.envs.env import (
+ FASTNLP_DISTRIBUTED_CHECK,
+ FASTNLP_GLOBAL_SEED,
+ FASTNLP_NO_SYNC,
+ USER_CUDA_VISIBLE_DEVICES,
+)
+from fastNLP.core.log import logger
+
+if _NEED_IMPORT_PADDLE:
+ import paddle
+ from paddle import DataParallel
+ import paddle.distributed.fleet as fleet
+ import paddle.distributed as paddledist
+ from paddle.optimizer import Optimizer
+ from paddle.fluid.reader import _DatasetKind
+ from paddle.fluid.dygraph import parallel_helper
+ from paddle.io import BatchSampler
+
+__all__ = [
+ "PaddleFleetDriver",
+]
+
+class PaddleFleetDriver(PaddleDriver):
+ """
+ :param model: 训练使用的模型。
+
+ * 如果不想自己初始化分布式环境,类型应为 :class:`paddle.nn.Layer`;
+ * 如果已经在外面初始化了分布式环境,类型应为 :class:`paddle.DataParallel`;
+
+ :param parallel_device: 多卡训练时使用的设备,必须是一个列表。
+ 当使用 ``python -m paddle.distributed.launch`` 启动时,该参数无效。
+ :param is_pull_by_paddle_run: 标记当前进程是否为通过 ``python -m paddle.distributed.launch`` 启动的。
+ 这个参数仅在 :class:`~fastNLP.core.Trainer` 中初始化 driver 时使用
+ :param fp16: 是否开启混合精度训练
+ :param paddle_kwargs:
+ * *fleet_kwargs* -- 用于在使用 ``PaddleFleetDriver`` 时指定 ``DataParallel`` 和 ``fleet`` 初始化时的参数,包括:
+
+ * *is_collective* -- 是否使用 paddle 集群式的分布式训练方法,目前仅支持为 ``True`` 的情况。
+ * *role_maker* -- 初始化 ``fleet`` 分布式训练 API 时使用的 ``RoleMaker``。
+ * 其它用于初始化 ``DataParallel`` 的参数。
+ * *gradscaler_kwargs* -- 用于 ``fp16=True`` 时,提供给 :class:`paddle.amp.GradScaler` 的参数
+
+ :kwargs:
+ * *wo_auto_param_call* (``bool``) -- 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为
+
+ .. note::
+
+ 关于该参数的详细说明,请参见 :class:`~fastNLP.core.controllers.Trainer` 中的描述;函数 ``auto_param_call`` 详见 :func:`fastNLP.core.utils.auto_param_call`。
+
+ """
+ def __init__(
+ self,
+ model,
+ parallel_device: Optional[Union[List[str], str]],
+ is_pull_by_paddle_run: bool = False,
+ fp16: bool = False,
+ paddle_kwargs: Dict = None,
+ **kwargs
+ ):
+ if USER_CUDA_VISIBLE_DEVICES not in os.environ:
+ raise RuntimeError("To run paddle distributed training, please set `FASTNLP_BACKEND` to 'paddle' before using fastNLP.")
+ super(PaddleFleetDriver, self).__init__(model, fp16=fp16, paddle_kwargs=paddle_kwargs, **kwargs)
+
+ # 如果不是通过 launch 启动,要求用户必须传入 parallel_device
+ if not is_pull_by_paddle_run:
+ if parallel_device is None:
+ raise ValueError("Parameter `parallel_device` can not be None when using `PaddleFleetDriver`. This error is caused "
+ "when your value of parameter `device` is `None` in your `Trainer` instance.")
+ if not isinstance(parallel_device, List):
+ raise ValueError("Parameter `parallel_device`'s type must be List when using `PaddleFleetDriver`, "
+ f"not {type(parallel_device)}.")
+ if get_paddle_device_id(parallel_device[0]) != 0:
+ raise ValueError("The first device of `parallel_device` must be 'gpu:0' in fastNLP.")
+
+ # 如果用户自己初始化了 paddle 的分布式训练那么一定是通过 launch 拉起的
+ # 这个参数会在 initialize_paddle_drvier 中设置。
+ self.is_pull_by_paddle_run = is_pull_by_paddle_run
+ self.parallel_device = parallel_device
+ # 在初始化时,如果发现 is_pull_by_paddle_run ,则将 parallel_device 设置成当前进程的gpu
+ if is_pull_by_paddle_run:
+ self.model_device = parallel_device
+ else:
+ self.model_device = parallel_device[self.local_rank]
+
+ # 如果用户自己在外面初始化了并行模型;
+ self.outside_fleet = False
+ if parallel_helper._is_parallel_ctx_initialized() and FASTNLP_DISTRIBUTED_CHECK not in os.environ and \
+ "fastnlp_paddle_launch_not_fleet" not in os.environ:
+ # 如果用户自己在外面初始化了 Fleet,那么我们要求用户传入的模型一定是已经由 DistributedDataParallel 包裹后的模型;
+ if not isinstance(model, DataParallel):
+ raise RuntimeError(
+ "It is not allowed to input a normal model instead of `paddle.DataParallel` when"
+ "you initialize the paddle distribued process out of our control.")
+
+ self.outside_fleet = True
+
+ self.world_size = None
+ self.global_rank = 0
+ self.gloo_rendezvous_dir = None
+
+ self._fleet_kwargs = self._paddle_kwargs.get("fleet_kwargs", {})
+ check_user_specific_params(self._fleet_kwargs, DataParallel.__init__, DataParallel.__name__)
+ # fleet.init 中对于分布式策略的设置,详情可以参考 PaddlePaddle 的官方文档
+ self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy())
+ self.is_collective = self._fleet_kwargs.pop("is_collective", True)
+ if not self.is_collective:
+ raise NotImplementedError("fastNLP only support `collective` for distributed training now.")
+ self.role_maker = self._fleet_kwargs.pop("role_maker", None)
+
+ self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error")
+ assert isinstance(self.output_from_new_proc, str), "Parameter `output_from_new_proc` can only be `str` type."
+ if self.output_from_new_proc not in {"all", "ignore", "only_error"}:
+ os.makedirs(name=self.output_from_new_proc, exist_ok=True)
+ self.output_from_new_proc = os.path.abspath(self.output_from_new_proc)
+
+ self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的;
+ self._has_fleetwrapped = False # 判断传入的模型是否经过 _has_fleetwrapped 包裹;
+
+ def setup(self):
+ """
+ 初始化分布式训练的环境。
+
+ 1. 如果是通过 ``paddle.distributed.launch`` 方法启动的,则根据已经设置好的环境获取分布式的属性。
+ 2. 否则启动子进程。
+ """
+ if self._has_setup:
+ return
+ self._has_setup = True
+ # 如果用户需要使用多机模式,那么一定进入到这里;
+ if self.is_pull_by_paddle_run:
+
+ if self.outside_fleet:
+ # 已经初始化了多机环境
+ self._set_from_fleet_environment()
+ else:
+ # 用户没有初始化多机环境
+ # TODO 绕一下
+ # dist.get_world_size() 只能在初始化之后进行调用;
+ self.world_size = int(os.environ.get("PADDLE_TRAINERS_NUM"))
+ self.global_rank = int(os.environ.get("PADDLE_TRAINER_ID"))
+ logger.info(f"World size: {self.world_size}, Global rank: {self.global_rank}")
+ if not parallel_helper._is_parallel_ctx_initialized():
+ fleet.init(self.role_maker, self.is_collective, self.strategy)
+
+ os.environ["fastnlp_paddle_launch_not_fleet"] = "yes"
+
+ else:
+ # 在用户只使用了一个分布式 trainer 的情况下
+ # 此时 parallel_helper._is_parallel_ctx_initialized() 一定为 False
+ # parallel_device 是 list,
+ if not parallel_helper._is_parallel_ctx_initialized():
+ # 拉起子进程并设置相应的属性
+ self._init_fleet_and_set()
+ # 用户在这个 trainer 前面又初始化了一个 trainer,并且使用的是 PaddleFleetDriver;
+ else:
+ # 已经设置过一次,保证参数必须是一样的
+ pre_gpus = os.environ[FASTNLP_DISTRIBUTED_CHECK]
+ pre_gpus = [int(x) for x in pre_gpus.split(",")]
+ cur_gpus = [get_paddle_device_id(g) for g in self.parallel_device]
+ if sorted(pre_gpus) != sorted(self.parallel_device):
+ raise RuntimeError("Notice you are using `PaddleFleetDriver` after one instantiated `PaddleFleetDriver`, it is not"
+ "allowed that your second `PaddleFleetDriver` has a new setting of parameters `parallel_device`.")
+ self.world_size = paddledist.get_world_size()
+ self.global_rank = paddledist.get_rank()
+
+ if not self.outside_fleet:
+ # self.model.to(self.model_device)
+ self.configure_fleet()
+
+ self.barrier()
+
+ # 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作;
+ # TODO 不用.to会怎么样?
+ self._pids = []
+ paddledist.all_gather(self._pids, paddle.to_tensor(os.getpid(), dtype="int32"))
+ # TODO LOCAL_WORLD_SIZE
+ local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) if "LOCAL_WORLD_SIZE" in os.environ else None
+ if local_world_size is None:
+ local_world_size = paddle.to_tensor(self.local_rank, dtype="int32")
+ paddledist.all_reduce(local_world_size, op=paddledist.ReduceOp.MAX)
+ local_world_size = local_world_size.item() + 1
+
+ node_rank = self.global_rank // local_world_size
+ self._pids = self._pids[node_rank*local_world_size: (node_rank+1)*local_world_size]
+ self._pids = self.tensor_to_numeric(self._pids)
+
+ def _init_fleet_and_set(self):
+ """
+ 使用 FleetLauncher 拉起子进程
+ """
+ if self.local_rank == 0:
+ logger._set_distributed()
+ # 是 rank0 的话,则拉起其它子进程
+ launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc)
+ launcher.launch()
+ self.gloo_rendezvous_dir = launcher.gloo_rendezvous_dir
+ # 设置参数和初始化分布式环境
+ fleet.init(self.role_maker, self.is_collective, self.strategy)
+ self.global_rank = int(os.getenv("PADDLE_TRAINER_ID"))
+ self.world_size = int(os.getenv("PADDLE_TRAINERS_NUM"))
+
+ # 正常情况下不会 Assert 出问题,但还是保险一下
+ assert self.global_rank is not None
+ assert self.world_size is not None
+ assert self.world_size == len(self.parallel_device)
+
+ def _set_from_fleet_environment(self):
+ """
+ 当用户使用了 `python -m paddle.distributed.launch xxx.py` 启动时,我们需要
+ 根据 paddle 设置的环境变量来获得各种属性
+ """
+ self.world_size = paddledist.get_world_size()
+ self.global_rank = paddledist.get_rank()
+
+ def barrier(self):
+ """
+ 同步进程之间的操作
+ """
+ if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行
+ paddledist.barrier()
+
+ def configure_fleet(self):
+ # 将模型用 DataParallel 和自定义的类型包裹起来
+ if not self._has_fleetwrapped and not isinstance(self.model, DataParallel):
+ self.model = DataParallel(
+ _FleetWrappingModel(self.model),
+ **self._fleet_kwargs
+ )
+ self._has_fleetwrapped = True
+
+ def on_exception(self):
+ rank_zero_rm(self.gloo_rendezvous_dir)
+ super().on_exception()
+
+ @property
+ def world_size(self) -> int:
+ """
+ 分布式训练的进程总数 ``WOLRD_SIZE``
+ """
+ return self._world_size
+
+ @world_size.setter
+ def world_size(self, size: int) -> None:
+ self._world_size = size
+
+ @property
+ def global_rank(self) -> int:
+ """
+ 当前进程的全局编号 ``global_rank``
+ """
+ return self._global_rank
+
+ @global_rank.setter
+ def global_rank(self, rank: int) -> None:
+ self._global_rank = rank
+
+ @property
+ def local_rank(self) -> int:
+ """
+ 当前进程的局部编号 ``local_rank``
+ """
+ return int(os.getenv("PADDLE_RANK_IN_NODE", "0"))
+
+ @property
+ def data_device(self):
+ """
+ 数据所在的设备;由于 **PaddlePaddle** 可以通过环境变量获取当前进程的设备,因此该属性
+ 和 ``model_device`` 表现相同。
+ """
+ return self.model_device
+
+ def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
+ if self._has_fleetwrapped:
+ return self.model(batch, fastnlp_fn=fn, fastnlp_signature_fn=signature_fn,
+ wo_auto_param_call=self.wo_auto_param_call)
+ else:
+ if isinstance(batch, Dict) and not self.wo_auto_param_call:
+ return auto_param_call(fn, batch, signature_fn=signature_fn)
+ else:
+ return fn(batch)
+
+ def get_model_call_fn(self, fn: str) -> Tuple:
+ model = self.unwrap_model()
+ if self._has_fleetwrapped:
+ if hasattr(model, fn):
+ fn = getattr(model, fn)
+ if not callable(fn):
+ raise RuntimeError(f"The `{fn}` attribute of model is not `Callable`.")
+ return fn, None
+ elif fn in {"train_step", "evaluate_step"}:
+ return model, model.forward
+ else:
+ raise RuntimeError(f"There is no `{fn}` method in your model.")
+ else:
+ if hasattr(model, fn):
+ logger.warning("Notice your model is a `DataParallel` model. And your model also implements "
+ f"the `{fn}` method, which we can not call actually, we will"
+ " call `forward` function instead of `train_step` and you should note that.")
+ elif fn not in {"train_step", "evaluate_step"}:
+ raise RuntimeError(f"There is no `{fn}` method in your model. And also notice that your model is a "
+ "`DistributedDataParallel` model, which means that we will only call model.forward "
+ "function when we are in forward propagation.")
+
+ return self.model, model.forward
+
+ def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproduceBatchSampler]],
+ reproducible: bool = False):
+ # 暂时不支持iterableDataset
+ assert dataloader.dataset_kind != _DatasetKind.ITER, \
+ "FastNLP does not support `IteratorDataset` now."
+ # 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load_checkpoint 函数调用;
+ if isinstance(dist, ReproducibleBatchSampler):
+ dist.set_distributed(
+ num_replicas=self.world_size,
+ rank=self.global_rank,
+ pad=True
+ )
+ return replace_batch_sampler(dataloader, dist)
+ if isinstance(dist, ReproducibleSampler):
+ dist.set_distributed(
+ num_replicas=self.world_size,
+ rank=self.global_rank,
+ pad=True
+ )
+ return replace_sampler(dataloader, dist)
+
+ # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用;
+ # trainer, evaluator
+ if dist is None:
+ if reproducible:
+ raise RuntimeError("It is not allowed to save checkpoint if the sampler is not allowed to be replaced.")
+ else:
+ args = self.get_dataloader_args(dataloader)
+ if isinstance(args.batch_sampler, ReproducibleBatchSampler):
+ batch_sampler = re_instantiate_sampler(args.batch_sampler)
+ return replace_batch_sampler(dataloader, batch_sampler)
+ if isinstance(args.sampler, ReproducibleSampler):
+ sampler = re_instantiate_sampler(args.sampler)
+ return replace_sampler(dataloader, sampler)
+ return dataloader
+ # trainer
+ elif dist == "dist":
+ args = self.get_dataloader_args(dataloader)
+ # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为;
+ if isinstance(args.batch_sampler, ReproducibleBatchSampler):
+ batch_sampler = re_instantiate_sampler(args.batch_sampler)
+ batch_sampler.set_distributed(
+ num_replicas=self.world_size,
+ rank=self.global_rank,
+ pad=True
+ )
+ return replace_batch_sampler(dataloader, batch_sampler)
+ elif isinstance(args.sampler, ReproducibleSampler):
+ sampler = re_instantiate_sampler(args.sampler)
+ sampler.set_distributed(
+ num_replicas=self.world_size,
+ rank=self.global_rank,
+ pad=True
+ )
+ return replace_sampler(dataloader, sampler)
+ else:
+ _check_dataloader_args_for_distributed(args, controller='Trainer')
+ sampler = RandomSampler(
+ dataset=args.dataset,
+ shuffle=args.shuffle,
+ seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0))
+ )
+ sampler.set_distributed(
+ num_replicas=self.world_size,
+ rank=self.global_rank,
+ pad=True
+ )
+ return replace_sampler(dataloader, sampler)
+ # evaluator
+ elif dist == "unrepeatdist":
+ args = self.get_dataloader_args(dataloader)
+ if type(args.batch_sampler) != BatchSampler:
+ # TODO 这里的目的是判断用户的 batch_sampler 是定制的,可能需要完善
+ logger.warning("Note that you are using customized ``batch_sampler`` in evaluate dataloader or" \
+ "train dataloader while testing ``overfit_batches``, which may cause that" \
+ "the data for distributed evaluation is not unrepeated.")
+ if isinstance(args.sampler, ReproducibleSampler):
+ sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler)
+ elif not isinstance(args.sampler, UnrepeatedSampler):
+ _check_dataloader_args_for_distributed(args, controller='Evaluator')
+ sampler = UnrepeatedSequentialSampler(
+ dataset=args.dataset
+ )
+ else:
+ sampler = re_instantiate_sampler(args.sampler)
+ sampler.set_distributed(
+ num_replicas=self.world_size,
+ rank=self.global_rank
+ )
+ # TODO 这里暂时统一替换为 BatchSampler
+ batch_sampler = BatchSampler(dataset=args.dataset, batch_size=args.batch_size, drop_last=False)
+ batch_sampler.sampler = sampler
+ return replace_batch_sampler(dataloader, batch_sampler)
+ else:
+ raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).")
+
+ def is_global_zero(self) -> bool:
+ r"""
+ :return: 当前的进程是否在全局上是进程 0
+ """
+ return self.global_rank == 0
+
+ def get_model_no_sync_context(self):
+ r"""
+ :return: 一个 ``context`` 上下文环境,用于关闭各个进程之间的同步。
+ """
+ return self.model.no_sync
+
+ def unwrap_model(self) -> "paddle.nn.Layer":
+ """
+ 获得 driver 最原始的模型。该函数可以取出被 :class:`paddle.DataParallel` 包裹起来的模型。
+ """
+ _layers = self.model._layers
+ if isinstance(_layers, _FleetWrappingModel):
+ return _layers.model
+ else:
+ return _layers
+
+ def get_local_rank(self) -> int:
+ r"""
+ :return: 当前进程局部的进程编号。
+ """
+ return self.local_rank
+
+ def is_distributed(self) -> bool:
+ """
+ :return: 当前使用的 driver 是否是分布式的 driver,在 ``PaddleFleetDriver`` 中,返回 ``True``。
+ """
+ return True
+
+ @staticmethod
+ def _check_optimizer_legality(optimizers):
+ # paddle 存在设置分布式 optimizers 的函数,返回值为 fleet.meta_optimizers.HybridParallelOptimizer
+ DistribuedOptimizer = fleet.meta_optimizers.HybridParallelOptimizer
+ for each_optimizer in optimizers:
+ if not isinstance(each_optimizer, (Optimizer, DistribuedOptimizer)):
+ raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, "
+ f"not {type(each_optimizer)}.")
+
+ def broadcast_object(self, obj, src:int=0, group=None, **kwargs):
+ r"""
+ 从 ``src`` 端将 ``obj`` 对象(可能是 tensor ,可能是 object )广播到其它进程。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行
+ 传输,然后在接收处处再加载回来。仅在分布式的 driver 中有实际意义。
+
+ :param obj: obj,可能是 Tensor 或 嵌套类型的数据
+ :param src: 发送方的 ``global_rank``
+ :param group: 进程所在的通信组
+ :return: 如果当前 rank 是接收端,则返回接收到的参数;如果是 source 端则返回发送的内容。如果环境变量 ``FASTNLP_NO_SYNC`` 为 **2** 则
+ 返回 ``None``
+ """
+ # 因为设置了CUDA_VISIBLE_DEVICES,可能会引起错误
+ device = _convert_data_device(self.data_device)
+ return fastnlp_paddle_broadcast_object(obj, src, device=device, group=group)
+
+ def all_gather(self, obj, group=None) -> List:
+ r"""
+ 将 ``obj`` 互相传送到其它所有的 rank 上,其中 ``obj`` 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,将会尝试通过
+ pickle 进行序列化,接收到之后再反序列化。
+
+ example::
+
+ >>> # rank 0
+ >>> obj = {'a': 1, 'b':[1, 2], 'c':{'d': 1}}
+ >>> # rank 1
+ >>> obj = {'a': 1, 'b':[1, 2], 'c':{'d': 2}}
+ >>> # after all_gather():
+ >>> result = [
+ {'a': 1, 'b':[1, 2], 'c':{'d': 1}},
+ {'a': 1, 'b':[1, 2], 'c':{'d': 2}}
+ ]
+
+ :param obj: 需要传输的对象,在每个 rank 上都应该保持相同的结构。
+ :param group: 进程所在的通信组。
+ :return: 所有 rank 发送的 ``obj`` 聚合在一起的内容;如果环境变量 ``FASTNLP_NO_SYNC`` 为 **2** 则不会执行,直接返回 ``[obj]`` 。
+ """
+ return fastnlp_paddle_all_gather(obj, group=group)
diff --git a/fastNLP/core/drivers/paddle_driver/fleet_launcher.py b/fastNLP/core/drivers/paddle_driver/fleet_launcher.py
new file mode 100644
index 00000000..4a42dcff
--- /dev/null
+++ b/fastNLP/core/drivers/paddle_driver/fleet_launcher.py
@@ -0,0 +1,197 @@
+import os
+import sys
+import __main__
+import tempfile
+import copy
+from typing import List
+
+from fastNLP.core.drivers.utils import distributed_open_proc
+from fastNLP.envs.env import (
+ FASTNLP_DISTRIBUTED_CHECK,
+ FASTNLP_LOG_LEVEL,
+ FASTNLP_GLOBAL_SEED,
+ FASTNLP_GLOBAL_RANK
+)
+from fastNLP.core.utils import get_paddle_device_id
+from .utils import (
+ find_free_ports,
+)
+
+__all__ = []
+
+# 记录各个进程信息
+class SubTrainer(object):
+ """
+ 用于统计节点内不同训练进程的信息,和 fastnlp 的 Triainer 没有关系
+ """
+ def __init__(self, endpoint=None, rank=None):
+ self.devices = []
+ self.endpoint = endpoint
+ self.rank = rank
+
+
+class FleetLauncher:
+ """
+ 复原了 paddle 的 launch_collective 函数,将其简化后集成到一个类里
+ 仅支持每个机器单卡的情况。
+ """
+ def __init__(
+ self,
+ devices: List[str],
+ output_from_new_proc: str = "only_error"
+ ):
+
+ self.devices = [ get_paddle_device_id(g) for g in devices]
+ self.output_from_new_proc = output_from_new_proc
+
+ self.setup()
+
+ def setup(self):
+ """
+ 进行初始化设置的函数,根据传入的设备找到分布式训练使用的端口号
+ """
+ self.set_endpoints()
+ self.sub_trainers = self.get_process_info()
+
+ def launch(self):
+ """
+ 用于启动分布式进程。
+ 首先设置 PaddlePaddle 分布式训练需要设置的环境变量,然后建立新的子进程
+ """
+ # 设置环境变量
+ self.global_envs = self.get_global_env()
+ self.open_subprocess()
+
+ def open_subprocess(self):
+ """
+ 从 sub_trainers 中获取各个 rank 的信息,并且使用 subprocess.Popen 建立新的子进程。
+ """
+
+ if __main__.__spec__ is None:
+ # Script called as `python a/b/c.py`
+ # when user is using hydra find the absolute path
+ path_lib = os.path.abspath
+
+ # pull out the commands used to run the script and resolve the abs file path
+ command = sys.argv
+ try:
+ full_path = path_lib(command[0])
+ except Exception:
+ full_path = os.path.abspath(command[0])
+
+ command[0] = full_path
+ # use the same python interpreter and actually running
+ command = [sys.executable] + command
+ else: # Script called as `python -m a.b.c`
+ command = [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:]
+
+ current_env = copy.copy(self.global_envs)
+ for idx, t in enumerate(self.sub_trainers):
+ # 根据不同的 rank 设置环境变量
+ proc_env = {
+ # global_rank
+ "PADDLE_TRAINER_ID": f"{t.rank}",
+ "PADDLE_CURRENT_ENDPOINT": f"{t.endpoint}",
+ # rank
+ "PADDLE_RANK_IN_NODE": f"{idx}",
+ "PADDLE_LOCAL_DEVICE_IDS":
+ ",".join([str(g) for g in t.devices]),
+ }
+
+ if len(t.devices) > 0:
+ proc_env["FLAGS_selected_gpus"] = "%s" % ",".join(
+ [str(g) for g in t.devices])
+ proc_env["FLAGS_selected_devices"] = "%s" % ",".join(
+ [str(g) for g in t.devices])
+
+ current_env.update(proc_env)
+
+ if os.environ.get(FASTNLP_GLOBAL_SEED) is None and FASTNLP_GLOBAL_SEED in current_env:
+ del current_env[FASTNLP_GLOBAL_SEED]
+
+ if idx != 0:
+ # 子进程
+ if os.environ.get(FASTNLP_LOG_LEVEL, None) is None:
+ current_env[FASTNLP_LOG_LEVEL] = "warning"
+ proc = distributed_open_proc(self.output_from_new_proc, command, current_env, t.rank)
+ else:
+ # 更新当前的环境变量
+ os.environ.update(current_env)
+
+ def get_global_env(self):
+ """
+ 设置分布式训练需要的全局变量,包括:
+ 1、GLOO 相关的设置
+ 2、`PADDLE_TRAINERS_NUM` :所有的进程数目
+ 3、`PADDLE_TRAINER_ENDPOINTS` :使用的所有地址及其端口
+ 4、`PADDLE_WORLD_DEVICE_IDS` :使用的所有设备
+ 5、FASTNLP_DISTRIBUTED_CHECK:通过 fastNLP 建立子进程的标志,保存分布式训练使用的设备
+ """
+
+ global_envs = copy.copy(os.environ.copy())
+ os.environ[FASTNLP_GLOBAL_RANK] = "0"
+
+ self.gloo_rendezvous_dir = tempfile.mkdtemp()
+ # launch中涉及的gloo环境
+ global_envs["PADDLE_WITH_GLOO"] = str(os.getenv("PADDLE_WITH_GLOO", "0"))
+ global_envs["PADDLE_GLOO_RENDEZVOUS"] = "3"
+ global_envs["PADDLE_GLOO_FS_PATH"] = self.gloo_rendezvous_dir
+ global_envs["PADDLE_DISTRI_BACKEND"] = "nccl"
+
+ # 通过FNLP初始化的标志
+ global_envs[FASTNLP_DISTRIBUTED_CHECK] = f"{','.join([str(g) for g in self.devices])}"
+
+ # 统计全局信息
+ device_ids = []
+ for t in self.sub_trainers:
+ device_ids.append([str(acc) for acc in t.devices])
+ world_device_ids = [':'.join(ele) for ele in device_ids]
+ # 全局环境变量
+ global_envs.update({
+ # world_size
+ "PADDLE_TRAINERS_NUM": f"{len(self.sub_trainers)}",
+ "PADDLE_TRAINER_ENDPOINTS": ",".join(self.endpoints),
+ "PADDLE_WORLD_DEVICE_IDS": ",".join(world_device_ids),
+ })
+
+ return global_envs
+
+ def set_endpoints(self):
+ """
+ 寻找用户设置的端口或是空闲端口用于分布式训练,参考了 PaddlePaddle 中的 `get_cluster_from_args` 函数
+ """
+ self.node_ip = "127.0.0.1"
+
+ free_ports = None
+ if os.environ.get("FLAGS_START_PORT") is None:
+ free_ports = find_free_ports(len(self.devices))
+ if free_ports is not None:
+ free_ports = list(free_ports)
+ else:
+ start_port = int(os.getenv("FLAGS_START_PORT", "6070"))
+
+ free_ports = [
+ x for x in range(start_port, start_port + len(self.devices))
+ ]
+
+ self.endpoints = ["%s:%d" % (self.node_ip, port) for port in free_ports]
+
+ def get_process_info(self):
+ """
+ 获取各个训练进程的设备、rank 和端口信息,参考 PaddlePaddle 的 `get_cluster` 函数。
+ """
+ sub_trainers = []
+ assert len(self.endpoints) >= len(
+ self.devices
+ ), "current trainer_endpoints size should be greater equal than acclerators size."
+
+ for i in range(len(self.devices)):
+ sub_trainer = SubTrainer(f"{self.endpoints[i]}", i)
+ if isinstance(self.devices[i], (list, tuple)):
+ sub_trainer.devices.extend(self.devices[i])
+ else:
+ sub_trainer.devices.append(self.devices[i])
+
+ sub_trainers.append(sub_trainer)
+
+ return sub_trainers
diff --git a/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py b/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py
new file mode 100644
index 00000000..807ef166
--- /dev/null
+++ b/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py
@@ -0,0 +1,80 @@
+import os
+
+from typing import Optional, List, Sequence, Union
+
+from .paddle_driver import PaddleDriver
+from .single_device import PaddleSingleDriver
+from .fleet import PaddleFleetDriver
+
+from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
+from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES
+from fastNLP.core.utils import is_in_paddle_launch_dist, get_paddle_gpu_str
+from fastNLP.core.log import logger
+
+if _NEED_IMPORT_PADDLE:
+ import paddle
+
+__all__ = []
+
+def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[int]]],
+ model: "paddle.nn.Layer", **kwargs) -> PaddleDriver:
+ r"""
+ 用来根据参数 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例。
+
+ 1. 如果检测到当前进程为用户通过 ``python -m paddle.distributed.launch xxx.py`` 方式拉起的,则将
+ 设备自动设置为用户指定的设备(由于我们要求分布式训练必须进行 ``backend`` 的设置,因此可以通过 ``CUDA_VISIBLE_DEVICES`` 获取)
+
+ 2. 如果 ``device`` 包含了多个设备,则返回一个 :class:`~fastNLP.core.PaddleFleetDriver` 实例,否则返回
+ 单卡的 :class:`~fastNLP.core.PaddleSingleDriver` 实例
+
+ :param driver: 使用的 ``driver`` 类型,在这个函数中仅支持 ``paddle``
+ :param device: 该参数的格式与 ``Trainer`` 对参数 ``device`` 的要求一致
+ :param model: 训练或者评测的具体的模型;
+
+ :return: 一个 :class:`~fastNLP.core.PaddleSingleDriver` 或 :class:`~fastNLP.core.PaddleFleetDriver` 实例;
+ """
+ if driver != "paddle":
+ raise ValueError("When initialize PaddleDriver, parameter `driver` must be 'paddle'.")
+ user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES)
+ if is_in_paddle_launch_dist():
+ if user_visible_devices is None:
+ raise RuntimeError("To run paddle distributed training, please set `FASTNLP_BACKEND` to 'paddle' before using fastNLP.")
+ if device is not None:
+ logger.rank_zero_warning("Parameter `device` would be ignored when you are using `paddle.distributed.launch` to pull "
+ "up your script. And we will directly get the local device via environment variables.", once=True)
+ _visible_list = user_visible_devices.split(",")
+ device = [ f"gpu:{_visible_list.index(g) }" for g in os.environ["CUDA_VISIBLE_DEVICES"].split(",")]
+ # TODO 目前一个进程仅对应一个卡,所以暂时传入单个
+ return PaddleFleetDriver(model, device[0], True, **kwargs)
+
+ if user_visible_devices is None:
+ _could_use_device_num = paddle.device.cuda.device_count()
+ else:
+ _could_use_device_num = len(user_visible_devices.split(","))
+
+ if isinstance(device, int):
+ if device < 0 and device != -1:
+ raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.")
+ if device >= _could_use_device_num:
+ raise ValueError("The gpu device that parameter `device` specifies is not existed.")
+ if device == -1:
+ device = [ get_paddle_gpu_str(g) for g in range(_could_use_device_num)]
+ elif isinstance(device, Sequence) and not isinstance(device, str):
+ device = list(set(device))
+ for each in device:
+ if not isinstance(each, int):
+ raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be 'int' type.")
+ elif each < 0:
+ raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be bigger than 0.")
+ elif each >= _could_use_device_num:
+ raise ValueError("When parameter `device` is 'Sequence' type, the value in it should not be bigger than"
+ " the available gpu number.")
+ device = [get_paddle_gpu_str(g) for g in device]
+ elif device is not None and not isinstance(device, str):
+ raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.")
+
+ if isinstance(device, List):
+ return PaddleFleetDriver(model, device, **kwargs)
+ else:
+ return PaddleSingleDriver(model, device, **kwargs)
+
diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py
new file mode 100644
index 00000000..cacff229
--- /dev/null
+++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py
@@ -0,0 +1,508 @@
+import os
+import random
+from typing import Union, Optional, Dict, Any
+from pathlib import Path
+from functools import partial
+from dataclasses import dataclass
+
+import numpy as np
+
+from .utils import _build_fp16_env, optimizer_state_to_device, DummyGradScaler
+from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
+from fastNLP.core.drivers.driver import Driver
+from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device
+from fastNLP.core.utils.paddle_utils import _convert_data_device
+from fastNLP.envs import (
+ FASTNLP_MODEL_FILENAME,
+ FASTNLP_CHECKPOINT_FILENAME,
+ FASTNLP_GLOBAL_RANK,
+ rank_zero_call,
+)
+from fastNLP.core.log import logger
+from fastNLP.core.dataloaders import OverfitDataLoader
+from fastNLP.core.samplers import (
+ ReproducibleBatchSampler,
+ ReproducibleSampler,
+ ReproduceBatchSampler,
+ RandomSampler,
+)
+
+if _NEED_IMPORT_PADDLE:
+ import paddle
+ from paddle.io import (
+ DataLoader,
+ Dataset,
+ Sampler,
+ BatchSampler,
+ RandomSampler as PaddleRandomSampler,
+ )
+ from paddle.optimizer import Optimizer
+
+ _reduces = {
+ "max": paddle.max,
+ "min": paddle.min,
+ "mean": paddle.mean,
+ "sum": paddle.sum
+ }
+
+class PaddleDriver(Driver):
+ r"""
+ 实现了 **PaddlePaddle** 框架训练功能的基本 Driver。
+
+ 这个类被以下子类继承:
+
+ 1. :class:`~fastNLP.core.drivers.paddle_driver.PaddleSingleDriver`:实现了使用单卡和 ``cpu`` 训练的具体功能;
+ 2. :class:`~fastNLP.core.drivers.paddle_driver.PaddleFleetDriver`:实现了使用 ``fleet`` 分布式训练 API 进行集群式分布式训练的具体功能;
+
+ .. warning::
+
+ 您不应当直接初始化该类,然后传入给 ``Trainer``,换句话说,您应当使用该类的子类 ``PaddleSingleDriver`` 和 ``PaddleDDPDriver``,而不是
+ 该类本身。
+
+ .. note::
+
+ 您可以在使用 ``PaddleSingleDriver`` 和 ``PaddleFleetDriver`` 时使用 ``PaddleDriver`` 提供的接口。
+
+ :param model: 训练时使用的 **PaddlePaddle** 模型
+ :param fp16: 是否开启混合精度训练
+ :param paddle_kwargs:
+ """
+ def __init__(self, model: "paddle.nn.Layer", fp16: Optional[bool] = False, paddle_kwargs: Dict = None, **kwargs):
+ if not isinstance(model, paddle.nn.Layer):
+ raise ValueError(f"Parameter `model` can not be `{type(model)}` in `PaddleDriver`, it should be exactly "
+ f"`paddle.nn.Layer` type.")
+
+ super(PaddleDriver, self).__init__(model)
+ self.fp16 = fp16
+ self._paddle_kwargs = paddle_kwargs if paddle_kwargs is not None else {}
+
+ # scaler的参数
+ self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16)
+ self.grad_scaler = _grad_scaler(**self._paddle_kwargs.get("gradscaler_kwargs", {}))
+
+ # 用来设置是否关闭 auto_param_call 中的参数匹配问题;
+ self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)
+
+ def zero_grad(self):
+ """
+ 实现梯度置零的过程
+ """
+ for optimizer in self.optimizers:
+ optimizer.clear_grad()
+
+ def backward(self, loss):
+ """
+ 对 ``loss`` 进行反向传播
+ """
+ self.grad_scaler.scale(loss).backward()
+
+ def step(self):
+ r"""
+ 实现参数的优化更新过程
+ """
+ for optimizer in self.optimizers:
+ self.grad_scaler.step(optimizer)
+ self.grad_scaler.update()
+
+ def check_dataloader_legality(self, dataloader):
+ """
+ 检测 DataLoader 是否合法。支持的类型包括 :class:`~fastNLP.core.dataloaders.PaddleDataLoader`、 :class:`paddle.io.DataLoader` 。
+
+ :param dataloder:
+ """
+ if not isinstance(dataloader, DataLoader) and not isinstance(dataloader, OverfitDataLoader):
+ raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`")
+ if dataloader.batch_size is None and dataloader.batch_sampler is None:
+ raise ValueError("Please ensure at least one of your dataloader's batch_size and batch_sampler"
+ "is not None")
+ if len(dataloader) == 0:
+ logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it "
+ "may cause some unexpected exceptions.", once=True)
+
+ @staticmethod
+ def _check_optimizer_legality(optimizers):
+ r"""
+ 对于用户传入 trainer 的每一个 optimizer检测其合法性,必须为`paddle.optimizer.Optimizer`类型。
+
+ :param optimizers: 需要检测的 `optimizers`。
+ """
+ for each_optimizer in optimizers:
+ if not isinstance(each_optimizer, Optimizer):
+ raise TypeError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, "
+ f"not {type(each_optimizer)}.")
+
+ @staticmethod
+ def tensor_to_numeric(tensor, reduce=None):
+ r"""
+ 将一个 :class:`paddle.Tensor` 对象转换为 转换成 python 中的数值类型。
+
+ :param tensor: :class:`paddle.Tensor` 类型的对象。
+ :param reduce: 当 tensor 是一个多数值的张量时,应当使用何种归一化操作来转换成单一数值,应当为以下类型之一:``['max', 'min', 'sum', 'mean']``。
+ :return: 一个单一数值,其数值类型是 python 中的基本的数值类型,例如 ``int,float`` 等。
+ """
+ if tensor is None:
+ return None
+
+ def _translate(_data):
+ # 如果只含有一个元素,则返回元素本身,而非list
+ if _data.numel().item() == 1:
+ return _data.item()
+ if reduce is None:
+ return _data.tolist()
+ else:
+ return _reduces[reduce](_data).item()
+
+ return apply_to_collection(
+ data=tensor,
+ dtype=paddle.Tensor,
+ function=_translate
+ )
+
+ def set_model_mode(self, mode: str):
+ r"""
+ 设置模型为 ``train`` 或 ``eval`` 的模式;目的是为切换模型的训练和推理(会关闭 dropout 等)模式。
+
+ :param mode: 应为二者之一:``["train", "eval"]``
+ """
+ assert mode in {"train", "eval"}
+ getattr(self.model, mode)()
+
+ @rank_zero_call
+ def save_model(self, filepath: str, only_state_dict: bool = True, **kwargs):
+ r"""
+ 将模型保存到 ``filepath`` 中。
+
+ :param filepath: 保存文件的文件位置(需要包括文件名)。
+ :param only_state_dict: 是否只保存模型的 ``state_dict``;如果为 ``False``,则会调用 ``paddle.jit.save``
+ 函数保存整个模型的参数,此时需要传入 ``input_spec`` 参数。
+ :kwargs:
+ * *input_spec* -- 描述存储模型 ``forward`` 方法的输入;
+ 当 ``only_state_dict`` 为 ``False`` 时必须传入,否则加载时会报错。您可以通过 ``InputSpec`` 或者示例 ``Tensor``
+ 进行描述。详细的使用方法可以参考 **PaddlePaddle** `关于 paddle.jit.save 函数的文档 `_。
+ """
+ model = self.unwrap_model()
+ if isinstance(filepath, Path):
+ filepath = str(filepath)
+ if only_state_dict:
+ states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()}
+ paddle.save(states, filepath)
+ else:
+ # paddle 在保存整个模型时需要传入额外参数
+ input_spec = kwargs.get("input_spec", None)
+ if input_spec is None:
+ raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.")
+ paddle.jit.save(model, filepath, input_spec)
+
+ def load_model(self, filepath: Union[Path, str], only_state_dict: bool = True, **kwargs):
+ """
+ 加载模型的函数;将 ``filepath`` 中的模型加载并赋值给当前 ``model`` 。
+
+ :param filepath: 保存文件的文件位置
+ :param load_state_dict: 保存的内容是否只是权重。
+ """
+ model = self.unwrap_model()
+ if isinstance(filepath, Path):
+ filepath = str(filepath)
+ # paddle 中,通过 paddle.jit.save 函数保存的模型也可以通过 paddle.load 加载为相应的 state dict
+ # 但是此时对输入的 path 有要求,必须是 dir/filename 的形式,否则会报错。
+ dirname, filename = os.path.split(filepath)
+ if not only_state_dict and dirname == "":
+ # 如果传入的是单个文件,则加上相对路径
+ filepath = os.path.join(".", filepath)
+ model.load_dict(paddle.load(filepath))
+
+ @rank_zero_call
+ def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
+ r"""
+ 断点重训的保存函数,该函数会负责保存 **优化器** 、 **sampler** 和 **fp16** 的状态,以及 **模型** (若 ``should_save_model`` 为 ``True``)
+
+ :param folder: 保存断点重训的状态的文件夹;:meth:`save_checkpoint` 函数应该在该路径下面下面新增名为 ``FASTNLP_CHECKPOINT_FILENAME`` 与
+ ``FASTNLP_MODEL_FILENAME`` (如果 ``should_save_model`` 为 ``True`` )的文件。把 model 相关的内容放入到 ``FASTNLP_MODEL_FILENAME`` 文件
+ 中,将传入的 ``states`` 以及自身产生的其它状态一并保存在 ``FASTNLP_CHECKPOINT_FILENAME`` 里面。
+ :param states: 由 :class:`~fastNLP.core.controllers.Trainer` 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态。
+ :param dataloader: 正在使用的 dataloader。
+ :param only_state_dict: 是否只保存模型的参数,当 ``should_save_model`` 为 ``False`` ,该参数无效。
+ :param should_save_model: 是否应该保存模型,如果为 ``False`` ,Driver 将不负责 model 的保存。
+ """
+ # 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变
+ # trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境;
+
+ # 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch;
+ # paddle 的 DataLoader 在初始化之后 batch_sampler 可能为 None,也可能为用户设置的 batch_sampler
+ dataloader_args = self.get_dataloader_args(dataloader)
+ if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
+ sampler = dataloader_args.batch_sampler
+ elif dataloader_args.sampler:
+ sampler = dataloader_args.sampler
+ else:
+ raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.")
+
+ num_consumed_batches = states.pop("num_consumed_batches")
+ if hasattr(sampler, "state_dict") and callable(sampler.state_dict):
+ sampler_states = sampler.state_dict()
+ if dataloader_args.batch_size is not None:
+ sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \
+ * num_consumed_batches
+ else:
+ logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
+ "it may cause missing some samples when reload.")
+ else:
+ raise RuntimeError(
+ "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.")
+
+ states['sampler_states'] = sampler_states
+
+ # 2. 保存模型的状态;
+ if should_save_model:
+ self.save_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs)
+
+ # 3. 保存 optimizers 的状态;
+ states["optimizers_state_dict"] = self.get_optimizer_state()
+ logger.debug("Save optimizer state dict.")
+
+ # 4.保存fp16的状态
+ if not isinstance(self.grad_scaler, DummyGradScaler):
+ grad_scaler_state_dict = self.grad_scaler.state_dict()
+ states['grad_scaler_state_dict'] = grad_scaler_state_dict
+
+ paddle.save(states, str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)))
+
+ def get_optimizer_state(self):
+ optimizers_state_dict = {}
+ for i in range(len(self.optimizers)):
+ optimizer: Optimizer = self.optimizers[i]
+ optimizers_state_dict[f"optimizer{i}"] = optimizer_state_to_device(optimizer.state_dict(), "cpu")
+
+ return optimizers_state_dict
+
+ def load_optimizer_state(self, states):
+ assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \
+ f"checkpoint it is:{len(states)}"
+ for i in range(len(self.optimizers)):
+ optimizer: Optimizer = self.optimizers[i]
+ optimizer.set_state_dict(states[f"optimizer{i}"])
+ logger.debug("Load optimizer state dict.")
+
+ def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
+ r"""
+ 断点重训的加载函数,该函数会负责读取数据,并且恢复 **优化器** 、**sampler** 、 **fp16** 的状态和 **模型** (如果 ``should_load_model`` 为 True)以及其它
+ 在 :meth:`save_checkpoint` 函数中执行的保存操作,然后将一个 state 字典返回给 :class:`~fastNLP.core.controllers.Trainer` ( 内容为 :meth:`save_checkpoint`
+ 接受到的 ``states`` )。
+
+ 该函数应该在所有 rank 上执行。
+
+ :param folder: 读取该 folder 下的 ``FASTNLP_CHECKPOINT_FILENAME`` 文件与 ``FASTNLP_MODEL_FILENAME``
+ (如果 should_load_model 为True)。
+ :param dataloader: 当前给定 dataloader,需要根据保存的 dataloader 状态合理设置。若该值为 ``None`` ,则不需要返回 ``'dataloader'``
+ 以及 ``'batch_idx_in_epoch'`` 这两个值。
+ :param only_state_dict: 是否仅读取模型的 state_dict ,当 ``should_save_model`` 为 ``False`` ,该参数无效。如果为 ``True`` ,说明保存的内容为权重;如果为
+ False 说明保存的是模型,但也是通过当前 Driver 的模型去加载保存的模型的权重,而不是使用保存的模型替换当前模型。
+ :param should_load_model: 是否应该加载模型,如果为 ``False`` ,Driver 将不负责加载模型。若该参数为 ``True`` ,但在保存的状态中没有
+ 找到对应的模型状态,则报错。
+ :return: :meth:`save_checkpoint` 函数输入的 ``states`` 内容。除此之外,还返回的内容有:
+
+ * *dataloader* -- 根据传入的 ``dataloader`` 与读取出的状态设置为合理状态的 dataloader。在当前 ``dataloader`` 样本数与读取出的 sampler 样本数
+ 不一致时报错。
+ * *batch_idx_in_epoch* -- :class:`int` 类型的数据,表明当前 epoch 进行到了第几个 batch 。请注意,该值不能仅通过保存的数据中读取的,因为前后两次运行的
+ ``batch_size`` 可能有变化,而应该符合以下等式::
+
+ 返回的 dataloader 还会产生的 batch 数量 + batch_idx_in_epoch = 原来不断点训练时的 batch 的总数
+
+ 由于 ``返回的 dataloader 还会产生的batch数`` 在 ``batch_size`` 与 ``drop_last`` 参数给定的情况下,无法改变,因此只能通过调整 ``batch_idx_in_epoch``
+ 这个值来使等式成立。一个简单的计算原则如下:
+
+ * drop_last 为 ``True`` 时,等同于 floor(sample_in_this_rank/batch_size) - floor(num_left_samples/batch_size);
+ * drop_last 为 ``False`` 时,等同于 ceil(sample_in_this_rank/batch_size) - ceil(num_left_samples/batch_size)。
+ """
+ states = paddle.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)))
+
+ # 1. 加载 optimizers 的状态;
+ optimizers_state_dict = states.pop("optimizers_state_dict")
+ self.load_optimizer_state(optimizers_state_dict)
+
+ # 2. 加载模型状态;
+ if should_load_model:
+ self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict)
+
+ # 3. 加载fp16的状态;
+ if "grad_scaler_state_dict" in states:
+ grad_scaler_state_dict = states.pop("grad_scaler_state_dict")
+ if not isinstance(self.grad_scaler, DummyGradScaler):
+ self.grad_scaler.load_state_dict(grad_scaler_state_dict)
+ logger.debug("Load grad_scaler state dict...")
+ elif not isinstance(self.grad_scaler, DummyGradScaler):
+ logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "
+ f"the training process may be unstable.")
+
+ # 4. 恢复 sampler 的状态;
+ dataloader_args = self.get_dataloader_args(dataloader)
+ if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
+ sampler = dataloader_args.batch_sampler
+ elif isinstance(dataloader_args.sampler, ReproducibleSampler):
+ sampler = dataloader_args.sampler
+ elif isinstance(dataloader_args.sampler, PaddleRandomSampler):
+ sampler = RandomSampler(dataloader_args.sampler.data_source)
+ logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.")
+ elif self.is_distributed():
+ raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or "
+ "`ReproducibleSampler`.")
+ else:
+ sampler = ReproduceBatchSampler(
+ batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
+ batch_size=dataloader_args.batch_size,
+ drop_last=dataloader_args.drop_last
+ )
+ sampler.load_state_dict(states.pop("sampler_states"))
+ states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)
+
+ # 5. 修改 trainer_state.batch_idx_in_epoch
+ # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler;
+ if not isinstance(sampler, ReproducibleBatchSampler):
+ if dataloader_args.drop_last:
+ batch_idx_in_epoch = len(
+ sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size
+ else:
+ batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \
+ (sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size
+ # sampler 是 batch_sampler;
+ else:
+ batch_idx_in_epoch = sampler.batch_idx_in_epoch
+
+ states["batch_idx_in_epoch"] = batch_idx_in_epoch
+
+ return states
+
+ def get_evaluate_context(self):
+ r"""
+ 返回一个不计算梯度的环境用来对模型进行评测。
+
+ :return: 上下文对象 ``paddle.no_grad``;
+ """
+ return paddle.no_grad
+
+ @staticmethod
+ def move_model_to_device(model: "paddle.nn.Layer", device: Union[str, int, "paddle.CUDAPlace", "paddle.CPUPlace"]):
+ r"""
+ 用来将模型 ``model`` 转移到指定的设备上。
+
+ .. note::
+
+ 在 **Paddle** 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。
+
+ :param model: 需要进行转移的模型。
+ :param device: 目标设备。
+ """
+ if device is not None:
+ model.to(device)
+
+ def move_data_to_device(self, batch: Any) -> Any:
+ r"""
+ 将数据集合 ``batch`` 迁移到指定的机器上。
+
+ .. note::
+
+ 在 **Paddle** 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。
+
+ :param batch: 包含 :class:`paddle.Tensor` 的数据集合,可以是 **List**、**Dict** 等嵌套类型。
+ :return: 移动到指定机器后的 ``batch``。
+ """
+ device = _convert_data_device(self.data_device)
+ return paddle_move_data_to_device(batch, device)
+
+ @staticmethod
+ def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover
+ # implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
+ global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))
+ # TODO gpu
+ process_seed = paddle.fluid.core.default_cpu_generator().initial_seed()
+ # back out the base seed so we can use all the bits
+ base_seed = process_seed - worker_id
+ ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
+ # use 128 bits (4 x 32-bit words)
+ np.random.seed(ss.generate_state(4))
+ # Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module
+ paddle_ss, stdlib_ss = ss.spawn(2)
+ paddle.seed(paddle_ss.generate_state(1, dtype=np.uint64)[0])
+ # use 128 bits expressed as an integer
+ stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum()
+ random.seed(stdlib_seed)
+
+ def set_deterministic_dataloader(self, dataloader):
+ """
+ 为了确定性训练要对 ``dataloader`` 进行修改,保证在确定随机数种子后,每次重新训练得到的结果是一样的。
+ """
+ if dataloader.worker_init_fn is None:
+ dataloader.worker_init_fn = partial(self.worker_init_function, rank=self.global_rank)
+
+ def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx):
+ r"""
+ 对于分布式的 ``sampler``,需要在每一个 ``epoch`` 前设置随机数种子,来保证每一个进程上的 ``shuffle`` 是一样的。
+
+ :param dataloader: 需要设置 ``epoch`` 的 ``dataloader``
+ :param cur_epoch_idx: 当前是第几个 ``epoch``
+ """
+ if callable(getattr(dataloader.batch_sampler, "set_epoch", None)):
+ dataloader.batch_sampler.set_epoch(cur_epoch_idx)
+ elif callable(getattr(dataloader.batch_sampler.sampler, "set_epoch", None)):
+ dataloader.batch_sampler.sampler.set_epoch(cur_epoch_idx)
+
+ @staticmethod
+ def get_dataloader_args(dataloader: "DataLoader"):
+ """
+ 从 ``dataloader`` 中获取参数 ``dataset``, ``batch_sampler``, ``sampler``, ``batch_size``, ``shuffle``
+ 和 ``drop_last`` 。
+ """
+ @dataclass
+ class Res:
+ dataset: Optional[Dataset] = None
+ batch_sampler: Optional[BatchSampler] = None
+ sampler: Optional[Sampler] = None
+ batch_size: Optional[int] = None
+ shuffle: Optional[bool] = None
+ drop_last: Optional[bool] = None
+
+ res = Res()
+
+ # paddle 的 DataLoader 一定会有 dataset 属性;
+ res.dataset = dataloader.dataset
+
+ if dataloader.batch_sampler is not None:
+ # 不过在 paddle 中,我们限定了 batch_sampler 不能为 None
+ res.batch_sampler = dataloader.batch_sampler
+ if hasattr(dataloader.batch_sampler, "batch_size"):
+ res.batch_size = getattr(dataloader.batch_sampler, "batch_size")
+ # 用户使用的是自己的 batch_sampler 并且其没有 "batch_size" 属性;
+ else:
+ dataloader_iter = iter(dataloader)
+ pre_sample = next(dataloader_iter)
+ res.batch_size = pre_sample.shape[0]
+
+ if hasattr(dataloader.batch_sampler, "sampler"):
+ res.sampler = dataloader.batch_sampler.sampler
+ if hasattr(dataloader.batch_sampler.sampler, "shuffle"):
+ res.shuffle = dataloader.batch_sampler.sampler.shuffle
+ elif isinstance(dataloader.batch_sampler.sampler, PaddleRandomSampler):
+ res.shuffle = True
+ else:
+ res.shuffle = False
+ # ReproduceBatchSampler 的情况
+ elif hasattr(dataloader.batch_sampler, "batch_sampler"):
+ batch_sampler = dataloader.batch_sampler.batch_sampler
+ res.sampler = batch_sampler.sampler
+ if hasattr(batch_sampler.sampler, "shuffle"):
+ res.shuffle = dataloader.batch_sampler.sampler.shuffle
+ elif isinstance(batch_sampler.sampler, PaddleRandomSampler):
+ res.shuffle = True
+ else:
+ res.shuffle = False
+ else:
+ res.sampler = None
+ res.shuffle = False
+
+ if hasattr(dataloader.batch_sampler, "drop_last"):
+ res.drop_last = getattr(dataloader.batch_sampler, "drop_last")
+ # 用户使用的是自己的 batch_sampler 并且其没有 "drop_last" 属性;
+ else:
+ res.drop_last = False
+
+ return res
diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py
new file mode 100644
index 00000000..e035f03c
--- /dev/null
+++ b/fastNLP/core/drivers/paddle_driver/single_device.py
@@ -0,0 +1,173 @@
+import os
+import contextlib
+from typing import Optional, Dict, Union, Callable, Tuple
+
+from .paddle_driver import PaddleDriver
+from .utils import replace_batch_sampler, replace_sampler
+from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
+from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES
+from fastNLP.core.utils import (
+ auto_param_call,
+ get_paddle_gpu_str,
+ get_paddle_device_id,
+)
+from fastNLP.core.utils.paddle_utils import _convert_data_device
+from fastNLP.core.utils.utils import _get_fun_msg
+from fastNLP.core.samplers import (
+ ReproducibleBatchSampler,
+ ReproduceBatchSampler,
+ ReproducibleSampler,
+ RandomSampler,
+ re_instantiate_sampler,
+)
+from fastNLP.core.log import logger
+
+if _NEED_IMPORT_PADDLE:
+ import paddle
+ from paddle import DataParallel
+ from paddle.fluid.reader import _DatasetKind
+ from paddle.io import (
+ RandomSampler as PaddleRandomSampler,
+ SequenceSampler as PaddleSequenialSampler,
+ BatchSampler as PaddleBatchSampler,
+ )
+
+__all__ = [
+ "PaddleSingleDriver",
+]
+
+class PaddleSingleDriver(PaddleDriver):
+ """
+ 实现了 **PaddlePaddle** 框架下在单卡或 ``cpu`` 环境下训练功能的 **Driver**。
+
+ :param model: 训练时使用的 **PaddlePaddle** 模型
+ :param device: 训练使用的设备
+ :param fp16: 是否开启混合精度训练
+ :param paddle_kwargs:
+ * *gradscaler_kwargs* -- 用于 ``fp16=True`` 时,提供给 :class:`paddle.amp.GradScaler` 的参数。
+ :kwargs:
+ * *model_wo_auto_param_call* (``bool``) -- 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为。
+
+ .. note::
+
+ 关于该参数的详细说明,请参见 :class:`~fastNLP.core.controllers.Trainer` 中的描述;函数 ``auto_param_call`` 详见 :func:`fastNLP.core.utils.auto_param_call`。
+
+ """
+ def __init__(self, model: "paddle.nn.Layer", device: Union[str, int], fp16: Optional[bool] = False, paddle_kwargs: Dict = None, **kwargs):
+ if isinstance(model, DataParallel):
+ raise ValueError("`paddle.DataParallel` is not supported in `PaddleSingleDriver`")
+
+ cuda_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES)
+ if cuda_visible_devices == "":
+ device = "cpu"
+ logger.info("You have set `CUDA_VISIBLE_DEVICES` to '' in system environment variable, and we are gonna to"
+ "use `cpu` instead of `gpu` device.")
+
+ super(PaddleSingleDriver, self).__init__(model, fp16=fp16, paddle_kwargs=paddle_kwargs, **kwargs)
+
+ if device is None:
+ raise ValueError("Parameter `device` can not be None in `PaddleSingleDriver`.")
+
+ if device != "cpu":
+ device_id = get_paddle_device_id(device)
+ if cuda_visible_devices is not None:
+ os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices.split(",")[device_id]
+ self.model_device = get_paddle_gpu_str(device)
+
+ self.local_rank = 0
+ self.global_rank = 0
+ self.world_size = 1
+
+ def setup(self):
+ r"""
+ 初始化训练环境;设置当前训练的设备,并将模型迁移到对应设备上。
+ """
+ device = _convert_data_device(self.data_device)
+
+ paddle.device.set_device(device)
+ with contextlib.redirect_stdout(None):
+ self.model.to(device)
+
+ def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
+ if isinstance(batch, Dict) and not self.wo_auto_param_call:
+ return auto_param_call(fn, batch, signature_fn=signature_fn)
+ else:
+ return fn(batch)
+
+ def get_model_call_fn(self, fn: str) -> Tuple:
+ if hasattr(self.model, fn):
+ fn = getattr(self.model, fn)
+ if not callable(fn):
+ raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
+ logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
+ return fn, None
+ elif fn in {"train_step", "evaluate_step"}:
+ logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...')
+ return self.model, self.model.forward
+ else:
+ raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")
+
+ def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None,
+ reproducible: bool = False):
+
+ # 暂时不支持iterableDataset
+ assert dataloader.dataset_kind != _DatasetKind.ITER, \
+ "FastNLP does not support `IteratorDataset` now."
+ # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load_checkpoint 函数调用;
+ if isinstance(dist, ReproducibleBatchSampler):
+ return replace_batch_sampler(dataloader, dist)
+ elif isinstance(dist, ReproducibleSampler):
+ return replace_sampler(dataloader, dist)
+
+ # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用;
+ args = self.get_dataloader_args(dataloader)
+ if isinstance(args.batch_sampler, ReproducibleBatchSampler):
+ batch_sampler = re_instantiate_sampler(args.batch_sampler)
+ return replace_batch_sampler(dataloader, batch_sampler)
+ elif isinstance(args.sampler, ReproducibleSampler):
+ sampler = re_instantiate_sampler(args.sampler)
+ return replace_sampler(dataloader, sampler)
+
+ if reproducible:
+ if type(args.batch_sampler) is PaddleBatchSampler:
+ if type(args.sampler) is PaddleRandomSampler:
+ if isinstance(args.sampler, PaddleRandomSampler):
+ if getattr(args.sampler, '_num_samples', None) is None \
+ and getattr(args.sampler, 'replacements', False) is False \
+ and getattr(args.sampler, 'generator', None) is None:
+ # 如果本来就是随机的,并且没有定制,直接替换掉。
+ sampler = RandomSampler(args.sampler.data_source, shuffle=True)
+ logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.")
+ return replace_sampler(dataloader, sampler)
+ elif type(args.sampler) is PaddleSequenialSampler:
+ # 需要替换为不要 shuffle 的。
+ sampler = RandomSampler(args.sampler.data_source, shuffle=False)
+ logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.")
+ return replace_sampler(dataloader, sampler)
+ batch_sampler = ReproduceBatchSampler(
+ batch_sampler=args.batch_sampler,
+ batch_size=args.batch_size,
+ drop_last=args.drop_last
+ )
+ return replace_batch_sampler(dataloader, batch_sampler)
+ else:
+ return dataloader
+
+ def unwrap_model(self):
+ """
+ :return: 训练使用的模型。
+ """
+ return self.model
+
+ @property
+ def data_device(self) -> str:
+ """
+ :return: 数据和模型所在的设备。
+ """
+ return self.model_device
+
+ def is_distributed(self) -> bool:
+ """
+ :return 是否为分布式的 **Driver** ,在 ``PaddleSingleDriver`` 中,返回 ``False``。
+ """
+ return False
diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py
new file mode 100644
index 00000000..296bcebe
--- /dev/null
+++ b/fastNLP/core/drivers/paddle_driver/utils.py
@@ -0,0 +1,288 @@
+import socket
+import os
+import struct
+import random
+import inspect
+import numpy as np
+from copy import deepcopy
+from contextlib import ExitStack, closing
+from typing import Dict, Optional
+
+from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
+from fastNLP.envs.utils import get_global_seed
+from fastNLP.envs import (
+ get_global_rank,
+ FASTNLP_BACKEND_LAUNCH,
+ FASTNLP_GLOBAL_SEED,
+)
+from fastNLP.core.samplers import ReproducibleBatchSampler
+from fastNLP.core.utils import auto_param_call, paddle_to
+from fastNLP.core.log import logger
+
+
+if _NEED_IMPORT_PADDLE:
+ import paddle
+ from paddle import nn
+ from paddle.nn import Layer
+ from paddle.io import DataLoader, BatchSampler, RandomSampler, SequenceSampler
+ from paddle.amp import auto_cast, GradScaler
+else:
+ from fastNLP.core.utils.dummy_class import DummyClass as Layer
+
+
+__all__ = [
+ "paddle_seed_everything",
+ "optimizer_state_to_device",
+]
+
+def paddle_seed_everything(seed: int = None, add_global_rank_to_seed: bool = True) -> int:
+ r"""
+ 为 **paddle**、**numpy**、**python.random** 伪随机数生成器设置种子。
+
+ :param seed: 全局随机状态的整数值种子。如果为 ``None`` 则会根据时间戳生成一个种子。
+ :param add_global_rank_to_seed: 在分布式训练中,是否在不同 **rank** 中使用不同的随机数。
+ 当设置为 ``True`` 时,**fastNLP** 会将种子加上当前的 ``global_rank``。
+ """
+ max_seed_value = np.iinfo(np.uint32).max
+ min_seed_value = np.iinfo(np.uint32).min
+
+ if seed is None:
+ if os.getenv(FASTNLP_BACKEND_LAUNCH) == "1":
+ seed = 42
+ else:
+ seed = get_global_seed()
+ logger.info(f"'FASTNLP_GLOBAL_SEED' is set to {seed} automatically.")
+ if not isinstance(seed, int):
+ seed = int(seed)
+
+ if not (min_seed_value <= seed <= max_seed_value):
+ logger.rank_zero_warning("Your seed value is too big or too small for numpy, we will choose a random seed for you.")
+ seed %= max_seed_value
+
+ os.environ[FASTNLP_GLOBAL_SEED] = f"{seed}"
+ if add_global_rank_to_seed:
+ seed += get_global_rank()
+
+ random.seed(seed)
+ np.random.seed(seed)
+ # paddle的seed函数会自行判断是否在gpu环境,如果在的话会设置gpu的种子
+ paddle.seed(seed)
+ return seed
+
+class _FleetWrappingModel(Layer):
+ """
+ 参考 :class:`fastNLP.core.drivers.torch_driver.utils._DDPWrappingModel` , **PaddlePaddle** 的分布式训练也需要用 :class:`paddle.nn.DataParallel` 进行包装,采用和
+ **pytorch** 相似的处理方式
+ """
+ def __init__(self, model: 'nn.Layer'):
+ super(_FleetWrappingModel, self).__init__()
+ self.model = model
+
+ def forward(self, batch, **kwargs) -> Dict:
+
+ fn = kwargs.pop("fastnlp_fn")
+ signature_fn = kwargs.pop("fastnlp_signature_fn")
+ wo_auto_param_call = kwargs.pop("wo_auto_param_call")
+
+ if isinstance(batch, Dict) and not wo_auto_param_call:
+ return auto_param_call(fn, batch, signature_fn=signature_fn)
+ else:
+ return fn(batch)
+
+class DummyGradScaler:
+ """
+ 用于仿造的 **GradScaler** 对象,防止重复写大量的if判断
+ """
+ def __init__(self, *args, **kwargs):
+ pass
+
+ def get_scale(self):
+ return 1.0
+
+ def is_enabled(self):
+ return False
+
+ def scale(self, outputs):
+ return outputs
+
+ def step(self, optimizer, *args, **kwargs):
+ optimizer.step(*args, **kwargs)
+
+ def update(self, new_scale=None):
+ pass
+
+ def unscale_(self, optimizer):
+ pass
+
+ def load_state_dict(self, state_dict):
+ pass
+
+ def state_dict(self):
+ return {}
+
+def _build_fp16_env(dummy=False):
+ if dummy:
+ return ExitStack, DummyGradScaler
+ else:
+ if not paddle.device.is_compiled_with_cuda():
+ raise RuntimeError("No cuda")
+ if paddle.device.cuda.get_device_capability(0)[0] < 7:
+ logger.warning(
+ "NOTE: your device does NOT support faster training with fp16, "
+ "please switch to FP32 which is likely to be faster"
+ )
+ return auto_cast, GradScaler
+
+def find_free_ports(num):
+ """
+ 在空闲的端口中找到 ``num`` 个端口
+ """
+ def __free_port():
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
+ struct.pack('ii', 1, 0))
+ s.bind(('', 0))
+ return s.getsockname()[1]
+
+ port_set = set()
+ step = 0
+ while True:
+ port = __free_port()
+ if port not in port_set:
+ port_set.add(port)
+
+ if len(port_set) >= num:
+ return port_set
+
+ step += 1
+ if step > 400:
+ logger.error(
+ "can't find avilable port and use the specified static port now!"
+ )
+ return None
+
+ return None
+
+def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler"):
+ """
+ 利用 ``batch_sampler`` 重新构建一个 ``DataLoader``,起到替换 ``batch_sampler`` 又不影响原 ``dataloader`` 的作用。
+ 考虑了用户自己定制了 ``DataLoader`` 的情形。
+ """
+ # 拿到非下划线开头的实例属性;
+ instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')}
+
+ # 拿到 dataloader '__init__' 函数的默认函数签名;可以获取参数名和参数的默认值以及类型
+ init_params = dict(inspect.signature(dataloader.__init__).parameters)
+
+ # 这里为什么要单独弄的原因在于,用户在定制自己的 dataloader 的同时可能为了方便只设定一些参数,而后面直接使用 **kwargs 的方式,这时如果
+ # 其在初始化自己的 dataloader 实例的时候加入了一些其它的新的参数(首先这一步是必要的,因为我们只能通过这样加 sampler;另一方面,用户
+ # 可能确实通过 **kwargs 加入了一些新的参数),如果假设用户是这样使用的: "super().__init__(**kwargs)",那么我们就只能去 DataLoader
+ # 中寻找;VAR_KEYWORD 代表 **kwargs
+ has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items())
+ if has_variadic_kwargs:
+ for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items():
+ if key not in init_params and key != 'self':
+ init_params[key] = value
+
+ # 如果初始化dataloader所使用的参数不是默认值,那么我们需要将其记录下来用于重新初始化时设置;
+ non_default_params = {name for name, p in init_params.items() if
+ name in instance_attrs and p.default != instance_attrs[name]}
+ # add `dataset` as it might have been replaced with `*args`
+ non_default_params.add("dataset")
+
+ reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params}
+ if isinstance(dataloader, DataLoader):
+ reconstruct_args.update({
+ "batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1,
+ "persistent_workers": dataloader._persistent_workers,
+ })
+
+ # POSITIONAL_OR_KEYWORD 代表一般的参数
+ # 收集初始化函数中出现的、一般形式的、不带默认值且不在 reconstruct_args 中的参数
+ # 也即它们没有在初始化函数和实例成员中同时出现
+ required_args = {
+ p.name
+ for p in init_params.values()
+ if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
+ and p.default is p.empty
+ and p.name not in reconstruct_args
+ }
+
+ # 这种错误针对的是 __init__ 中的参数没有用同样名字的 self 挂上;
+ if required_args:
+ required_args = sorted(required_args)
+ dataloader_self_name = dataloader.__class__.__name__
+ raise Exception(
+ f"Need to inject arguments {required_args} into the __init__ of `{dataloader_self_name}`. "
+ f"But they are not found in the attribute of `{dataloader_self_name}`, fastNLP cannot determine its "
+ f"value when try to reinitialize `{dataloader_self_name}`, please add `{required_args}` to be "
+ f"`{dataloader_self_name}`'s attribute."
+ )
+
+ # 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs;
+ if not has_variadic_kwargs:
+
+ # the dataloader signature does not allow keyword arguments that need to be passed
+ missing_kwargs = reconstruct_args.keys() - init_params.keys()
+ if missing_kwargs:
+ missing_kwargs = sorted(missing_kwargs)
+ dataloader_self_name = dataloader.__class__.__name__
+ raise Exception(
+ f"The parameter:{missing_kwargs} needed to reinitialize `{dataloader_self_name}` is not found."
+ )
+ # 如果没有kwargs,则保证一下只传入需要的参数
+ if not isinstance(dataloader, DataLoader):
+ reconstruct_args = {key:value for key,value in reconstruct_args.items() if key in init_params}
+
+ return type(dataloader)(**reconstruct_args)
+
+def replace_sampler(dataloader, new_sampler):
+ """
+ 使用 ``new_sampler`` 重新构建一个 ``BatchSampler``,并替换到 ``dataloader`` 中
+ """
+ batch_sampler = getattr(dataloader, "batch_sampler")
+ if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler):
+ raise RuntimeError("It should not be running here, please report a bug to us.")
+ new_batch_sampler = deepcopy(dataloader.batch_sampler)
+ new_batch_sampler.sampler = new_sampler
+ return replace_batch_sampler(dataloader, new_batch_sampler)
+
+def optimizer_state_to_device(state, device):
+ r"""
+ 将一个 ``optimizer`` 的 ``state_dict`` 迁移到对应的设备。
+
+ :param state: :func:`optimzier.state_dict` 获取的 state_dictt
+ :param device: 要迁移到的目的设备。
+ :return: 迁移后的新的 state_dict。
+ """
+ new_state = {}
+ for name, param in state.items():
+ if isinstance(param, dict):
+ new_state[name] = optimizer_state_to_device(param, device)
+ elif isinstance(param, paddle.Tensor):
+ new_state[name] = paddle_to(param, device).clone()
+ else:
+ new_state[name] = param
+ return new_state
+
+def _check_dataloader_args_for_distributed(args, controller='Trainer'):
+ """
+ 检查 dataloader 的 sampler 情况,如果用户替换了自己定制的 sampler ,为了防止
+ 在分布式训练中出现错误会报错。
+ """
+ error_flag = (type(args.sampler) not in {RandomSampler, SequenceSampler})
+ if controller == 'Trainer':
+ mode = 'training'
+ substitution = 'fastNLP.RandomSampler'
+ error_flag = (type(args.batch_sampler) != BatchSampler) or error_flag
+ else: # Evaluator
+ mode = 'evaluation'
+ substitution = 'fastNLP.UnrepeatedSequentialSampler'
+ if error_flag:
+ raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause "
+ f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into "
+ f"``{substitution}``. The customized sampler should set for distributed running "
+ f"before initializing ``{controller}`` , and then set the "
+ f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``."
+ f"\n Current batch_sampler: {type(args.batch_sampler)}"
+ f"\n Current sampler: {type(args.sampler)}")
diff --git a/fastNLP/core/drivers/torch_driver/__init__.py b/fastNLP/core/drivers/torch_driver/__init__.py
new file mode 100644
index 00000000..deb2effe
--- /dev/null
+++ b/fastNLP/core/drivers/torch_driver/__init__.py
@@ -0,0 +1,25 @@
+__all__ = [
+ 'TorchDriver',
+ 'TorchSingleDriver',
+ 'TorchDDPDriver',
+ 'airScaleDriver',
+ 'DeepSpeedDriver',
+ 'TorchFSDPDriver',
+ 'torch_seed_everything',
+ 'optimizer_state_to_device'
+]
+
+from .ddp import TorchDDPDriver
+# todo 实现 fairscale 后再将 fairscale 导入到这里;
+from .fairscale import FairScaleDriver
+from .single_device import TorchSingleDriver
+from .torch_driver import TorchDriver
+from .deepspeed import DeepSpeedDriver
+from .torch_fsdp import TorchFSDPDriver
+from .utils import torch_seed_everything, optimizer_state_to_device
+
+
+
+
+
+
diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py
new file mode 100644
index 00000000..d9f08a2d
--- /dev/null
+++ b/fastNLP/core/drivers/torch_driver/ddp.py
@@ -0,0 +1,734 @@
+r"""
+"""
+
+r"""
+`TorchDDPDriver` 目前支持的三种启动方式:
+1. 用户自己不进行 ddp 的任何操作,直接使用我们的 Trainer,这时是由我们自己使用 `open_subprocesses` 拉起多个进程,
+ 然后 `TorchDDPDriver` 自己通过调用 `dist.init_process_group` 来初始化 ddp 的通信组;(情况 A)
+2. 用户同样不在 Trainer 之外初始化 ddp,但是用户自己使用 python -m torch.distributed.launch 拉起来创建多个进程,这时我们仍旧
+ 会通过调用 `dist.init_process_group` 来初始化 ddp 的通信组;(情况 B)
+3. 用户自己在外面初始化 DDP,并且通过 python -m torch.distributed.launch 拉起,这时无论是多个进程的拉起和 ddp 的通信组的建立
+ 都由用户自己操作,我们只会在 driver.setup 的时候对 `TorchDDPDriver` 设置一些必要的属性值;(情况 C)
+
+注意多机的启动强制要求用户在每一台机器上使用 python -m torch.distributed.launch 启动;因此我们不会在 `TorchDDPDriver` 中保存
+ 任何当前有多少台机器的信息(num_nodes,不是 gpu 的数量);
+
+Part 1:三种启动方式的具体分析:
+(1)对于用户运行的脚本中,如果 `driver.setup` 只会被调用一次(意味着用户的启动脚本中只初始化了一个 trainer/evaluator)时,
+ `TorchDDPDriver` 在初始化以及 `setup` 函数中会做的事情分别如下所示:
+ -> 情况 A:这种情况下用户传入的 model 在一定是普通的 model(没有经 `DistributedDataParallel` 包裹的model),
+ 因为 `DistributedDataParallel` 的使用一定要求 init_process_group 已经被调用用来建立当前的 ddp 通信组;但是这意味着如果
+ 用户需要使用 2 张以上的显卡,那么其必然需要使用 torch.distributed.launch 来启动,意味着就不是情况 A 了;
+ 这时我们首先会调用 `TorchDDPDriver.open_subprocess` 函数来拉起多个进程,其中进程的数量等于用户传入给 trainer 的使用的 gpu
+ 的数量(例如 `Trainer` 中的参数是 device=[0, 1, 6, 7],那么我们就会使用第 0、1、6、7 张 gpu 来拉起 4 个进程);
+ 接着我们会调用 `dist.init_process_group` 来初始化各个进程之间的通信组;
+ 这里需要注意拉起的新的进程会从前到后完整地运行一遍用户的启动脚本(例如 main.py),因此也都会运行这两个函数,但是需要注意只有进程 0
+ 才会去真正地运行 `TorchDDPDriver.open_subprocess`;进程 0 运行到 `dist.init_process_group`,pytorch 会阻塞进程 0 继续
+ 向前运行,直到其它进程也运行到这里;
+ 最后我们会设置这个进程对应的 device,然后将模型迁移到对应的机器上,再使用 `DistributedDataParallel` 将模型包裹;
+ 至此,ddp 的环境配置过程全部完成;
+
+ -> 情况 B:注意这种情况我们直接限定了用户是通过 torch.distributed.launch 拉起,并且没有自己建立 ddp 的通信组。这时在
+ `TorchDDPDriver` 的初始化和 setup 函数的调用过程中,与情况 A 首要的不同就在于用户在 trainer 中输入的参数 device 不再有效,
+ 这时每个进程所使用的 gpu 是我们直接通过 `torch.device("cuda:{local_rank}")` 来配置的;因此,如果用户想要实现使用特定 gpu
+ 设备的目的,可以通过自己设置环境变量实现(例如 os.environ["CUDA_VISIBLE_DEVICE"] 来实现);剩下的操作和情况 A 类似;
+
+ -> 情况 C:注意这种情况我们限定了用户是通过 torch.distributed.launch 拉起,并且 ddp 的通信组也是由自己建立。这时基本上所有的
+ 与操作相关的操作都应当由用户自己完成,包括迁移模型到对应 gpu 上以及将模型用 `DistributedDataParallel` 包裹等。
+(2)如果 `driver.setup` 函数在脚本中会被调用两次及以上(意味着用户的启动脚本初始化了两个及以上的 trainer/evaluator)时:
+注意这种情况下我们是会保证前后两个 trainer/evaluator 使用的 `TorchDDPDriver` 以及其初始化方式的一致性,换句话说,如果 trainer1
+ 检测到的启动方式是 '情况 A',那么我们会保证 trainer2 检测到的启动方式同样是 '情况A'(即使这需要一些额外的处理);因此这里我们主要讨论
+ 我们是通过怎样的操作来保证 trainer2/3/... 检测到的启动方式是和 trainer1 一致的;简单来说,我们是通过使用环境变量来标记每一种不同的
+ 启动方式来实现这一点的:
+我们会使用 `FASTNLP_DISTRIBUTED_CHECK` 来标记 '情况 A',使用 `fastnlp_torch_launch_not_ddp` 来标记 '情况 B',意味着我们在
+ 使用 '情况 A' 来启动 `TorchDDPDriver` 时,我们会将 `FASTNLP_DISTRIBUTED_CHECK` 这一字符串注入到环境变量中,而 '情况 B' 时则
+ 会将 `fastnlp_torch_launch_not_ddp` 这一字符串注入到环境变量中。因此在 trainer2 的 `TorchDDPDriver` 的初始化和 setup 过程中,
+ 如果检测到这些特殊的环境变量,我们就会将启动方式变更为其对应的启动方式,即使其它的参数特征属于另外的启动方式。
+
+Part 2:对应的代码细节:
+ 1. 如何判断当前的各进程之间的通信组已经被建立(ddp 已经被初始化);
+ dist.is_initialized();
+ 2. 如何判断不同的进程是否是由 `python -m torch.distributed.launch` 拉起还是由我们的 `TorchDDPDriver.open_subprocess`
+ 函数拉起;
+ 我们会在用户脚本 `import fastNLP` 的时候检测当前的环境变量中是否有 'LOCAL_RANK'、'WORLD_SIZE' 以及没有 `FASTNLP_DISTRIBUTED_CHECK`,
+ 如果满足条件,则我们会向环境变量中注入特殊的值 'FASTNLP_BACKEND_LAUNCH' 来标记用户是否使用了 `python -m torch.distributed.launch`
+ 来拉起多个进程;
+ 3. 整体的处理判断流程:
+ ___________________________________
+ |进入 TorchDDPDriver 的 __init__ 函数|
+ ———————————————————————————————————
+ ↓
+ ___________________________________________________
+ | 判断不同的进程是否是由 torch.distributed.launch 拉起 |
+ |(或者我们自己的 open_subprocess 函数拉起) | -------------->
+ ——————————————————————————————————————————————————— |
+ ↓ 是由 torch.distributed.launch 拉起 | 我们自己的 open_subprocess 函数拉起多个进程
+ ___________________________ |
+ ←←←←← | 检测用户是否自己初始化了 ddp | |
+ ↓ ——————————————————————————— ↓
+ ↓ ↓ 是 ________
+ ↓ ______ | 情况 A |
+ ↓ 否 |情况 C| —————————
+ ↓ ———————
+ ↓
+ ↓ ______
+ ↓ -----------> |情况 B|
+ ———————
+ 4. 为了完成全部的建立 ddp 所需要的操作,三种情况都需要做的事情,以及每件事情的职责归属:
+
+ 情况 A | 情况 B | 情况 C
+ ________________________________________________________________________________________________________
+ 配置 ddp 所 | TorchDDPDriver.open_subprocess | torch.distributed.launch| torch.distributed.launch
+ 需要的环境变量 | | |
+ ————————————————————————————————————————————————————————————————————————————————————————————————————————
+ 开启多个进程 | TorchDDPDriver.open_subprocess | torch.distributed.launch| torch.distributed.launch
+ ————————————————————————————————————————————————————————————————————————————————————————————————————————
+ 调用 dist. | | |
+ init_process\ | TorchDDPDriver.setup | TorchDDPDriver.setup | 用户自己调用
+ _group 函数 | | |
+ ————————————————————————————————————————————————————————————————————————————————————————————————————————
+ 设置 TorchDDPDriver | | |
+ 的 world_size 和 | TorchDDPDriver.setup | TorchDDPDriver.setup | TorchDDPDriver.setup
+ global_rank 属性 | | |
+ ————————————————————————————————————————————————————————————————————————————————————————————————————————
+
+Part 3:其它的处理细节:
+ 1. 环境变量;
+ fastNLP 的 `TorchDDPDriver` 运行时所需要的环境变量分为两种,一种是 torch 的 ddp 运行所需要的环境变量;另一种是 fastNLP 自己
+ 的环境变量。前者的配置情况如上表所示;而后者中的大多数环境变量则是在用户 import fastNLP 时就设置好了;
+ 2. parallel_device, model_device 和 data_device 的关系;
+ parallel_device 为 `TorchDDPDriver` 的参数,model_device 和 data_device 都为 driver 的属性;
+ 其中 data_device 仅当情况 C 时由用户自己指定;如果其不为 None,那么在模型 forward 的时候,我们就会将数据迁移到 data_device 上;
+ model_device 永远都为单独的一个 torch.device;
+
+ 情况 A | 情况 B | 情况 C
+ ________________________________________________________________________________________________________
+ parallel_device | 由用户传入trainer的参数 | 为 torch.device( | 为 torch.device(
+ | device 决定,必须是一个list, | "cuda:{local_rank}") | "cuda:{local_rank}")
+ | 其中每一个对象都是 torch.device | |
+ ————————————————————————————————————————————————————————————————————————————————————————————————————————
+ model_device | parallel_device[local_rank] | parallel_device | None
+ ————————————————————————————————————————————————————————————————————————————————————————————————————————
+ data_device | model_device | model_device | 由用户传入 trainer 的参数
+ | | | data_device 决定
+ ————————————————————————————————————————————————————————————————————————————————————————————————————————
+
+ 3. _DDPWrappingModel 的作用;
+ 因为我们即需要调用模型的 `train_step`、`evaluate_step`、`test_step` 方法,又需要通过 `DistributedDataParallel` 的
+ forward 函数来帮助我们同步各个设备上的梯度,因此我们需要先将模型单独包裹一层,然后在 forward 的时候,其先经过 `DistributedDataParallel`
+ 的 forward 方法,然后再经过 `_DDPWrappingModel` 的 forward 方法,我们会在该 forward 函数中进行判断,确定调用的是模型自己的
+ forward 函数,还是 `train_step`、`evaluate_step`、`test_step` 方法。
+
+ 4. 当某一个进程出现 exception 后,`TorchDDPDriver` 的处理;
+
+ 不管是什么情况,`TorchDDPDriver` 在 `setup` 函数的最后,都会将所有进程的 pid 主动记录下来,这样当一个进程出现 exception 后,
+ driver 的 on_exception 函数就会被 trainer 调用,其会调用 os.kill 指令将其它进程 kill 掉;
+"""
+
+import os
+import sys
+import __main__
+import socket
+import numpy as np
+from time import sleep
+from typing import List, Optional, Union, Dict, Tuple, Callable
+
+from fastNLP.envs.imports import _NEED_IMPORT_TORCH
+
+if _NEED_IMPORT_TORCH:
+ import torch
+ import torch.distributed as dist
+ from torch.nn.parallel import DistributedDataParallel
+ from torch.utils.data import BatchSampler
+
+__all__ = [
+ 'TorchDDPDriver'
+]
+
+from .torch_driver import TorchDriver
+from fastNLP.core.drivers.torch_driver.utils import (
+ _DDPWrappingModel,
+ replace_sampler,
+ replace_batch_sampler
+)
+from fastNLP.core.drivers.utils import distributed_open_proc
+from fastNLP.core.utils import auto_param_call, check_user_specific_params
+from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, \
+ ReproducibleBatchSampler, \
+ re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler
+from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC
+from fastNLP.core.log import logger
+from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object
+from .utils import _check_dataloader_args_for_distributed
+
+
+class TorchDDPDriver(TorchDriver):
+ r"""
+ ``TorchDDPDriver`` 通过开启多个进程,让每个进程单独使用一个 gpu 设备来实现分布式训练。
+
+ .. note::
+
+ 您在绝大多数情况下不需要自己使用到该类,通过向 ``Trainer`` 传入正确的参数,您可以方便快速地部署您的分布式训练。
+
+ ``TorchDDPDriver`` 目前支持的三种启动方式:
+
+ 1. 用户自己不进行 ``ddp`` 的任何操作,直接使用我们的 ``Trainer``,这时是由我们自己使用 ``open_subprocesses`` 拉起多个进程,
+ 然后 ``TorchDDPDriver`` 自己通过调用 ``dist.init_process_group`` 来初始化 ddp 的通信组;(情况 A)
+
+ .. code-block::
+
+ trainer = Trainer(
+ ...
+ driver='torch',
+ device=[0, 1]
+ )
+ trainer.run()
+
+ 通过运行 ``python train.py`` 启动;
+
+ 2. 用户同样不在 ``Trainer`` 之外初始化 ``ddp``,但是用户自己使用 ``python -m torch.distributed.launch`` 拉起来创建多个进程,这时我们仍旧
+ 会通过调用 ``dist.init_process_group`` 来初始化 ``ddp`` 的通信组;(情况 B)
+
+ .. code-block::
+
+ trainer = Trainer(
+ ...
+ driver='torch',
+ device=None, # fastNLP 会忽略传入的 device,并根据 local_rank 自动分配
+ )
+ trainer.run()
+
+ 通过运行 ``python -m torch.distributed.launch --nproc_per_node 2 train.py`` 启动;
+
+ 3. 用户自己在外面初始化 ``DDP``,并且通过 ``python -m torch.distributed.launch`` 拉起,这时无论是多个进程的拉起和 ddp 的通信组的建立
+ 都由用户自己操作,我们只会在 ``driver.setup`` 的时候对 ``TorchDDPDriver`` 设置一些必要的属性值;(情况 C)
+
+ .. code-block::
+
+ import torch.distributed as dist
+ from torch.nn.parallel import DistributedDataParallel
+
+ # 获取当前的进程信息;
+ ...
+
+ # 初始化 ddp 不同进程间的通信组;
+ dist.init_process_group(...)
+
+ # 初始化模型使用 DistributedDataParallel 包裹;
+ model = Model()
+ model = DistributedDataParallel(model, ...)
+
+ # 注意此时仍旧不需要您主动地将 datalaoder 的 sampler 替换为 DistributedSampler;
+ trainer = Trainer(
+ ...
+ driver='torch',
+ device=None, # fastNLP 会忽略传入的 device,并根据 local_rank 自动分配
+ )
+ trainer.run()
+
+ 通过运行 ``python -m torch.distributed.launch --nproc_per_node 2 train.py`` 启动;
+
+ 注意多机的启动强制要求用户在每一台机器上使用 ``python -m torch.distributed.launch`` 启动;因此我们不会在 ``TorchDDPDriver`` 中保存
+ 任何当前有多少台机器的信息。
+
+ :param model: 传入给 ``Trainer`` 的 ``model`` 参数
+ :param parallel_device: 用于分布式训练的 ``gpu`` 设备
+ :param is_pull_by_torch_run: 标志当前的脚本的启动是否由 ``python -m torch.distributed.launch`` 启动的
+ :param fp16: 是否开启 fp16 训练
+ :param torch_kwargs:
+ * *ddp_kwargs* -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入
+ ``{'find_unused_parameters': True}`` 来解决有参数不参与前向运算导致的报错等
+ * *set_grad_to_none* -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 ``None``
+ * *non_blocking* -- 表示用于 :meth:`torch.Tensor.to` 方法的参数 non_blocking
+ * *gradscaler_kwargs* -- 用于 ``fp16=True`` 时,提供给 :class:`torch.amp.cuda.GradScaler` 的参数
+ :kwargs:
+ * *wo_auto_param_call* (``bool``) -- 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为
+
+ .. note::
+
+ 关于该参数的详细说明,请参见 :class:`~fastNLP.core.controllers.Trainer` 中的描述;函数 ``auto_param_call`` 详见 :func:`fastNLP.core.utils.auto_param_call`。
+ """
+
+ def __init__(
+ self,
+ model,
+ parallel_device: Optional[Union[List["torch.device"], "torch.device"]],
+ is_pull_by_torch_run: bool = False,
+ fp16: bool = False,
+ torch_kwargs: Dict = None,
+ **kwargs
+ ):
+
+ # 在加入很多东西后,需要注意这里调用 super 函数的位置;
+ super(TorchDDPDriver, self).__init__(model, fp16=fp16, torch_kwargs=torch_kwargs, **kwargs)
+
+ if isinstance(model, torch.nn.DataParallel):
+ raise ValueError(f"Parameter `model` can not be `DataParallel` in `TorchDDPDriver`, it should be "
+ f"`torch.nn.Module` or `torch.nn.parallel.DistributedDataParallel` type.")
+
+ # 如果用户自己在外面初始化 DDP,那么其一定是通过 python -m torch.distributed.launch 拉起的;
+ self.is_pull_by_torch_run = is_pull_by_torch_run
+ self.parallel_device = parallel_device
+ if not is_pull_by_torch_run and parallel_device is None:
+ raise ValueError(
+ "Parameter `parallel_device` can not be None when using `TorchDDPDriver`. This error is caused "
+ "when your value of parameter `device` is `None` in your `Trainer` instance.")
+
+ # 注意我们在 initialize_torch_driver 中的逻辑就是如果是 is_pull_by_torch_run,那么我们就直接把 parallel_device 置为当前进程的gpu;
+ if is_pull_by_torch_run:
+ self.model_device = parallel_device
+ else:
+ # 我们的 model_device 一定是 torch.device,而不是一个 list;
+ self.model_device = parallel_device[self.local_rank]
+
+ # 如果用户自己在外面初始化了 DDP;
+ self.outside_ddp = False
+ if dist.is_initialized() and FASTNLP_DISTRIBUTED_CHECK not in os.environ and \
+ "fastnlp_torch_launch_not_ddp" not in os.environ:
+ # 如果用户自己在外面初始化了 DDP,那么我们要求用户传入的模型一定是已经由 DistributedDataParallel 包裹后的模型;
+ if not isinstance(model, DistributedDataParallel):
+ raise RuntimeError(
+ "It is not allowed to input a normal model instead of `DistributedDataParallel` when"
+ "you initialize the ddp process out of our control.")
+
+ self.outside_ddp = True
+ # 用户只有将模型上传到对应机器上后才能用 DistributedDataParallel 包裹,因此如果用户在外面初始化了 DDP,那么在 TorchDDPDriver 中
+ # 我们就直接将 model_device 置为 None;
+ self.model_device = None
+
+ # 当用户自己在外面初始化 DDP 时我们会将 model_device 置为 None,这是用户可以通过 `data_device` 将对应的数据移到指定的机器上;
+ self._data_device = kwargs.get("data_device", None)
+ if isinstance(self._data_device, int):
+ if self._data_device < 0:
+ raise ValueError("Parameter `data_device` can not be smaller than 0.")
+ _could_use_device_num = torch.cuda.device_count()
+ if self._data_device >= _could_use_device_num:
+ raise ValueError("The gpu device that parameter `device` specifies is not existed.")
+ self._data_device = torch.device(f"cuda:{self._data_device}")
+ elif isinstance(self._data_device, str):
+ self._data_device = torch.device(self._data_device)
+ elif self._data_device is not None and not isinstance(self._data_device, torch.device):
+ raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.")
+
+ self._master_port = None
+ # world_size 表示的就是全局的显卡的数量;
+ self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device)
+ self.global_rank = 0
+
+ self._fsdp_kwargs = self._torch_kwargs.get("ddp_kwargs", {})
+ check_user_specific_params(self._fsdp_kwargs, DistributedDataParallel.__init__, DistributedDataParallel.__name__)
+ if len(self.model._buffers) != 0 and self._fsdp_kwargs.get("broadcast_buffers", None) is None:
+ logger.info("Notice your model has buffers and you are using `TorchDDPDriver`, but you do not set "
+ "'broadcast_buffers' in your trainer. Cause in most situations, this parameter can be set"
+ " to 'False' to avoid redundant data communication between different processes.")
+
+ self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error")
+ assert isinstance(self.output_from_new_proc, str), "Parameter `output_from_new_proc` can only be `str` type."
+ if self.output_from_new_proc not in {"all", "ignore", "only_error"}:
+ os.makedirs(name=self.output_from_new_proc, exist_ok=True)
+ self.output_from_new_proc = os.path.abspath(self.output_from_new_proc)
+
+ self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的;
+ self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹;
+
+ def setup(self):
+ r"""
+ 准备分布式环境,该函数主要做以下两件事情:
+
+ 1. 开启多进程,每个 ``gpu`` 设备对应单独的一个进程;
+ 2. 每个进程将模型迁移到自己对应的 ``gpu`` 设备上;然后使用 ``DistributedDataParallel`` 包裹模型;
+ """
+ if self._has_setup:
+ return
+ self._has_setup = True
+ # 如果用户需要使用多机模式,那么一定进入到这里;
+ if self.is_pull_by_torch_run:
+
+ if self.outside_ddp:
+ self.world_size = dist.get_world_size()
+ self.global_rank = dist.get_rank()
+ else:
+ # dist.get_world_size() 只能在 dist.init_process_group 初始化之后进行调用;
+ self.world_size = int(os.environ.get("WORLD_SIZE"))
+ self.global_rank = int(os.environ.get("RANK"))
+ logger.info(f"World size: {self.world_size}, Global rank: {self.global_rank}")
+
+ if not dist.is_initialized():
+ dist.init_process_group(
+ backend="nccl", rank=self.global_rank, world_size=self.world_size
+ )
+
+ os.environ["fastnlp_torch_launch_not_ddp"] = "yes"
+
+ # 进入到这里的情况时:
+ # dist.is_initialized 一定为 False;
+ # 一定是单机;
+ # self.parallel_device 一定是 List[torch.device];
+ else:
+ if not dist.is_initialized():
+ # 这里主要的问题在于要区分 rank0 和其它 rank 的情况;
+ self.world_size = len(self.parallel_device)
+ self.open_subprocess()
+ self.global_rank = self.local_rank # rank 一定是通过环境变量去获取的;
+ dist.init_process_group(
+ backend="nccl", rank=self.global_rank, world_size=self.world_size
+ )
+ # 用户在这个 trainer 前面又初始化了一个 trainer,并且使用的是 TorchDDPDriver;
+ else:
+ # 如果 `dist.is_initialized() == True`,那么说明 TorchDDPDriver 在之前已经初始化并且已经 setup 过一次,那么我们需要保证现在
+ # 使用的(即之后的)TorchDDPDriver 的设置和第一个 TorchDDPDriver 是完全一样的;
+ pre_num_processes = int(os.environ[FASTNLP_DISTRIBUTED_CHECK])
+ if pre_num_processes != len(self.parallel_device):
+ raise RuntimeError(
+ "Notice you are using `TorchDDPDriver` after one instantiated `TorchDDPDriver`, it is not"
+ "allowed that your second `TorchDDPDriver` has a new setting of parameters "
+ "`num_nodes` and `num_processes`.")
+ self.world_size = dist.get_world_size()
+ self.global_rank = dist.get_rank()
+
+ if not self.outside_ddp:
+ self.configure_ddp()
+
+ self.barrier()
+ # 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作;
+ self._pids = [torch.tensor(0, dtype=torch.int).to(self.data_device) for _ in range(dist.get_world_size())]
+ dist.all_gather(self._pids, torch.tensor(os.getpid(), dtype=torch.int).to(self.data_device))
+ local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) if "LOCAL_WORLD_SIZE" in os.environ else None
+ if local_world_size is None:
+ local_world_size = torch.tensor(int(os.environ.get("LOCAL_RANK")), dtype=torch.int).to(self.data_device)
+ dist.all_reduce(local_world_size, op=dist.ReduceOp.MAX)
+ local_world_size = local_world_size.tolist() + 1
+
+ node_rank = self.global_rank // local_world_size
+ self._pids = self._pids[node_rank * local_world_size: (node_rank + 1) * local_world_size]
+ self._pids = self.tensor_to_numeric(self._pids)
+
+ def configure_ddp(self):
+ torch.cuda.set_device(self.model_device)
+ self.model.to(self.model_device)
+ if not isinstance(self.model, DistributedDataParallel):
+ self.model = DistributedDataParallel(
+ # 注意这里的 self.model_device 是 `torch.device` type,因此 self.model_device.index;
+ _DDPWrappingModel(self.model), device_ids=[self.model_device.index],
+ **self._fsdp_kwargs
+ )
+ self._has_ddpwrapped = True
+
+ def open_subprocess(self):
+ if self.local_rank == 0:
+ # Script called as `python a/b/c.py`
+ if __main__.__spec__ is None: # pragma: no-cover
+ # pull out the commands used to run the script and resolve the abs file path
+ command = sys.argv
+ command[0] = os.path.abspath(command[0])
+ # use the same python interpreter and actually running
+ command = [sys.executable] + command
+ # Script called as `python -m a.b.c`
+ else:
+ command = [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:]
+
+ os.environ['MASTER_ADDR'] = self.master_address
+ os.environ['MASTER_PORT'] = self.master_port
+
+ os.environ["RANK"] = "0"
+ os.environ["LOCAL_RANK"] = str(self.local_rank)
+ os.environ["WORLD_SIZE"] = f"{self.world_size}"
+
+ os.environ[FASTNLP_DISTRIBUTED_CHECK] = f"{len(self.parallel_device)}"
+ os.environ[FASTNLP_GLOBAL_RANK] = "0"
+ logger._set_distributed()
+
+ interactive_ddp_procs = []
+
+ for rank in range(1, len(self.parallel_device)):
+ env_copy = os.environ.copy()
+ env_copy["LOCAL_RANK"] = f"{rank}"
+ env_copy["RANK"] = f"{rank}"
+
+ # 如果是多机,一定需要用户自己拉起,因此我们自己使用 open_subprocesses 开启的进程的 FASTNLP_GLOBAL_RANK 一定是 LOCAL_RANK;
+ env_copy[FASTNLP_GLOBAL_RANK] = str(rank)
+
+ proc = distributed_open_proc(self.output_from_new_proc, command, env_copy, self.global_rank)
+
+ interactive_ddp_procs.append(proc)
+ delay = np.random.uniform(1, 5, 1)[0]
+ sleep(delay)
+
+ @property
+ def master_address(self) -> str:
+ """
+ 分布式训练中的地址 ``MASTER_ADDR``
+ """
+ return os.environ.get("MASTER_ADDR", "127.0.0.1")
+
+ @property
+ def master_port(self) -> str:
+ """
+ 分布式训练使用的端口 ``MASTER_PORT``
+ """
+ if self.outside_ddp:
+ return os.environ.get("MASTER_PORT")
+ if self._master_port is None:
+ self._master_port = os.environ.get("MASTER_PORT", find_free_network_port())
+ return self._master_port
+
+ @property
+ def world_size(self) -> int:
+ """
+ 分布式训练的进程总数 ``WORLD_SIZE``
+ """
+ return self._world_size
+
+ @world_size.setter
+ def world_size(self, size: int):
+ self._world_size = size
+
+ @property
+ def global_rank(self) -> int:
+ """
+ 当前进程的全局编号 ``global_rank``
+ """
+ return self._global_rank
+
+ @global_rank.setter
+ def global_rank(self, rank: int) -> None:
+ self._global_rank = rank
+
+ @property
+ def local_rank(self) -> int: # 这个不会受到 all_rank_call_context 的影响
+ """
+ 当前进程的局部编号 ``local_rank``
+ """
+ return int(os.environ.get("LOCAL_RANK", 0))
+
+ @property
+ def data_device(self):
+ if self.outside_ddp:
+ return self._data_device
+ return self.model_device
+
+ def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
+ if self._has_ddpwrapped:
+ return self.model(batch, fastnlp_fn=fn, fastnlp_signature_fn=signature_fn,
+ wo_auto_param_call=self.wo_auto_param_call)
+ else:
+ if isinstance(batch, Dict) and not self.wo_auto_param_call:
+ return auto_param_call(fn, batch, signature_fn=signature_fn)
+ else:
+ return fn(batch)
+
+ def get_model_call_fn(self, fn: str) -> Tuple:
+ model = self.unwrap_model()
+ if self._has_ddpwrapped:
+ if hasattr(model, fn):
+ fn = getattr(model, fn)
+ if not callable(fn):
+ raise RuntimeError(f"The `{fn}` attribute of model is not `Callable`.")
+ return fn, None
+ elif fn in {"train_step", "evaluate_step"}:
+ return model, model.forward
+ else:
+ raise RuntimeError(f"There is no `{fn}` method in your model.")
+ else:
+ if hasattr(model, fn):
+ logger.warning("Notice your model is a `DistributedDataParallel` model. And your model also implements "
+ f"the `{fn}` method, which we can not call actually, we will"
+ " call `forward` function instead of `train_step` and you should note that.")
+ elif fn not in {"train_step", "evaluate_step"}:
+ raise RuntimeError(f"There is no `{fn}` method in your model. And also notice that your model is a "
+ "`DistributedDataParallel` model, which means that we will only call model.forward "
+ "function when we are in forward propagation.")
+
+ return self.model, model.forward
+
+ def set_dist_repro_dataloader(self, dataloader,
+ dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]] = None,
+ reproducible: bool = False):
+ # 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load_checkpoint 函数调用;
+ # 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数;
+ if isinstance(dist, ReproducibleBatchSampler):
+ dist.set_distributed(
+ num_replicas=self.world_size,
+ rank=self.global_rank,
+ pad=True
+ )
+ return replace_batch_sampler(dataloader, dist)
+ if isinstance(dist, ReproducibleSampler):
+ dist.set_distributed(
+ num_replicas=self.world_size,
+ rank=self.global_rank,
+ pad=True
+ )
+ return replace_sampler(dataloader, dist)
+
+ # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用;
+ # trainer, evaluator
+ if dist is None:
+ if reproducible:
+ raise RuntimeError("It is not allowed to save checkpoint if the sampler is not allowed to be replaced.")
+ else:
+ args = self.get_dataloader_args(dataloader)
+ if isinstance(args.batch_sampler, ReproducibleBatchSampler):
+ return replace_batch_sampler(dataloader, re_instantiate_sampler(args.batch_sampler))
+ if isinstance(args.sampler, ReproducibleSampler):
+ return replace_sampler(dataloader, re_instantiate_sampler(args.sampler))
+ return dataloader
+ # trainer
+ elif dist == "dist":
+ args = self.get_dataloader_args(dataloader)
+ # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为;
+ if isinstance(args.batch_sampler, ReproducibleBatchSampler):
+ batch_sampler = re_instantiate_sampler(args.batch_sampler)
+ batch_sampler.set_distributed(
+ num_replicas=self.world_size,
+ rank=self.global_rank,
+ pad=True
+ )
+ return replace_batch_sampler(dataloader, batch_sampler)
+ elif isinstance(args.sampler, ReproducibleSampler):
+ sampler = re_instantiate_sampler(args.sampler)
+ sampler.set_distributed(
+ num_replicas=self.world_size,
+ rank=self.global_rank,
+ pad=True
+ )
+ return replace_sampler(dataloader, sampler)
+ else:
+ _check_dataloader_args_for_distributed(args, controller='Trainer')
+ sampler = RandomSampler(
+ dataset=args.dataset,
+ shuffle=args.shuffle,
+ seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0))
+ )
+ sampler.set_distributed(
+ num_replicas=self.world_size,
+ rank=self.global_rank,
+ pad=True
+ )
+ return replace_sampler(dataloader, sampler)
+ # evaluator
+ elif dist == "unrepeatdist":
+ args = self.get_dataloader_args(dataloader)
+ if type(args.batch_sampler) != BatchSampler:
+ # TODO 这里的目的是判断用户的 batch_sampler 是定制的,可能需要完善
+ logger.warning("Note that you are using customized ``batch_sampler`` in evaluate dataloader or" \
+ "train dataloader while testing ``overfit_batches``, which may cause that" \
+ "the data for distributed evaluation is not unrepeated.")
+ if isinstance(args.sampler, ReproducibleSampler):
+ sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler)
+ elif not isinstance(args.sampler, UnrepeatedSampler):
+ _check_dataloader_args_for_distributed(args, controller='Evaluator')
+ sampler = UnrepeatedSequentialSampler(
+ dataset=args.dataset
+ )
+ else:
+ sampler = re_instantiate_sampler(args.sampler)
+ sampler.set_distributed(
+ num_replicas=self.world_size,
+ rank=self.global_rank
+ )
+ # TODO 这里暂时统一替换为 BatchSampler
+ batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False)
+ return replace_batch_sampler(dataloader, batch_sampler)
+ else:
+ raise ValueError(
+ "Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).")
+
+ def is_global_zero(self):
+ r"""
+ :return: 当前的进程是否在全局上是进程 0 。
+ """
+ return self.global_rank == 0
+
+ def get_model_no_sync_context(self):
+ r"""
+ :return: 一个 ``context`` 上下文环境,用于关闭各个进程之间的同步。
+ """
+ # 注意此时的 model 是 "DistributedDataParallel" 对象;
+ return self.model.no_sync
+
+ def unwrap_model(self):
+ r"""
+ :return: 没有经过 ``DistributedDataParallel`` 包裹的原始模型。
+ """
+ _module = self.model.module
+ if isinstance(_module, _DDPWrappingModel):
+ return _module.model
+ else:
+ return _module
+
+ def get_local_rank(self) -> int:
+ r"""
+ :return: 当前进程局部的进程编号。
+ """
+ return self.local_rank
+
+ def barrier(self):
+ r"""
+ 通过使用该函数来使得各个进程之间同步操作。
+ """
+ if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1 and dist.is_initialized(): # 当 FASTNLP_NO_SYNC 小于 1 时实际执行
+ torch.distributed.barrier(async_op=False)
+
+ def is_distributed(self):
+ r"""
+ :return: 当前使用的 driver 是否是分布式的 driver,对于 ``TorchDDPDriver`` 来说,该函数一定返回 ``True``。
+ """
+ return True
+
+ def broadcast_object(self, obj, src: int = 0, group=None, **kwargs):
+ r"""
+ 从 ``src`` 端将 ``obj`` 对象(可能是 tensor ,可能是 object )广播到其它进程。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行
+ 传输,然后在接收处处再加载回来。仅在分布式的 driver 中有实际意义。
+
+ :param obj: obj,可能是 Tensor 或 嵌套类型的数据
+ :param src: 发送方的 ``global_rank``
+ :param group: 进程所在的通信组
+ :return: 如果当前 rank 是接收端,则返回接收到的参数;如果是 source 端则返回发送的内容。如果环境变量 ``FASTNLP_NO_SYNC`` 为 **2** 则
+ 返回 ``None``
+ """
+ if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC == 2 直接返回。
+ return
+ return fastnlp_torch_broadcast_object(obj, src, device=self.data_device, group=group)
+
+ def all_gather(self, obj, group) -> List:
+ r"""
+ 将 ``obj`` 互相传送到其它所有的 rank 上,其中 ``obj`` 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,将会尝试通过
+ pickle 进行序列化,接收到之后再反序列化。
+
+ example::
+
+ >>> # rank 0
+ >>> obj = {'a': 1, 'b':[1, 2], 'c':{'d': 1}}
+ >>> # rank 1
+ >>> obj = {'a': 1, 'b':[1, 2], 'c':{'d': 2}}
+ >>> # after all_gather():
+ >>> result = [
+ {'a': 1, 'b':[1, 2], 'c':{'d': 1}},
+ {'a': 1, 'b':[1, 2], 'c':{'d': 2}}
+ ]
+
+ :param obj: 需要传输的对象,在每个 rank 上都应该保持相同的结构。
+ :param group: 进程所在的通信组。
+ :return: 所有 rank 发送的 ``obj`` 聚合在一起的内容;如果环境变量 ``FASTNLP_NO_SYNC`` 为 **2** 则不会执行,直接返回 ``[obj]`` 。
+ """
+ if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC 表示不执行
+ return [obj]
+ return fastnlp_torch_all_gather(obj, group=group)
+
+ def on_exception(self):
+ super().on_exception()
+ dist.destroy_process_group() # 防止在之后的 barrier 出现卡死的问题。
+
+
+def find_free_network_port() -> str:
+ """
+ 在 localhost 上找到一个空闲端口;
+ 当我们不想连接到真正的主节点但必须设置“MASTER_PORT”环境变量时在单节点训练中很有用。
+ """
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ s.bind(("", 0))
+ s.listen(1)
+ port = s.getsockname()[1]
+ s.close()
+ return str(port)
diff --git a/fastNLP/core/drivers/torch_driver/deepspeed.py b/fastNLP/core/drivers/torch_driver/deepspeed.py
new file mode 100644
index 00000000..51497dbf
--- /dev/null
+++ b/fastNLP/core/drivers/torch_driver/deepspeed.py
@@ -0,0 +1,509 @@
+import os
+import argparse
+import logging
+from pathlib import Path
+
+from typing import Union, Dict, List
+from .torch_driver import TorchDriver
+from .ddp import TorchDDPDriver
+from .utils import _create_default_config, _DeepSpeedWrappingModel
+from fastNLP.core.utils import nullcontext
+from fastNLP.core.log import logger
+from fastNLP.envs import(
+ FASTNLP_DISTRIBUTED_CHECK,
+ FASTNLP_CHECKPOINT_FILENAME
+)
+from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED
+
+if _NEED_IMPORT_TORCH:
+ import torch
+ import torch.distributed as dist
+ from torch.optim import Optimizer
+
+if _NEED_IMPORT_DEEPSPEED:
+ import deepspeed
+ from deepspeed import DeepSpeedEngine, DeepSpeedOptimizer
+
+__all__ = [
+ "DeepSpeedDriver",
+]
+
+class DeepSpeedDriver(TorchDDPDriver):
+ """
+ 实现 ``deepspeed`` 分布式训练的 ``Driver``。
+
+ .. note::
+
+ 您在绝大多数情况下不需要自己使用到该类,通过向 ``Trainer`` 传入正确的参数,您可以方便快速地部署您的分布式训练;
+
+ ``DeepSpeedDriver`` 目前支持的三种启动方式:
+
+ 1. 用户自己不进行任何操作,直接使用我们的 ``Trainer``,这时是由我们自己使用 ``open_subprocesses`` 拉起多个进程,
+ 然后 ``DeepSpeedDriver`` 自己通过调用 ``deepspeed.initialize`` 来初始化模型和通信组;(情况 A)
+
+ .. code-block::
+
+ trainer = Trainer(
+ ...
+ driver='deepspeed',
+ device=[0, 1]
+ )
+ trainer.run()
+
+ 通过运行 ``python train.py`` 启动;
+
+ 2. 用户同样不在 ``Trainer`` 之外初始化 ``deepspeed``,但是用户自己使用 ``python -m torch.distributed.launch`` 拉起来创建多个进程,这时我们仍旧
+ 会通过调用 ``model.initialize`` 来初始化 ``ddp`` 的通信组;(情况 B)
+
+ .. code-block::
+
+ trainer = Trainer(
+ ...
+ driver='deepspeed',
+ device=None, # fastNLP 会忽略传入的 device,并根据 local_rank 自动分配
+ )
+ trainer.run()
+
+ 通过运行 ``deepspeed train.py`` 启动;
+
+ 3. 用户自己在外面初始化 ``deepspeed``,并且通过 ``deepspeed train.py`` 拉起,这时无论是多个进程的拉起和通信组的建立
+ 都由用户自己操作,我们只会在 ``driver.setup`` 的时候对 ``DeepSpeedDriver`` 设置一些必要的属性值;(情况 C)
+
+ .. code-block::
+
+ import deepspeed
+
+ # 初始化
+ model, _, _, _ = deepspeed.initialize(model, ...)
+
+ trainer = Trainer(
+ ...
+ driver='deepspeed',
+ device=None, # fastNLP 会忽略传入的 device,并根据 local_rank 自动分配
+ )
+ trainer.run()
+
+ 通过运行 ``deepspeed train.py`` 启动。
+
+ :param model: 传入给 ``Trainer`` 的 ``model`` 参数。
+ :param parallel_device: 用于分布式训练的 ``gpu`` 设备。
+ :param is_pull_by_torch_run: 标志当前的脚本的启动是否由 ``python -m torch.distributed.launch`` 启动的。
+ :param fp16: 是否开启 fp16 训练。
+ :param deepspeed_kwargs:
+ * *strategy* -- 使用 ZeRO 优化的策略,默认为 ``deepspeed``;目前仅支持以下值:
+
+ * ``deepspeed`` -- 使用 ZeRO 的第二阶段,等同于 ``deepspeed_stage_2``;
+ * ``deepspeed_stage_1`` -- 使用 ZeRO 的第一阶段,仅将 ``optimizer`` 的状态分散到不同设备上;
+ * ``deepspeed_stage_2`` -- 使用 ZeRO 的第二阶段,将 ``optimizer`` 和 **梯度** 分散到不同设备上;
+ * ``deepspeed_stage_2_offload`` -- 使用 ZeRO 的第二阶段,并且借助 cpu 的内存来进一步节约显存;
+ * ``deepspeed_stage_3`` -- 使用 ZeRO 的第三阶段,将 ``optimizer`` 、**梯度** 和 **模型** 分散到不同设备上;
+ * ``deepspeed_stage_3_offload`` -- 使用 ZeRO 的第三阶段,并且借助 cpu 的内存来进一步节约显存;
+ * ``deepspeed_stage_3_offload_nvme`` -- 使用 ZeRO 的第三阶段,并且借助 NVMe 硬盘来进一步节约显存;
+ * *logging_level* -- ``deepspeed`` 库的日志等级,默认为 **logging.ERROR**。
+ * *config* -- ``deepspeed`` 的各项设置;**FastNLP** 允许用户传入自己的设置以增强灵活性,但这会使参数
+ 中的 ``optimizer`` 、``strategy`` 、 ``fp16`` 等失效,即当这个参数存在时,**FastNLP** 会用该参数覆盖
+ 其它的设置。
+ :kwargs:
+ * *accumulation_steps* -- 即在 :class:`~fastNLP.core.controllers.Trainer` 传入的 ``accumulation_steps`` 。 deepspeed 会将 ``config`` 的
+ ``gradient_accumulation_steps`` 设置为该值。
+ * *train_dataloader* -- 即在 :class:`~fastNLP.core.controllers.Trainer` 传入的 ``train_dataloader`` 。 ``deepspeed`` 需要通过它来获取
+ 数据的 ``batch_size`` 用于设置 ``train_micro_batch_size_per_gpu`` 。如果没有传入的话,则会设置为 **1** 。
+ * *wo_auto_param_call* (``bool``) -- 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为
+
+ .. note::
+
+ 关于该参数的详细说明,请参见 :class:`~fastNLP.core.controllers.Trainer` 中的描述;函数 ``auto_param_call`` 详见 :func:`fastNLP.core.utils.auto_param_call`。
+ """
+ # TODO fp16 load_config
+ def __init__(
+ self,
+ model,
+ parallel_device: Union[List["torch.device"], "torch.device"],
+ is_pull_by_torch_run = False,
+ fp16: bool = False,
+ deepspeed_kwargs: Dict = None,
+ **kwargs
+ ):
+ assert _NEED_IMPORT_DEEPSPEED, "Deepspeed is not imported."
+ kwargs.pop("torch_kwargs", None)
+ self._ds_kwargs = deepspeed_kwargs
+ TorchDriver.__init__(self, model=model, fp16=False, torch_kwargs=deepspeed_kwargs, **kwargs)
+ self.fp16 = fp16
+
+ # 如果用户自己在外面初始化 DDP,那么其一定是通过 python -m torch.distributed.launch 拉起的;
+ self.is_pull_by_torch_run = is_pull_by_torch_run
+ self.parallel_device = parallel_device
+ if not is_pull_by_torch_run and parallel_device is None:
+ raise ValueError(
+ "Parameter `parallel_device` can not be None when using `TorchDeepSpeedDriver`. This error is caused "
+ "when your value of parameter `device` is `None` in your `Trainer` instance.")
+
+ # 注意我们在 initialize_torch_driver 中的逻辑就是如果是 is_pull_by_torch_run,那么我们就直接把 parallel_device 置为当前进程的gpu;
+ if is_pull_by_torch_run:
+ self.model_device = parallel_device
+ else:
+ # 我们的 model_device 一定是 torch.device,而不是一个 list;
+ self.model_device = parallel_device[self.local_rank]
+
+ # 如果用户自己在外面初始化了 deepspeed;
+ self.outside_ddp = False
+ if dist.is_initialized() and FASTNLP_DISTRIBUTED_CHECK not in os.environ and \
+ "fastnlp_torch_launch_not_ddp" not in os.environ:
+ # 如果用户自己在外面初始化了 deepspeed,那么我们要求用户传入的模型一定是已经由 DeepSpeedEngine 包裹后的模型;
+ if not isinstance(model, DeepSpeedEngine):
+ raise RuntimeError(
+ "It is not allowed to input a normal model instead of `DeepSpeedEngine` when"
+ "you initialize the ddp process out of our control.")
+
+ self.outside_ddp = True
+ self.config = model.config
+ self.model_device = None
+
+ self._data_device = kwargs.get("data_device", None)
+ if isinstance(self._data_device, int):
+ if self._data_device < 0:
+ raise ValueError("Parameter `data_device` can not be smaller than 0.")
+ _could_use_device_num = torch.cuda.device_count()
+ if self._data_device >= _could_use_device_num:
+ raise ValueError("The gpu device that parameter `device` specifies is not existed.")
+ self._data_device = torch.device(f"cuda:{self._data_device}")
+ elif isinstance(self._data_device, str):
+ self._data_device = torch.device(self._data_device)
+ elif self._data_device is not None and not isinstance(self._data_device, torch.device):
+ raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.")
+
+ self._master_port = None
+ # world_size 表示的就是全局的显卡的数量;
+ self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device)
+ self.global_rank = 0
+
+ self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error")
+ assert isinstance(self.output_from_new_proc, str), "Parameter `output_from_new_proc` can only be `str` type."
+ if self.output_from_new_proc not in {"all", "ignore", "only_error"}:
+ os.makedirs(name=self.output_from_new_proc, exist_ok=True)
+ self.output_from_new_proc = os.path.abspath(self.output_from_new_proc)
+
+ self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的;
+ self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹;
+ self.accumulation_steps = kwargs.get("accumulation_steps", 1)
+ # 获取 batch_size 以设置 train_micro_batch_size_per_gpu 参数
+ train_dl = kwargs.get("train_dataloader", None)
+ if train_dl is not None:
+ self.train_micro_batch_size = self.get_dataloader_args(train_dl).batch_size
+ else:
+ logger.warning("No `train_dataloader` found, and we will set `train_micro_batch_size_per_gpu`"
+ "to 1 for deepspeed configuration.")
+ self.train_micro_batch_size = 1
+
+ self.strategy = self._ds_kwargs.get("strategy", "deepspeed")
+ deepspeed_logging_level = self._ds_kwargs.get("logging_level", logging.ERROR)
+ deepspeed.utils.logging.logger.setLevel(deepspeed_logging_level)
+
+ @staticmethod
+ def _check_optimizer_legality(optimizers):
+ for each_optimizer in optimizers:
+ if not isinstance(each_optimizer, (Optimizer, DeepSpeedOptimizer)):
+ raise TypeError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' or "
+ f"'DeepSpeedOptimizer'type, not {type(each_optimizer)}.")
+
+ def setup(self):
+ r"""
+ 准备分布式环境,该函数主要做以下两件事情:
+
+ 1. 开启多进程,每个 gpu 设备对应单独的一个进程;
+ 2. 使用 ``deepspeed.initialize`` 包裹模型;
+ """
+ if len(self.optimizers) != 1:
+ raise ValueError("Multi optimizers is not supported for `DeepSpeedDriver` right now.")
+ if self._has_setup:
+ return
+ self._has_setup = True
+ self.setup_config()
+ # 如果用户需要使用多机模式,那么一定进入到这里;
+ if self.is_pull_by_torch_run:
+ if self.outside_ddp:
+ self.world_size = dist.get_world_size()
+ self.global_rank = dist.get_rank()
+ else:
+ # dist.get_world_size() 只能在 dist.init_process_group 初始化之后进行调用;
+ self.world_size = int(os.environ.get("WORLD_SIZE"))
+ self.global_rank = int(os.environ.get("RANK"))
+ logger.info(f"World size: {self.world_size}, Global rank: {self.global_rank}")
+
+ if not dist.is_initialized():
+ deepspeed.init_distributed("nccl", distributed_port=self.master_port)
+
+ os.environ["fastnlp_torch_launch_not_ddp"] = "yes"
+
+ # 进入到这里的情况时:
+ # dist.is_initialized 一定为 False;
+ # 一定是单机;
+ # self.parallel_device 一定是 List[torch.device];
+ else:
+ if not dist.is_initialized():
+ # 这里主要的问题在于要区分 rank0 和其它 rank 的情况;
+ self.world_size = len(self.parallel_device)
+ self.open_subprocess()
+ self.global_rank = self.local_rank # rank 一定是通过环境变量去获取的;
+ deepspeed.init_distributed("nccl", distributed_port=self.master_port)
+ # 用户在这个 trainer 前面又初始化了一个 trainer,并且使用的是 DeepSpeedDriver;
+ else:
+ # 如果 `dist.is_initialized() == True`,那么说明 DeepSpeedDriver 在之前已经初始化并且已经 setup 过一次,那么我们需要保证现在
+ # 使用的(即之后的)DeepSpeedDriver 的设置和第一个 DeepSpeedDriver 是完全一样的;
+ pre_num_processes = int(os.environ[FASTNLP_DISTRIBUTED_CHECK])
+ if pre_num_processes != len(self.parallel_device):
+ raise RuntimeError(
+ "Notice you are using `DeepSpeedDriver` after one instantiated `DeepSpeedDriver`, it is not"
+ "allowed that your second `DeepSpeedDriver` has a new setting of parameters "
+ "`num_nodes` and `num_processes`.")
+ self.world_size = dist.get_world_size()
+ self.global_rank = dist.get_rank()
+
+ if not self.outside_ddp:
+ torch.cuda.set_device(self.model_device)
+ # 不加 dist.broadcast_object_list 会发生设备在 4,5 但是模型会同步到 0,1 的情况
+ # 原因未知
+ dist.broadcast_object_list(["test"], 0, None)
+ self.configure_ddp()
+
+ self.barrier()
+ # 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作;
+ self._pids = [torch.tensor(0, dtype=torch.int).to(self.data_device) for _ in range(dist.get_world_size())]
+ dist.all_gather(self._pids, torch.tensor(os.getpid(), dtype=torch.int).to(self.data_device))
+ local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) if "LOCAL_WORLD_SIZE" in os.environ else None
+ if local_world_size is None:
+ local_world_size = torch.tensor(int(os.environ.get("LOCAL_RANK")), dtype=torch.int).to(self.data_device)
+ dist.all_reduce(local_world_size, op=dist.ReduceOp.MAX)
+ local_world_size = local_world_size.tolist() + 1
+
+ node_rank = self.global_rank // local_world_size
+ self._pids = self._pids[node_rank * local_world_size: (node_rank + 1) * local_world_size]
+ self._pids = self.tensor_to_numeric(self._pids)
+
+ def configure_ddp(self):
+
+ # 设置 deepspeed
+ if not isinstance(self.model, DeepSpeedEngine):
+ model=_DeepSpeedWrappingModel(self.model, self.fp16)
+ model_parameters = filter(lambda p: p.requires_grad, model.parameters())
+ self.model, ds_optimizer, _, _ = deepspeed.initialize(
+ args=argparse.Namespace(device_rank=self.model_device.index),
+ model=model,
+ optimizer=self.optimizers[0],
+ model_parameters=model_parameters,
+ config=self.config,
+ dist_init_required=False
+ )
+ self._optimizers = [ds_optimizer]
+
+ if self.config.get("activation_checkpointing"):
+ checkpoint_config = self.config["activation_checkpointing"]
+ deepspeed.checkpointing.configure(
+ mpu_=None,
+ partition_activations=checkpoint_config.get("partition_activations"),
+ contiguous_checkpointing=checkpoint_config.get("contiguous_memory_optimization"),
+ checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"),
+ profile=checkpoint_config.get("profile"),
+ )
+
+ self._has_ddpwrapped = True
+
+ def setup_config(self):
+
+ self.config = self._ds_kwargs.get("config")
+ if self.config is not None:
+ logger.warning("Notice that you have defined a configuration for deepspeed and parameters like"
+ "`optimizers`, `strategy` and `fp16` may not take effects.")
+ return
+
+ if self.strategy == "deepspeed":
+ self.config = _create_default_config(stage=2)
+ elif self.strategy == "deepspeed_stage_1":
+ self.config = _create_default_config(stage=1)
+ elif self.strategy == "deepspeed_stage_2":
+ self.config = _create_default_config(stage=2)
+ elif self.strategy == "deepspeed_stage_2_offload":
+ self.config = _create_default_config(stage=2, offload_optimizer=True)
+ elif self.strategy == "deepspeed_stage_3":
+ self.config = _create_default_config(stage=3)
+ elif self.strategy == "deepspeed_stage_3_offload":
+ self.config = _create_default_config(
+ stage=3,
+ offload_optimizer=True,
+ offload_parameters=True,
+ )
+ elif self.strategy == "deepspeed_stage_3_offload_nvme":
+ self.config = _create_default_config(
+ stage=3,
+ offload_optimizer=True,
+ offload_parameters=True,
+ remote_device="nvme",
+ offload_params_device="nvme",
+ offload_optimizer_device="nvme",
+ )
+ else:
+ raise ValueError(f"Unknown deepspeed strategy {self.strategy}.")
+
+ # 设置成 max_int 防止 deepspeed 的输出干扰 fastnlp 的输出
+ self.config.setdefault("steps_per_print", 2147483647)
+ self.config["gradient_accumulation_steps"] = self.accumulation_steps
+ self.config.setdefault("train_micro_batch_size_per_gpu", self.train_micro_batch_size)
+
+ if self.fp16:
+ if "fp16" not in self.config:
+ # FP16 is a DeepSpeed standalone AMP implementation
+ logger.debug("Enabling DeepSpeed FP16.")
+ # TODO 这部分是否可以像 pytorch-lightning 那样给用户定制
+ self.config["fp16"] = {
+ "enabled": True,
+ "loss_scale": 0,
+ "initial_scale_power": True,
+ "loss_scale_window": 1000,
+ "hysteresis": 2,
+ "min_loss_scale": 1,
+ }
+ elif "amp" not in self.config:
+ logger.debug("Enabling DeepSpeed APEX Implementation.")
+ self.config["amp"] = {"enabled": True, "opt_level": "O1"}
+
+ def zero_grad(self):
+ """
+ 进行梯度置零操作;由于 :meth:`DeepSpeedEngine.step` 包含了 :meth:`zero_step` 的功能,因此该接口实际无意义。
+ """
+ # DeepSpeedEngine.step 包含了 zero_grad 功能
+ pass
+
+ def backward(self, loss):
+ """
+ 对 ``loss`` 进行反向传播
+ """
+ self.model.backward(loss)
+
+ def step(self):
+ """
+ 更新模型的参数
+ """
+ self.model.step()
+
+ def get_model_no_sync_context(self):
+ r"""
+ :return: 一个 ``context`` 上下文环境,用于关闭各个进程之间的同步;在 ``deepspeed`` 中,返回一个空的上下文
+ """
+ # 注意此时的 model 是 "DistributedDataParallel" 对象;
+ return nullcontext
+
+ def save_model(self, filepath: Union[str, Path], only_state_dict: bool = False, **kwargs):
+ """
+ 保存的模型到 ``filepath`` 中。
+
+ :param filepath: 文件路径
+ :param only_state_dict: 是否只保存权重;在 ``DeepSpeedDriver`` 中该参数无效。
+ :param kwargs: 需要传入 **deepspeed** 模型 :meth:`save_checkpoint` 的其它参数。
+ :return:
+ """
+ # deepspeed engine 要求在每个 rank 都调用 save_checkpoint,故去掉了 rank_zero_call 装饰器
+ if self.stage_3:
+ logger.rank_zero_warning(
+ "When saving the DeepSpeed Stage 3 checkpoint, "
+ "each worker will save a shard of the checkpoint within a directory. "
+ # TODO check一下
+ # "If a single file is required after training, "
+ # "see https://pytorch-lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#"
+ # "deepspeed-zero-stage-3-single-file for instructions."
+ )
+ if not only_state_dict:
+ logger.rank_zero_warning("Only saving state dict is not allowed for `DeepSpeedDriver`. We will save its "
+ "checkpoint for you instead.")
+ self.model.save_checkpoint(filepath, **kwargs)
+
+ def load_model(self, filepath: Union[Path, str], only_state_dict: bool = False, **kwargs):
+ """
+ 从 ``filepath`` 中加载权重并赋值到当前 driver 的模型上。
+
+ :param filepath: 加载权重或模型的路径
+ :param load_state_dict: 保存的内容是否只是权重;在 ``DeepSpeedDriver`` 中该参数无效。
+ :param kwargs: 需要传入 **deepspeed** 模型 :meth:`load_checkpoint` 的其它参数。
+ :return:
+ """
+ if not only_state_dict:
+ logger.warning("Only loading state dict is not allowed for `DeepSpeedDriver`. We will load its "
+ "checkpoint for you instead.")
+ self.model.load_checkpoint(filepath, **kwargs)
+
+ def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
+ r"""
+ 断点重训的保存函数,该函数会负责保存 **优化器** 、 **sampler** 和 **fp16** 的状态,以及 **模型** (若 ``should_save_model`` 为 ``True``)
+
+ :param folder: 保存断点重训的状态的文件夹;:meth:`save_checkpoint` 函数应该在该路径下面下面新增名为 ``FASTNLP_CHECKPOINT_FILENAME`` 与
+ ``FASTNLP_MODEL_FILENAME`` (如果 ``should_save_model`` 为 ``True`` )的文件。把 model 相关的内容放入到 ``FASTNLP_MODEL_FILENAME`` 文件
+ 中,将传入的 ``states`` 以及自身产生其它状态一并保存在 ``FASTNLP_CHECKPOINT_FILENAME`` 里面。
+ :param states: 由 :class:`~fastNLP.core.controllers.Trainer` 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态。
+ :param dataloader: 正在使用的 dataloader。
+ :param only_state_dict: 是否只保存模型的参数,当 ``should_save_model`` 为 ``False`` ,该参数无效。
+ :param should_save_model: 是否应该保存模型,如果为 ``False`` ,Driver 将不负责 model 的保存。
+ """
+ # deepspeed engine 要求在每个 rank 都调用 save_checkpoint,故去掉了 rank_zero_call 装饰器
+ # 1. 保存 sampler 的状态
+ num_consumed_batches = states.pop('num_consumed_batches')
+ states['sampler_states'] = self.get_sampler_state(dataloader, num_consumed_batches)
+
+ # 2. 保存模型的状态;
+ if not should_save_model:
+ logger.rank_zero_warning("Saving checkpoint without model is not allowed for `DeepSpeedDriver`, "
+ "so we will still save the model for you.")
+
+ self.model.save_checkpoint(Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME),
+ client_state=states)
+
+ def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
+ r"""
+ 断点重训的加载函数,该函数会负责读取数据,并且恢复 **优化器** 、**sampler** 、 **fp16** 的状态和 **模型** (如果 ``should_load_model`` 为 True)以及其它
+ 在 :meth:`save_checkpoint` 函数中执行的保存操作,然后将一个 state 字典返回给 :class:`~fastNLP.core.controllers.Trainer` ( 内容为 :meth:`save_checkpoint`
+ 接受到的 ``states`` )。
+
+ 该函数应该在所有 rank 上执行。
+
+ :param folder: 读取该 folder 下的 ``FASTNLP_CHECKPOINT_FILENAME`` 文件与 ``FASTNLP_MODEL_FILENAME``
+ (如果 should_load_model 为True)。
+ :param dataloader: 当前给定 dataloader,需要根据保存的 dataloader 状态合理设置。若该值为 ``None`` ,则不需要返回 ``'dataloader'``
+ 以及 ``'batch_idx_in_epoch'`` 这两个值。
+ :param only_state_dict: 是否仅读取模型的 state_dict ,当 ``should_save_model`` 为 ``False`` ,该参数无效。如果为 ``True`` ,说明保存的内容为权重;如果为
+ False 说明保存的是模型,但也是通过当前 Driver 的模型去加载保存的模型的权重,而不是使用保存的模型替换当前模型。
+ :param should_load_model: 是否应该加载模型,如果为 ``False`` ,Driver 将不负责加载模型。若该参数为 ``True`` ,但在保存的状态中没有
+ 找到对应的模型状态,则报错。
+ :return: :meth:`save_checkpoint` 函数输入的 ``states`` 内容。除此之外,还返回的内容有:
+
+ * *dataloader* -- 根据传入的 ``dataloader`` 与读取出的状态设置为合理状态的 dataloader。在当前 ``dataloader`` 样本数与读取出的 sampler 样本数
+ 不一致时报错。
+ * *batch_idx_in_epoch* -- :class:`int` 类型的数据,表明当前 epoch 进行到了第几个 batch 。请注意,该值不能仅通过保存的数据中读取的,因为前后两次运行的
+ ``batch_size`` 可能有变化,而应该符合以下等式::
+
+ 返回的 dataloader 还会产生的 batch 数量 + batch_idx_in_epoch = 原来不断点训练时的 batch 的总数
+
+ 由于 ``返回的 dataloader 还会产生的batch数`` 在 ``batch_size`` 与 ``drop_last`` 参数给定的情况下,无法改变,因此只能通过调整 ``batch_idx_in_epoch``
+ 这个值来使等式成立。一个简单的计算原则如下:
+
+ * drop_last 为 ``True`` 时,等同于 floor(sample_in_this_rank/batch_size) - floor(num_left_samples/batch_size);
+ * drop_last 为 ``False`` 时,等同于 ceil(sample_in_this_rank/batch_size) - ceil(num_left_samples/batch_size)。
+ """
+ # 1. 加载模型状态;
+ if not should_load_model:
+ logger.rank_zero_warning("Loading checkpoint without model is not allowed for `DeepSpeedDriver`, "
+ "so we will still load the model for you.")
+ load_path, states = self.model.load_checkpoint(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))
+ if load_path is None:
+ raise RuntimeError(f"Failed to load checkpoint from path: {str(folder)}")
+
+ # 2.恢复 sampler 的状态
+ sampler_states = states.pop('sampler_states')
+ states_ret = self.load_sampler_state(dataloader, sampler_states)
+ states.update(states_ret)
+
+ return states
+
+ @property
+ def stage_3(self) -> bool:
+ """
+ 判断是否为第三阶段的 ZeRO 优化
+ """
+ return self.config.get("zero_optimization") and self.config.get("zero_optimization").get("stage") == 3
\ No newline at end of file
diff --git a/fastNLP/core/drivers/torch_driver/dist_utils.py b/fastNLP/core/drivers/torch_driver/dist_utils.py
new file mode 100644
index 00000000..3e2fbea0
--- /dev/null
+++ b/fastNLP/core/drivers/torch_driver/dist_utils.py
@@ -0,0 +1,366 @@
+import io
+import pickle
+import os
+_pickler = pickle.Pickler
+_unpickler = pickle.Unpickler
+from typing import Any, List
+
+from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8
+from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP
+from fastNLP.envs.imports import _NEED_IMPORT_TORCH
+from fastNLP.envs.env import FASTNLP_NO_SYNC
+if _NEED_IMPORT_TORCH:
+ import torch
+ from torch import distributed as dist
+ if _TORCH_GREATER_EQUAL_1_8:
+ try:
+ from torch._C._distributed_c10d import ProcessGroupGloo
+ from torch._C._distributed_c10d import _ProcessGroupWrapper
+ except ImportError:
+ pass
+
+
+from fastNLP.core.utils import apply_to_collection
+
+__all__ = []
+
+def _validate_output_list_for_rank(my_rank, dst, gather_list):
+ if dst == my_rank:
+ if not gather_list:
+ raise ValueError(
+ "Argument ``gather_list`` must be specified on destination rank."
+ )
+ elif gather_list:
+ raise ValueError(
+ "Argument ``gather_list`` must NOT be specified "
+ "on non-destination ranks."
+ )
+
+
+def fastnlp_torch_gather_object(obj, dst=0, group=DEFAULT_TORCH_GROUP):
+ """
+ 从其它 rank gather 东西到 dst rank 。
+
+ Example::
+ >>> # Assumes world_size of 3.
+ >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
+ >>> output = [None for _ in gather_objects]
+ >>> fastnlp_torch_gather_object(
+ gather_objects[dist.get_rank()],
+ output if dist.get_rank() == 0 else None,
+ dst=0
+ )
+ >>> # On rank 0
+ >>> output
+ ['foo', 12, {1: 2}]
+
+ :param obj: 需要发送的 obj 对象,需要是可以 pickable 的对象
+ :param dst: 目标的 rank 。
+ :param group: 在哪个 group 执行该函数。
+ :return: 在 dst 上面返回 world_size 的 list,依次为 rank 0;rank 1...上 obj
+ """
+ if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
+ return [obj]
+
+ if dist.get_rank() == dst:
+ object_gather_list = [None for _ in range(dist.get_world_size(group))]
+ else:
+ object_gather_list = None
+
+ if group is None:
+ group = DEFAULT_TORCH_GROUP
+
+ if dist.distributed_c10d._rank_not_in_group(group):
+ return
+
+ # Ensure object_gather_list is specified appopriately.
+ my_rank = dist.get_rank()
+ _validate_output_list_for_rank(my_rank, dst, object_gather_list)
+ # 防止 unpickle 的时候出现在了发送的 gpu 上。
+ obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
+ input_tensor, local_size = _object_to_tensor(obj)
+ group_backend = dist.get_backend(group)
+ current_device = torch.device("cpu")
+ is_nccl_backend = group_backend == dist.Backend.NCCL
+ if is_nccl_backend:
+ current_device = torch.device('cuda', torch.cuda.current_device())
+ input_tensor = input_tensor.to(current_device)
+ local_size = local_size.to(current_device)
+ # Gather all local sizes. This is so that we can find the max size, and index
+ # until the correct size when deserializing the tensors.
+ group_size = dist.get_world_size(group=group)
+ object_sizes_tensor = torch.zeros(group_size, dtype=torch.long, device=current_device)
+ object_size_list = [
+ object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
+ ]
+ # Allgather tensor sizes. An all-gather is needed here despite this being a
+ # gather, since each rank needs to broadcast a tensor of the same (maximal)
+ # size.
+ dist.all_gather(object_size_list, local_size, group=group)
+ max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
+ # Resize tensor to max size across all ranks.
+ input_tensor.resize_(max_object_size)
+ # Avoid populating output tensors if the result won't be gathered on this rank.
+ if my_rank == dst:
+ coalesced_output_tensor = torch.empty(
+ max_object_size * group_size, dtype=torch.uint8, device=current_device
+ )
+ # Output tensors are nonoverlapping views of coalesced_output_tensor
+ output_tensors = [
+ coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
+ for i in range(group_size)
+ ]
+ # All ranks call gather with equal-sized tensors.
+ dist.gather(
+ input_tensor,
+ gather_list=output_tensors if my_rank == dst else None,
+ dst=dst,
+ group=group,
+ )
+ if my_rank != dst:
+ return
+ for i, tensor in enumerate(output_tensors):
+ tensor = tensor.type(torch.uint8) # type: ignore[call-overload]
+ tensor_size = object_size_list[i]
+ object_gather_list[i] = _tensor_to_object(tensor, tensor_size)
+
+
+def _object_to_tensor(obj, device=None):
+ f = io.BytesIO()
+ _pickler(f).dump(obj)
+ byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) # type: ignore[attr-defined]
+ # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
+ # Otherwise, it will casue 100X slowdown.
+ # See: https://github.com/pytorch/pytorch/issues/65696
+ byte_tensor = torch.ByteTensor(byte_storage)
+ local_size = torch.LongTensor([byte_tensor.numel()])
+ if device is not None:
+ byte_tensor = byte_tensor.to(device)
+ local_size = local_size.to(device)
+ return byte_tensor, local_size
+
+
+def _tensor_to_object(tensor, tensor_size):
+ buf = tensor.detach().cpu().numpy().tobytes()[:tensor_size]
+ return _unpickler(io.BytesIO(buf)).load()
+
+
+def send_recv_object(obj, src, cur_rank, device, group=None, tag=0):
+ r"""
+ pytorch 中的单点对多点的分发函数;
+
+ 例如将进程 0 上的对象 object 分发到其它进程上。
+
+ Example::
+
+ cur_rank = int(os.environ.get('LOCAL_RANK', 0))
+
+ # 拿到 local_device
+
+ send_recv_object(object, 0, cur_rank, local_device)
+
+ :param obj: 一个可以序列化的 python 对象。
+ :param src: 从哪一个 rank 上发送到其它 rank。
+ :param cur_rank: 当前的进程的 rank 序号。
+ :param device: 当前的进程所在的设备。
+ :param group: 通信组,默认为 None。
+ :param tag: 将发送与远程接收匹配的标记;
+ :return:
+ """
+ # src rank send to all other ranks
+ size = torch.LongTensor([0]).to(device)
+
+ if cur_rank == src:
+ world_size = dist.get_world_size(group=group)
+ tensor, size = _object_to_tensor(obj)
+ tensor = tensor.to(device)
+ size = size.to(device)
+
+ # 首先同步 obj 的 size 的信息;
+ dist.broadcast(size, src, group=group)
+ for subrank in range(world_size):
+ if subrank != src:
+ dist.send(tensor=tensor, dst=subrank, group=group, tag=tag)
+ else:
+ dist.broadcast(size, src, group=group)
+ tensor = torch.ByteTensor([0] * size).to(device)
+ dist.recv(tensor=tensor, src=src, group=group, tag=tag)
+
+ return _tensor_to_object(tensor.cpu(), size)
+
+
+def _to_device(tensor, device):
+ return tensor.contiguous().to(device)
+
+
+def fastnlp_torch_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) ->List:
+ """
+ 实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。
+
+ example::
+
+ >>> # rank 0
+ >>> obj = {'a': 1, 'b':[1, 2], 'c':{'d': 1}}
+ >>> # rank 1
+ >>> obj = {'a': 1, 'b':[1, 2], 'c':{'d': 2}}
+ >>> # after all_gather():
+ >>> result = [
+ {'a': 1, 'b':[1, 2], 'c':{'d': 1}},
+ {'a': 1, 'b':[1, 2], 'c':{'d': 2}}
+ ]
+
+ :param obj: 任意结构的数据,如果为 tensor ,需要保证每个显卡上的 tensor 的形状是一样的。如果传入的是非 tensor 对象都将直接进行
+ 序列化之后进行传输。
+ :param device: 当前该参数无意义。
+ :param group:
+ :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。
+ """
+ if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
+ return [obj]
+
+ if group is None:
+ group = DEFAULT_TORCH_GROUP
+ if isinstance(obj, torch.Tensor):
+ objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))]
+ dist.all_gather(objs, obj, group=group)
+ else:
+ objs = [None for _ in range(dist.get_world_size(group))]
+ # 防止 unpickle 的时候弄到发送的 gpu 上了
+ obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
+ if _TORCH_GREATER_EQUAL_1_8:
+ dist.all_gather_object(objs, obj, group=group)
+ else:
+ objs = all_gather_object(objs, obj, group=group)
+ return objs
+
+
+def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GROUP):
+ """
+ 将 src 上的 obj 对象广播到其它 rank 上。
+
+ :param obj: 需要发送的对象
+ :param src: 从哪里发出。
+ :param device:
+ :param group: 属于哪个通信 group
+ :return:
+ """
+ if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
+ if src == dist.get_rank(group):
+ return obj
+ else:
+ return None
+
+ if group is None:
+ group = DEFAULT_TORCH_GROUP
+ cur_rank = dist.get_rank(group)
+ if cur_rank == src:
+ # 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里
+ obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
+ if _TORCH_GREATER_EQUAL_1_8:
+ if cur_rank!=src:
+ get_obj = [None]
+ dist.broadcast_object_list(get_obj, src=src, group=group)
+ return get_obj[0]
+ else:
+ dist.broadcast_object_list([obj], src=src, group=group)
+ return obj
+ if device is None:
+ device = torch.cuda.current_device()
+
+ if cur_rank == src:
+ tensor, size = _object_to_tensor(obj, device=device)
+ else:
+ size = torch.LongTensor([0]).to(device)
+
+ dist.broadcast(size, src=src, group=group)
+ if cur_rank != src:
+ tensor = torch.empty(
+ size.int().item(), # type: ignore[arg-type]
+ dtype=torch.uint8,
+ device=device
+ )
+ dist.broadcast(tensor, src=src, group=group)
+
+ return _tensor_to_object(tensor, tensor_size=size.item())
+
+
+def _check_for_nccl_backend(group):
+ pg = group or dist.distributed_c10d._get_default_group()
+ # It is not expected for PG to be wrapped many times, but support it just
+ # in case
+ while isinstance(pg, _ProcessGroupWrapper):
+ pg = pg.wrapped_pg
+
+ return (
+ dist.is_nccl_available() and
+ isinstance(pg, dist.ProcessGroupNCCL)
+ )
+
+
+def all_gather_object(object_list, obj, group=None):
+ """
+ 复制 pytorch 的代码,使得可以版本兼容低版本的 pytorch 。
+
+ Example::
+ >>> # Note: Process group initialization omitted on each rank.
+ >>> # Assumes world_size of 3.
+ >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
+ >>> output = [None for _ in gather_objects]
+ >>> all_gather_object(output, gather_objects[dist.get_rank()])
+ >>> output
+ ['foo', 12, {1: 2}]
+
+ :param object_list:
+ :param obj:
+ :param group:
+ :return:
+ """
+ if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
+ return [obj]
+
+ if dist.distributed_c10d._rank_not_in_group(group):
+ return
+ if _TORCH_GREATER_EQUAL_1_8:
+ current_device = torch.device("cpu")
+ is_nccl_backend = _check_for_nccl_backend(group)
+ if is_nccl_backend:
+ # See note about using torch.cuda.current_device() here in docstring.
+ # We cannot simply use my_rank since rank == device is not necessarily
+ # true.
+ current_device = torch.device("cuda", torch.cuda.current_device())
+ else:
+ current_device = torch.cuda.current_device()
+
+ input_tensor, local_size = _object_to_tensor(obj, device=current_device)
+
+ # Gather all local sizes. This is so that we can find the max size, and index
+ # until the correct size when deserializing the tensors.
+ group_size = dist.get_world_size(group=group)
+ object_sizes_tensor = torch.zeros(
+ group_size, dtype=torch.long, device=current_device
+ )
+ object_size_list = [
+ object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
+ ]
+ # Allgather tensor sizes
+ dist.all_gather(object_size_list, local_size, group=group)
+ max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
+ # Resize tensor to max size across all ranks.
+ input_tensor.resize_(max_object_size)
+ coalesced_output_tensor = torch.empty(
+ max_object_size * group_size, dtype=torch.uint8, device=current_device
+ )
+ # Output tensors are nonoverlapping views of coalesced_output_tensor
+ output_tensors = [
+ coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
+ for i in range(group_size)
+ ]
+ dist.all_gather(output_tensors, input_tensor, group=group)
+ # Deserialize outputs back to object.
+ for i, tensor in enumerate(output_tensors):
+ tensor = tensor.type(torch.uint8)
+ if tensor.device != torch.device("cpu"):
+ tensor = tensor.cpu()
+ tensor_size = object_size_list[i]
+ object_list[i] = _tensor_to_object(tensor, tensor_size)
+ return object_list
\ No newline at end of file
diff --git a/fastNLP/core/drivers/torch_driver/fairscale.py b/fastNLP/core/drivers/torch_driver/fairscale.py
new file mode 100644
index 00000000..83f464f1
--- /dev/null
+++ b/fastNLP/core/drivers/torch_driver/fairscale.py
@@ -0,0 +1,329 @@
+__all__ = [
+ 'FairScaleDriver'
+]
+from typing import List, Sequence, Union, Dict, Mapping
+from pathlib import Path
+import os
+import functools
+
+from fastNLP.envs.imports import _NEED_IMPORT_FAIRSCALE
+if _NEED_IMPORT_FAIRSCALE:
+ import torch
+ import torch.distributed as dist
+ from fairscale.optim import OSS
+ from fairscale.nn import ShardedDataParallel
+ from fairscale.nn import FullyShardedDataParallel
+ from fairscale.optim.grad_scaler import ShardedGradScaler
+ from torch.nn.parallel import DistributedDataParallel
+ from fairscale.nn.wrap import auto_wrap, enable_wrap, default_auto_wrap_policy
+
+from ...log import logger
+from .utils import _DDPWrappingModel
+
+from .ddp import TorchDDPDriver
+from .torch_driver import TorchDriver
+from .utils import _build_fp16_env
+from ....envs.distributed import all_rank_call_context
+from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK
+from .utils import optimizer_state_to_device
+
+
+class FairScaleDriver(TorchDDPDriver):
+ """
+ 实现 ``fairscale`` 功能的 ``Driver`` 。
+
+ :param model: 传入给 ``Trainer`` 的 ``model`` 参数。
+ :param parallel_device: 用于分布式训练的 ``gpu`` 设备。
+ :param is_pull_by_torch_run: 标志当前的脚本的启动是否由 ``python -m torch.distributed.launch`` 启动的。
+ :param fp16: 是否开启 fp16 训练。
+ :param fairscale_kwargs:
+
+ * *oss_kwargs* --
+ * *sdp_kwargs* --
+ * *fsdp_kwargs* --
+ * *ddp_kwargs* --
+ * *set_grad_to_none* -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 ``None``
+ * *non_blocking* -- 表示用于 :meth:`torch.Tensor.to` 方法的参数 non_blocking
+ * *gradscaler_kwargs* -- 用于 ``fp16=True`` 时,提供给 :class:`torch.amp.cuda.GradScaler` 的参数
+ :kwargs:
+ * *wo_auto_param_call* (``bool``) -- 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为
+
+ .. note::
+
+ 关于该参数的详细说明,请参见 :class:`~fastNLP.core.controllers.Trainer` 中的描述;函数 ``auto_param_call`` 详见 :func:`fastNLP.core.utils.auto_param_call`。
+ """
+ def __init__(
+ self,
+ model,
+ parallel_device: Union[List["torch.device"], "torch.device"],
+ is_pull_by_torch_run = False,
+ fp16: bool = False,
+ fairscale_kwargs: Dict = None,
+ **kwargs
+ ):
+ assert _NEED_IMPORT_FAIRSCALE, "fairscale is not imported."
+ assert not dist.is_initialized(), "FairScaleDriver does not support initialize distributed by user."
+ self._fairscale_kwargs = fairscale_kwargs
+ self.fs_type = self._fairscale_kwargs.get('fs_type', 'sdp') # ddp, sdp, fsdp
+ if self.fs_type == 'fsdp':
+ self._fairscale_kwargs['set_grad_to_none'] = self._fairscale_kwargs.get('set_grad_to_none', True)
+ # 将最顶上的进行初始化
+ kwargs.pop('torch_kwargs', None)
+ TorchDriver.__init__(self, model=model, fp16=False, torch_kwargs=self._fairscale_kwargs, **kwargs)
+ self.is_pull_by_torch_run = is_pull_by_torch_run
+ assert self.fs_type in ['ddp', 'sdp', 'fsdp']
+ self._oss_kwargs = self._fairscale_kwargs.get('oss_kwargs', {}) # 仅在 ddp 和 sdp 下有使用到
+ self._sdp_kwargs = self._fairscale_kwargs.get('sdp_kwargs', {})
+ self._fdsp_kwargs = self._fairscale_kwargs.get('fsdp_kwargs', {})
+ self._ddp_kwargs = self._fairscale_kwargs.get('ddp_kwargs', {})
+
+ if self.fs_type == 'ddp' or fp16 is False:
+ self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16)
+ self.grad_scaler = _grad_scaler(**self._fairscale_kwargs.get('gradscaler_kwargs', {}))
+ else:
+ self.auto_cast, self.grad_scaler = torch.cuda.amp.autocast, \
+ ShardedGradScaler(**self._fairscale_kwargs.get('gradscaler_kwargs', {}))
+
+ self.parallel_device = parallel_device
+ if is_pull_by_torch_run:
+ self.model_device = parallel_device
+ else:
+ self.model_device = parallel_device[self.local_rank]
+
+ self.outside_ddp = False # 不允许在外部初始化
+ self._data_device = kwargs.get("data_device", None)
+ if isinstance(self._data_device, int):
+ if self._data_device < 0:
+ raise ValueError("Parameter `data_device` can not be smaller than 0.")
+ _could_use_device_num = torch.cuda.device_count()
+ if self._data_device >= _could_use_device_num:
+ raise ValueError("The gpu device that parameter `device` specifies is not existed.")
+ self._data_device = torch.device(f"cuda:{self._data_device}")
+ elif isinstance(self._data_device, str):
+ self._data_device = torch.device(self._data_device)
+ elif self._data_device is not None and not isinstance(self._data_device, torch.device):
+ raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.")
+
+ self._master_port = None
+ # world_size 表示的就是全局的显卡的数量;
+ self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device)
+ self.global_rank = 0
+
+ if self.fs_type == 'ddp':
+ if len(self.model._buffers) != 0 and self._ddp_kwargs.get("broadcast_buffers", None) is None:
+ logger.info("Notice your model has buffers and you are using `FairScaleDriver`, but you do not set "
+ "'broadcast_buffers' in your trainer. Cause in most situations, this parameter can be set"
+ " to 'False' to avoid redundant data communication between different processes.")
+
+ self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error")
+ assert isinstance(self.output_from_new_proc, str), "Parameter `output_from_new_proc` can only be `str` type."
+ if self.output_from_new_proc not in {"all", "ignore", "only_error"}:
+ os.makedirs(self.output_from_new_proc, exist_ok=True)
+ self.output_from_new_proc = os.path.abspath(self.output_from_new_proc)
+
+ self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的;
+ self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹;
+
+ def setup(self):
+ r"""
+ 准备分布式环境,该函数主要做以下两件事情:
+
+ 1. 开启多进程,每个 gpu 设备对应单独的一个进程;
+ 2. 每个进程将模型迁移到自己对应的 ``gpu`` 设备上;然后使用 ``DistributedDataParallel`` 包裹模型;
+ """
+ if self._has_setup:
+ return
+ self._has_setup = True
+ if self.is_pull_by_torch_run:
+ # dist.get_world_size() 只能在 dist.init_process_group 初始化之后进行调用;
+ self.world_size = int(os.environ.get("WORLD_SIZE"))
+ self.global_rank = int(os.environ.get("RANK"))
+ logger.info(f"World size: {self.world_size}, Global rank: {self.global_rank}")
+
+ if not dist.is_initialized():
+ dist.init_process_group(
+ backend="nccl", rank=self.global_rank, world_size=self.world_size
+ )
+
+ os.environ["fastnlp_torch_launch_not_ddp"] = "yes"
+ else:
+ if not dist.is_initialized():
+ # 这里主要的问题在于要区分 rank0 和其它 rank 的情况;
+ self.world_size = len(self.parallel_device)
+ self.open_subprocess()
+ self.global_rank = self.local_rank # rank 一定是通过环境变量去获取的;
+ dist.init_process_group(
+ backend="nccl", rank=self.global_rank, world_size=self.world_size
+ )
+ # 用户在这个 trainer 前面又初始化了一个 trainer,并且使用的是 TorchDDPDriver;
+ else:
+ # 如果 `dist.is_initialized() == True`,那么说明 TorchDDPDriver 在之前已经初始化并且已经 setup 过一次,那么我们需要保证现在
+ # 使用的(即之后的)TorchDDPDriver 的设置和第一个 TorchDDPDriver 是完全一样的;
+ pre_num_processes = int(os.environ[FASTNLP_DISTRIBUTED_CHECK])
+ if pre_num_processes != len(self.parallel_device):
+ raise RuntimeError(
+ "Notice you are using `TorchDDPDriver` after one instantiated `TorchDDPDriver`, it is not"
+ "allowed that your second `TorchDDPDriver` has a new setting of parameters "
+ "`num_nodes` and `num_processes`.")
+ self.world_size = dist.get_world_size()
+ self.global_rank = dist.get_rank()
+
+ torch.cuda.set_device(self.model_device)
+ if self.fs_type != 'fsdp':
+ self.model.to(self.model_device)
+ self.configure_ddp()
+
+ self.barrier()
+ # 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作;
+ self._pids = [torch.tensor(0, dtype=torch.int).to(self.data_device) for _ in range(dist.get_world_size())]
+ dist.all_gather(self._pids, torch.tensor(os.getpid(), dtype=torch.int).to(self.data_device))
+ local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) if "LOCAL_WORLD_SIZE" in os.environ else None
+ if local_world_size is None:
+ local_world_size = torch.tensor(int(os.environ.get("LOCAL_RANK")), dtype=torch.int).to(self.data_device)
+ dist.all_reduce(local_world_size, op=dist.ReduceOp.MAX)
+ local_world_size = local_world_size.tolist() + 1
+
+ node_rank = self.global_rank // local_world_size
+ self._pids = self._pids[node_rank * local_world_size: (node_rank + 1) * local_world_size]
+ self._pids = self.tensor_to_numeric(self._pids)
+
+ def configure_ddp(self):
+ model = _DDPWrappingModel(self.model)
+ if self.fs_type == 'ddp':
+ self.model = DistributedDataParallel(
+ # 注意这里的 self.model_device 是 `torch.device` type,因此 self.model_device.index;
+ model, device_ids=[self.model_device.index],
+ **self._ddp_kwargs
+ )
+ elif self.fs_type == 'sdp':
+ sdp_kwargs = self._sdp_kwargs
+ sdp_kwargs = {**sdp_kwargs, 'module': model}
+ sdp_kwargs['reduce_fp16'] = sdp_kwargs.get('reduce_fp16', self.fp16)
+ oss_lst = []
+ for optimizer in self.optimizers:
+ oss = OSS(optimizer.param_groups, optim=type(optimizer), **optimizer.defaults)
+ oss_lst.append(oss)
+ sdp_kwargs['sharded_optimizer'] = oss_lst
+ sdp_kwargs['warn_on_trainable_params_changed'] = sdp_kwargs.get('warn_on_trainable_params_changed', False)
+ self.model = ShardedDataParallel(**sdp_kwargs)
+ self.optimizers = oss_lst
+ else:
+ assert len(self.optimizers) == 1, "When fs_type='fsdp', only one optimizer is allowed."
+ optimizer = self.optimizers[0]
+ assert len(optimizer.param_groups) == 1, "Cannot assign parameter specific optimizer parameter for 'fsdp'."
+ fsdp_kwargs = self._fdsp_kwargs
+ fsdp_kwargs['mixed_precision'] = self.fp16
+ fsdp_kwargs['state_dict_on_rank_0_only'] = fsdp_kwargs.get('state_dict_on_rank_0_only', True)
+ fsdp_kwargs['state_dict_device'] = fsdp_kwargs.get('state_dict_device', torch.device('cpu'))
+ fsdp_kwargs['compute_device'] = fsdp_kwargs.get('compute_device', self.model_device)
+ optimizer = self.optimizers[0]
+ # wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=1e6)
+ # with enable_wrap(wrapper_cls=FullyShardedDataParallel, auto_wrap_policy=wrap_policy,
+ # **fsdp_kwargs):
+ # model = auto_wrap(model)
+ fsdp_kwargs = {**fsdp_kwargs, 'module': model}
+ self.model = None # 释放掉
+ self.model = FullyShardedDataParallel(**fsdp_kwargs).to(self.model_device)
+ self.optimizers = type(optimizer)(self.model.parameters(), **optimizer.defaults)
+
+ self._has_ddpwrapped = True
+
+ def save_model(self, filepath: Union[str, Path], only_state_dict: bool = True, **kwargs):
+ """
+ 保存当前 driver 的模型到 folder 下。
+
+ :param filepath: 保存到哪个文件夹;
+ :param only_state_dict: 是否只保存权重;
+ :return:
+ """
+ if self.fs_type in ('ddp', 'sdp'):
+ model = self.model.module.model
+
+ if only_state_dict:
+ if self.fs_type != 'fsdp':
+ if self.local_rank == 0:
+ states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()}
+ else:
+ # 所有 rank 都需要调用
+ states = self.model.state_dict()
+ if self.local_rank == 0:
+ states = {key[len('model.'):]:value for key, value in states.items()} # 这里需要去掉那个 _wrap 的 key
+ if self.local_rank == 0: #
+ torch.save(states, filepath)
+ elif self.fs_type == 'fsdp':
+ raise RuntimeError("When fs_type='fsdp', only `only_state_dict=True` is allowed.")
+ else:
+ if self.local_rank == 0:
+ torch.save(model, filepath)
+
+ def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs):
+ """
+ 从 folder 中加载权重并赋值到当前 driver 的模型上。
+
+ :param filepath: 加载权重或模型的路径
+ :param load_state_dict: 保存的内容是否只是权重。
+ :param kwargs:
+ :return:
+ """
+ states = torch.load(filepath, map_location='cpu')
+ if isinstance(states, dict) and only_state_dict is False:
+ logger.rank_zero_warning(f"It seems like that {filepath} only contains state, you may need to use "
+ f"`only_state_dict=True`")
+ elif not isinstance(states, dict) and only_state_dict is True:
+ logger.rank_zero_warning(f"It seems like that {filepath} is not state, you may need to use "
+ f"`only_state_dict=False`")
+ if not isinstance(states, Mapping):
+ states = states.state_dict()
+
+ if self.fs_type in ('ddp', 'sdp'):
+ model = self.model.module.model
+ else:
+ model = self.model
+ states = {f'model.{k}':v for k, v in states.items()}
+
+ model.load_state_dict(states)
+
+ def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
+ if self.fs_type == 'fsdp':
+ if should_save_model is False:
+ logger.warning("When save model using fs_type='fsdp', please make sure use "
+ "`with trainer.driver.model.summon_full_params():` context to gather all parameters.")
+ with all_rank_call_context():
+ super().save_checkpoint(folder=folder, states=states, dataloader=dataloader, only_state_dict=only_state_dict,
+ should_save_model=should_save_model, **kwargs)
+ else:
+ super().save_checkpoint(folder=folder, states=states, dataloader=dataloader,
+ only_state_dict=only_state_dict, should_save_model=should_save_model, **kwargs)
+
+ def get_optimizer_state(self):
+ optimizers_state_dict = {}
+ for i in range(len(self.optimizers)):
+ optimizer: torch.optim.Optimizer = self.optimizers[i]
+ if self.fs_type == 'fsdp':
+ optimizer_state = self.model.gather_full_optim_state_dict(optimizer)
+ elif self.fs_type == 'sdp':
+ optimizer.consolidate_state_dict(recipient_rank=0)
+ else:
+ optimizer_state = optimizer.state_dict()
+ if self.local_rank == 0:
+ optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu"))
+ optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的;
+ return optimizers_state_dict
+
+ def load_optimizer_state(self, states):
+ assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \
+ f"checkpoint it is:{len(states)}"
+ for i in range(len(self.optimizers)):
+ optimizer: torch.optim.Optimizer = self.optimizers[i]
+ state = states[f'optimizer{i}']
+ if self.fs_type == 'fsdp':
+ state = self.model.get_shard_from_optim_state_dict(state)
+ optimizer.load_state_dict(state)
+
+ logger.debug("Load optimizer state dict.")
+
+ def unwrap_model(self):
+ r"""
+ :return: 原本的模型,例如没有被 ``DataParallel`` 包裹;
+ """
+ return self.model.module.model
diff --git a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py
new file mode 100644
index 00000000..9352e315
--- /dev/null
+++ b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py
@@ -0,0 +1,111 @@
+import os
+from typing import Optional, Union, List, Sequence
+from fastNLP.envs.imports import _NEED_IMPORT_TORCH
+if _NEED_IMPORT_TORCH:
+ import torch
+
+from .torch_driver import TorchDriver
+from .single_device import TorchSingleDriver
+from .ddp import TorchDDPDriver
+from .fairscale import FairScaleDriver
+from .deepspeed import DeepSpeedDriver
+from .torch_fsdp import TorchFSDPDriver
+from fastNLP.core.log import logger
+from fastNLP.envs import FASTNLP_BACKEND_LAUNCH
+from pkg_resources import parse_version
+
+__all__ = []
+
+
+def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.device", int, List[int]]],
+ model: "torch.nn.Module", **kwargs) -> TorchDriver:
+ r"""
+ 用来根据参数 ``driver` 和 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去;
+
+ :param driver: 该参数的值应为以下之一:``["torch", "fairscale", "deepspeed", "torch_fsdp"]``
+ :param device: 该参数的格式与 ``Trainer`` 对参数 ``device`` 的要求一致
+ :param model: 训练或者评测的具体的模型
+
+ :return: 下列类型之一的实例:
+ * :class:`~fastNLP.core.drivers.torch_driver.TorchSingleDriver`
+ * :class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver`
+ * :class:`~fastNLP.core.drivers.torch_driver.DeepSpeedDriver`
+ * :class:`~fastNLP.core.drivers.torch_driver.FairScaleDriver`
+ * :class:`~fastNLP.core.drivers.torch_driver.TorchFSDPDriver`
+ """
+ if parse_version(torch.__version__) < parse_version('1.6'):
+ raise RuntimeError(f"Pytorch(current version:{torch.__version__}) need to be older than 1.6.")
+ # world_size 和 rank
+ if FASTNLP_BACKEND_LAUNCH in os.environ:
+ if device is not None:
+ logger.rank_zero_warning("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull "
+ "up your script. And we will directly get the local device via "
+ "`os.environ['LOCAL_RANK']`.", once=True)
+ if driver == 'fairscale':
+ return FairScaleDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"),
+ is_pull_by_torch_run=True, **kwargs)
+ elif driver == 'deepspeed':
+ return DeepSpeedDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"),
+ is_pull_by_torch_run=True, **kwargs)
+ else:
+ return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"),
+ is_pull_by_torch_run=True, **kwargs)
+
+ if driver not in {"torch", "fairscale", "deepspeed", "torch_fsdp"}:
+ raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale'].")
+
+ _could_use_device_num = torch.cuda.device_count()
+ if isinstance(device, str):
+ device = torch.device(device)
+ elif isinstance(device, int):
+ if device < 0:
+ if device != -1:
+ raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.")
+ device = [torch.device(f"cuda:{w}") for w in range(_could_use_device_num)]
+ elif device >= _could_use_device_num:
+ raise ValueError("The gpu device that parameter `device` specifies is not existed.")
+ else:
+ device = torch.device(f"cuda:{device}")
+ elif isinstance(device, Sequence):
+ device = list(set(device))
+ for each in device:
+ if not isinstance(each, int):
+ raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be 'int' type.")
+ elif each < 0:
+ raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be bigger than 0.")
+ elif each >= _could_use_device_num:
+ raise ValueError(f"When parameter `device` is 'Sequence' type, the value in it should not be bigger than"
+ f" the available gpu number:{_could_use_device_num}.")
+ device = [torch.device(f"cuda:{w}") for w in device]
+ elif device is not None and not isinstance(device, torch.device):
+ raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.")
+
+ if driver == "torch": # single, ddp, 直接启动。
+ if not isinstance(device, List):
+ return TorchSingleDriver(model, device, **kwargs)
+ else:
+ return TorchDDPDriver(model, device, **kwargs)
+ elif driver == "fairscale":
+ if not isinstance(device, List):
+ if device.type == 'cpu':
+ raise ValueError("You are using `fairscale` driver, but your chosen `device` is 'cpu'.")
+ logger.warning_once("Notice you are using `fairscale`, but the `device` is only one gpu.")
+ return FairScaleDriver(model, [device], **kwargs)
+ else:
+ return FairScaleDriver(model, device, **kwargs)
+ elif driver == "deepspeed":
+ if not isinstance(device, List):
+ if device.type == 'cpu':
+ raise ValueError("You are using `deepspeed` driver, but your chosen `device` is 'cpu'.")
+ logger.warning_once("Notice you are using `deepspeed`, but the `device` is only one gpu.")
+ return DeepSpeedDriver(model, [device], **kwargs)
+ else:
+ return DeepSpeedDriver(model, device, **kwargs)
+ elif driver == "torch_fsdp":
+ if not isinstance(device, List):
+ if device.type == 'cpu':
+ raise ValueError("You are using `torch_fsdp` driver, but your chosen `device` is 'cpu'.")
+ logger.warning_once("Notice you are using `torch_fsdp`, but the `device` is only one gpu.")
+ return TorchFSDPDriver(model, [device], **kwargs)
+ else:
+ return TorchFSDPDriver(model, device, **kwargs)
\ No newline at end of file
diff --git a/fastNLP/core/drivers/torch_driver/single_device.py b/fastNLP/core/drivers/torch_driver/single_device.py
new file mode 100644
index 00000000..336b5420
--- /dev/null
+++ b/fastNLP/core/drivers/torch_driver/single_device.py
@@ -0,0 +1,181 @@
+import os
+from typing import Dict, Union, Callable, Tuple, Optional
+from fastNLP.envs.imports import _NEED_IMPORT_TORCH
+
+if _NEED_IMPORT_TORCH:
+ import torch
+ from torch.nn import DataParallel
+ from torch.nn.parallel import DistributedDataParallel
+ from torch.utils.data import RandomSampler as TorchRandomSampler
+ from torch.utils.data import SequentialSampler as TorchSequentialSampler
+ from torch.utils.data import BatchSampler as TorchBatchSampler
+
+__all__ = [
+ 'TorchSingleDriver'
+]
+
+from .torch_driver import TorchDriver
+from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler
+from fastNLP.core.utils import auto_param_call
+from fastNLP.core.utils.utils import _get_fun_msg
+from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \
+ ReproduceBatchSampler
+from fastNLP.core.samplers import RandomSampler
+from fastNLP.core.log import logger
+
+
+class TorchSingleDriver(TorchDriver):
+ r"""
+ ``TorchSingleDriver`` 是用于 cpu 和 单卡 gpu 运算的 ``driver``。
+
+ .. note::
+
+ 如果您希望使用 ``DataParallel`` 来训练您的模型,您应当自己在 ``Trainer`` 初始化之前初始化好 ``DataParallel``,然后将其传入 ``Trainer`` 中。
+
+ :param model: 传入给 ``Trainer`` 的 ``model`` 参数
+ :param device: torch.device,当前进程所使用的设备
+ :param fp16: 是否开启 fp16
+ :param torch_kwargs:
+ * *set_grad_to_none* -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 ``None``
+ * *non_blocking* -- 表示用于 :meth:`torch.Tensor.to` 方法的参数 non_blocking
+ * *gradscaler_kwargs* -- 用于 ``fp16=True`` 时,提供给 :class:`torch.amp.cuda.GradScaler` 的参数
+ :kwargs:
+ * *wo_auto_param_call* (``bool``) -- 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为
+
+ .. note::
+
+ 关于该参数的详细说明,请参见 :class:`~fastNLP.core.controllers.Trainer` 中的描述;函数 ``auto_param_call`` 详见 :func:`fastNLP.core.utils.auto_param_call`。
+ """
+
+ def __init__(self, model, device: "torch.device", fp16: bool = False, torch_kwargs: Dict = None, **kwargs):
+ if isinstance(model, DistributedDataParallel):
+ raise ValueError("`DistributedDataParallel` is not supported in `TorchSingleDriver`")
+
+ cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
+ if cuda_visible_devices == "":
+ device = torch.device("cpu")
+ logger.info("You have set `CUDA_VISIBLE_DEVICES` to '' in system environment variable, and we are gonna to"
+ "use `cpu` instead of `gpu` device.")
+
+ super(TorchSingleDriver, self).__init__(model, fp16=fp16, torch_kwargs=torch_kwargs, **kwargs)
+
+ if device is None:
+ logger.debug("device is not set, fastNLP will try to automatically get it.")
+ try:
+ device = next(model.parameters()).device
+ assert isinstance(device, torch.device)
+ except:
+ raise ValueError("fastNLP cannot get device automatically, please set device explicitly.")
+
+ self.model_device = device
+
+ self.local_rank = 0
+ self.global_rank = 0
+ self.world_size = 1
+
+ def setup(self):
+ r"""
+ 将模型迁移到相应的设备上。
+ """
+ if self.model_device is not None:
+ self.model.to(self.model_device)
+
+ def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
+ if isinstance(batch, Dict) and not self.wo_auto_param_call:
+ return auto_param_call(fn, batch, signature_fn=signature_fn)
+ else:
+ return fn(batch)
+
+ def get_model_call_fn(self, fn: str) -> Tuple:
+ if isinstance(self.model, DataParallel):
+ model = self.unwrap_model()
+ if hasattr(model, fn):
+ logger.warning("Notice your model is a `DataParallel` model. And your model also implements the "
+ f"`{fn}` method, which we can not call actually, we will"
+ " call `forward` function instead of `train_step` and you should note that.")
+
+ elif fn not in {"train_step", "evaluate_step"}:
+ raise RuntimeError(f"There is no `{fn}` method in your model. And also notice that your model is a "
+ f"`DataParallel` model, which means that we will only call model.forward function "
+ f"when we are in forward propagation.")
+
+ return self.model, model.forward
+ else:
+ # TODO 这种直接调用模型某个接口的方法无法触发hook,也许需要做一个warning,如果用户有钩子,提醒他train_step无法触发。
+ if hasattr(self.model, fn):
+ fn = getattr(self.model, fn)
+ if not callable(fn):
+ raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
+ logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
+ return fn, None
+ elif fn in {"train_step", "evaluate_step"}:
+ logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...')
+ return self.model, self.model.forward
+ else:
+ raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")
+
+ def set_dist_repro_dataloader(self, dataloader,
+ dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None,
+ reproducible: bool = False):
+
+ # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load_checkpoint 函数调用;
+ if isinstance(dist, ReproducibleBatchSampler):
+ return replace_batch_sampler(dataloader, dist)
+ elif isinstance(dist, ReproducibleSampler):
+ return replace_sampler(dataloader, dist)
+
+ # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用;
+ args = self.get_dataloader_args(dataloader)
+ if isinstance(args.batch_sampler, ReproducibleBatchSampler):
+ batch_sampler = re_instantiate_sampler(args.batch_sampler)
+ return replace_batch_sampler(dataloader, batch_sampler)
+ elif isinstance(args.sampler, ReproducibleSampler):
+ sampler = re_instantiate_sampler(args.sampler)
+ return replace_sampler(dataloader, sampler)
+
+ if reproducible:
+ if type(args.batch_sampler) is TorchBatchSampler:
+ if type(args.sampler) is TorchRandomSampler:
+ if getattr(args.sampler, '_num_samples', None) is None \
+ and getattr(args.sampler, 'replacements', False) is False \
+ and getattr(args.sampler, 'generator', None) is None:
+ # 如果本来就是随机的,并且没有定制,直接替换掉吧。
+ sampler = RandomSampler(args.sampler.data_source, shuffle=True)
+ logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.")
+ return replace_sampler(dataloader, sampler)
+ elif type(args.sampler) is TorchSequentialSampler:
+ # 需要替换为不要 shuffle 的。
+ sampler = RandomSampler(args.sampler.data_source, shuffle=False)
+ logger.debug("Replace torch SequentialSampler into fastNLP RandomSampler.")
+ return replace_sampler(dataloader, sampler)
+ batch_sampler = ReproduceBatchSampler(
+ batch_sampler=args.batch_sampler,
+ batch_size=args.batch_size,
+ drop_last=args.drop_last
+ )
+ return replace_batch_sampler(dataloader, batch_sampler)
+ else:
+ return dataloader
+
+ def unwrap_model(self):
+ r"""
+ :return: 原本的模型,该函数可以取出被 ``DataParallel`` 包裹的模型
+ """
+ if isinstance(self.model, torch.nn.DataParallel) or \
+ isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
+ return self.model.module
+ else:
+ return self.model
+
+ @property
+ def data_device(self):
+ r"""
+ 数据和模型所在的设备
+ """
+ return self.model_device
+
+ def is_distributed(self):
+ r"""
+ :return: 当前使用的 driver 是否是分布式的 driver,对于 ``TorchSingleDriver`` 来说直接返回 ``False``
+ """
+ return False
diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py
new file mode 100644
index 00000000..6ca33476
--- /dev/null
+++ b/fastNLP/core/drivers/torch_driver/torch_driver.py
@@ -0,0 +1,519 @@
+import os
+from typing import Union, Dict, Optional, Callable
+from functools import partial
+import numpy as np
+import random
+from dataclasses import dataclass
+from fastNLP.envs.imports import _NEED_IMPORT_TORCH
+from pathlib import Path
+if _NEED_IMPORT_TORCH:
+ import torch
+ from torch.utils.data import DataLoader, IterableDataset, Sampler, BatchSampler, Dataset
+ from torch.optim import Optimizer
+ from torch.utils.data import RandomSampler as TorchRandomSampler
+ _reduces = {
+ 'sum': torch.sum,
+ 'min': torch.min,
+ 'max': torch.max,
+ 'mean': torch.mean
+ }
+
+
+__all__ = [
+ 'TorchDriver'
+]
+
+from .utils import optimizer_state_to_device
+from fastNLP.core.drivers.driver import Driver
+from fastNLP.core.drivers.torch_driver.utils import _build_fp16_env, DummyGradScaler
+from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
+from fastNLP.envs import rank_zero_call
+from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
+from fastNLP.core.log import logger
+from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler
+from fastNLP.core.dataloaders import OverfitDataLoader
+
+
+class TorchDriver(Driver):
+ r"""
+ 专属于 ``pytorch`` 的 ``driver``,是 ``TorchSingleDriver`` 和 ``TorchDDPDriver`` 的父类。
+
+ .. warning::
+
+ 您不应当直接初始化该类,然后传入给 ``Trainer``,换句话说,您应当使用该类的子类 ``TorchSingleDriver`` 和 ``TorchDDPDriver``,而不是
+ 该类本身。
+
+ .. note::
+
+ 您可以在使用 ``TorchSingleDriver`` 和 ``TorchDDPDriver`` 时使用 ``TorchDriver`` 提供的接口。
+
+ :param model: 训练时使用的 **pytorch** 模型。
+ :param fp16: 是否开启混合精度训练;
+ :param torch_kwargs:
+ """
+ def __init__(self, model, fp16: Optional[bool] = False, torch_kwargs: Dict = None, **kwargs):
+ super(TorchDriver, self).__init__(model)
+
+ """ 进行 fp16 的设置 """
+ self._torch_kwargs = torch_kwargs if torch_kwargs is not None else {}
+
+ # 因为 ddp 和 single_device 的混合精度训练的设置是一样的,因此可以统一抽象到这里;
+ self.fp16 = fp16
+ self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not self.fp16)
+ self.grad_scaler = _grad_scaler(**self._torch_kwargs.get('gradscaler_kwargs', {}))
+ self.set_grad_to_none = self._torch_kwargs.get('set_grad_to_none')
+
+ # 用来设置 `torch_move_data_to_device` 中的 `non_blocking` 参数;
+ self.non_blocking = self._torch_kwargs.get("non_blocking", True)
+
+ # 用来设置是否关闭 auto_param_call 中的参数匹配问题;
+ self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)
+
+ def zero_grad(self):
+ """
+ 实现梯度置零的过程
+ """
+ for optimizer in self.optimizers:
+ self._clear_grad(optimizer, self.set_grad_to_none)
+
+ def _clear_grad(self, optimizer, set_to_none):
+ param_groups = optimizer.param_groups
+ for group in param_groups:
+ for p in group['params']:
+ if p.grad is not None:
+ if set_to_none:
+ p.grad = None
+ else:
+ if p.grad.grad_fn is not None:
+ p.grad.detach_()
+ else:
+ p.grad.requires_grad_(False)
+ p.grad.zero_()
+
+ def backward(self, loss):
+ """
+ 对 ``loss`` 进行反向传播
+ """
+ self.grad_scaler.scale(loss).backward()
+
+ def step(self):
+ r"""
+ 实现参数的优化更新过程
+ """
+ for optimizer in self.optimizers:
+ self.grad_scaler.step(optimizer)
+ self.grad_scaler.update()
+
+ def check_dataloader_legality(self, dataloader):
+ """
+ 检测 DataLoader 是否合法。支持的类型包括 :class:`~fastNLP.core.dataloaders.TorchDataLoader`、 :class:`torch.utils.data.DataLoader` 。
+
+ :param dataloder:
+ """
+ if not isinstance(dataloader, DataLoader) and not isinstance(dataloader, OverfitDataLoader):
+ raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`")
+ if len(dataloader) == 0:
+ logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it "
+ "may cause some unexpected exceptions.", once=True)
+
+ @staticmethod
+ def _check_optimizer_legality(optimizers):
+ for each_optimizer in optimizers:
+ if not isinstance(each_optimizer, Optimizer):
+ raise TypeError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, "
+ f"not {type(each_optimizer)}.")
+
+ @staticmethod
+ def tensor_to_numeric(tensor, reduce: str = None):
+ r"""
+ 将 ``torch.Tensor`` 转换成 python 中的数值类型。
+
+ :param tensor: ``torch.Tensor``。
+ :param reduce: 当 tensor 是一个多数值的张量时,应当使用何种归一化操作来转换成单一数值,应当为以下类型之一:``['max', 'min', 'sum', 'mean']``。
+ :return: 一个单一数值,其数值类型是 python 中的基本的数值类型,例如 ``int,float`` 等。
+ """
+
+ if tensor is None:
+ return None
+
+ def _translate(_data):
+ if _data.numel() == 1:
+ return _data.item()
+ if reduce is None:
+ return _data.tolist()
+ return _reduces[reduce](_data).item()
+
+ return apply_to_collection(
+ data=tensor,
+ dtype=torch.Tensor,
+ function=_translate
+ )
+
+ def set_model_mode(self, mode: str):
+ r"""
+ 设置模型为 ``train`` 或 ``eval`` 的模式;目的是为切换模型的训练和推理(会关闭 dropout 等)模式。
+
+ :param mode: 应为二者之一:``["train", "eval"]``
+ """
+ assert mode in {"train", "eval"}
+ getattr(self.model, mode)()
+
+ @rank_zero_call
+ def save_model(self, filepath: Union[str, Path], only_state_dict: bool = True, **kwargs):
+ """
+ 保存当前 driver 的模型到 ``filepath``。
+
+ :param filepath: 保存文件的文件位置
+ :param only_state_dict: 是否只保存权重
+ :return:
+ """
+ model = self.unwrap_model()
+
+ if only_state_dict:
+ states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()}
+ torch.save(states, filepath)
+ else:
+ if self.model_device is not None:
+ if not self.is_distributed():
+ self.move_model_to_device(model, torch.device("cpu"))
+ torch.save(model, filepath)
+ if not self.is_distributed():
+ self.move_model_to_device(model, self.model_device)
+ else:
+ torch.save(model, filepath)
+
+ def load_model(self, filepath: Union[Path, str], only_state_dict: bool = True, **kwargs):
+ """
+ 加载模型的函数;将 ``filepath`` 中的模型加载并赋值给当前 ``model`` 。
+
+ :param filepath: 保存文件的文件位置
+ :param load_state_dict: 保存的内容是否只是权重
+ """
+ model = self.unwrap_model()
+ # todo torch.load 在加载时会使得卡 0 多出一个(甚至多个)model 的显存;因此在多卡断点重训时可能会出现错误;
+ res = torch.load(filepath, map_location='cpu')
+ if isinstance(res, dict) and only_state_dict is False:
+ logger.rank_zero_warning(f"It seems like that {filepath} only contains state, you may need to use "
+ f"`only_state_dict=True`")
+ elif not isinstance(res, dict) and only_state_dict is True:
+ logger.rank_zero_warning(f"It seems like that {filepath} is not state, you may need to use "
+ f"`only_state_dict=False`")
+ if not isinstance(res, dict):
+ res = res.state_dict()
+ _strict = kwargs.get("strict", True)
+ model.load_state_dict(res, _strict)
+
+ @rank_zero_call
+ def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
+ r"""
+ 断点重训的保存函数,该函数会负责保存 **优化器** 、 **sampler** 和 **fp16** 的状态,以及 **模型** (若 ``should_save_model`` 为 ``True``)
+
+ :param folder: 保存断点重训的状态的文件夹;:meth:`save_checkpoint` 函数应该在该路径下面下面新增名为 ``FASTNLP_CHECKPOINT_FILENAME`` 与
+ ``FASTNLP_MODEL_FILENAME`` (如果 ``should_save_model`` 为 ``True`` )的文件。把 model 相关的内容放入到 ``FASTNLP_MODEL_FILENAME`` 文件
+ 中,将传入的 ``states`` 以及自身产生的其它状态一并保存在 ``FASTNLP_CHECKPOINT_FILENAME`` 里面。
+ :param states: 由 :class:`~fastNLP.core.controllers.Trainer` 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态。
+ :param dataloader: 正在使用的 dataloader。
+ :param only_state_dict: 是否只保存模型的参数,当 ``should_save_model`` 为 ``False`` ,该参数无效。
+ :param should_save_model: 是否应该保存模型,如果为 ``False`` ,Driver 将不负责 model 的保存。
+ """
+ # 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变
+ # trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境;
+
+ # 1. sampler 的状态;
+ num_consumed_batches = states.pop('num_consumed_batches')
+ states['sampler_states'] = self.get_sampler_state(dataloader, num_consumed_batches)
+
+ # 2. 保存模型的状态;
+ if should_save_model:
+ if not os.path.exists(folder):
+ os.mkdir(folder)
+ model_path = folder.joinpath(FASTNLP_MODEL_FILENAME)
+ self.save_model(model_path, only_state_dict=only_state_dict)
+
+ # 3. 保存 optimizers 的状态;
+ states["optimizers_state_dict"] = self.get_optimizer_state()
+ logger.debug("Save optimizer state dict.")
+
+ # 4. 保存fp16的状态
+ if not isinstance(self.grad_scaler, DummyGradScaler):
+ grad_scaler_state_dict = self.grad_scaler.state_dict()
+ states['grad_scaler_state_dict'] = grad_scaler_state_dict
+
+ torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME))
+
+ def get_sampler_state(self, dataloader, num_consumed_batches):
+ # 因为我们支持 resume training,即精确恢复到具体的一个 batch;
+ # 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `set_` 中将 dataloader 的
+ # sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`;
+ dataloader_args = self.get_dataloader_args(dataloader)
+ if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
+ sampler = dataloader_args.batch_sampler
+ elif dataloader_args.sampler:
+ sampler = dataloader_args.sampler
+ else:
+ raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.")
+
+ if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
+ sampler_states = sampler.state_dict()
+ if dataloader_args.batch_size is not None:
+ sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \
+ * num_consumed_batches
+ else:
+ logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on sampler's "
+ "`num_consumed_samples`, it may cause missing some samples when reload.")
+ else:
+ raise RuntimeError('The sampler has no `state_dict()` method, fastNLP cannot save the training '
+ 'state.')
+
+ return sampler_states
+
+ def load_sampler_state(self, dataloader, sampler_states):
+ states = {}
+ dataloader_args = self.get_dataloader_args(dataloader)
+ if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
+ sampler = dataloader_args.batch_sampler
+ elif isinstance(dataloader_args.sampler, ReproducibleSampler):
+ sampler = dataloader_args.sampler
+ elif isinstance(dataloader_args.sampler, TorchRandomSampler):
+ sampler = RandomSampler(dataloader_args.sampler.data_source)
+ logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.")
+ elif self.is_distributed():
+ raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our"
+ "`ReproducibleSampler`.")
+ else:
+ sampler = ReproduceBatchSampler(
+ batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
+ batch_size=dataloader_args.batch_size,
+ drop_last=dataloader_args.drop_last
+ )
+ sampler.load_state_dict(sampler_states)
+ states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)
+
+ # 修改 trainer_state.batch_idx_in_epoch
+ # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler;
+ if not isinstance(sampler, ReproducibleBatchSampler):
+ if dataloader_args.drop_last:
+ batch_idx_in_epoch = len(
+ sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size
+ else:
+ batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \
+ (sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size
+ # sampler 是 batch_sampler;
+ else:
+ batch_idx_in_epoch = sampler.batch_idx_in_epoch
+
+ states["batch_idx_in_epoch"] = batch_idx_in_epoch
+ return states
+
+ def get_optimizer_state(self):
+ optimizers_state_dict = {}
+ for i in range(len(self.optimizers)):
+ optimizer: torch.optim.Optimizer = self.optimizers[i]
+ optimizer_state = optimizer.state_dict()
+ optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu"))
+ optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的;
+ return optimizers_state_dict
+
+ def load_optimizer_state(self, states):
+ assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \
+ f"checkpoint it is:{len(states)}"
+ for i in range(len(self.optimizers)):
+ optimizer: torch.optim.Optimizer = self.optimizers[i]
+ optimizer.load_state_dict(states[f"optimizer{i}"])
+ logger.debug("Load optimizer state dict.")
+
+ def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
+ r"""
+ 断点重训的加载函数,该函数会负责读取数据,并且恢复 **优化器** 、**sampler** 、 **fp16** 的状态和 **模型** (如果 ``should_load_model`` 为 True)以及其它
+ 在 :meth:`save_checkpoint` 函数中执行的保存操作,然后将一个 state 字典返回给 :class:`~fastNLP.core.controllers.Trainer` ( 内容为 :meth:`save_checkpoint`
+ 接受到的 ``states`` )。
+
+ 该函数应该在所有 rank 上执行。
+
+ :param folder: 读取该 folder 下的 ``FASTNLP_CHECKPOINT_FILENAME`` 文件与 ``FASTNLP_MODEL_FILENAME``
+ (如果 should_load_model 为True)。
+ :param dataloader: 当前给定 dataloader,需要根据保存的 dataloader 状态合理设置。若该值为 ``None`` ,则不需要返回 ``'dataloader'``
+ 以及 ``'batch_idx_in_epoch'`` 这两个值。
+ :param only_state_dict: 是否仅读取模型的 state_dict ,当 ``should_save_model`` 为 ``False`` ,该参数无效。如果为 ``True`` ,说明保存的内容为权重;如果为
+ False 说明保存的是模型,但也是通过当前 Driver 的模型去加载保存的模型的权重,而不是使用保存的模型替换当前模型。
+ :param should_load_model: 是否应该加载模型,如果为 ``False`` ,Driver 将不负责加载模型。若该参数为 ``True`` ,但在保存的状态中没有
+ 找到对应的模型状态,则报错。
+ :return: :meth:`save_checkpoint` 函数输入的 ``states`` 内容。除此之外,还返回的内容有:
+
+ * *dataloader* -- 根据传入的 ``dataloader`` 与读取出的状态设置为合理状态的 dataloader。在当前 ``dataloader`` 样本数与读取出的 sampler 样本数
+ 不一致时报错。
+ * *batch_idx_in_epoch* -- :class:`int` 类型的数据,表明当前 epoch 进行到了第几个 batch 。请注意,该值不能仅通过保存的数据中读取的,因为前后两次运行的
+ ``batch_size`` 可能有变化,而应该符合以下等式::
+
+ 返回的 dataloader 还会产生的 batch 数量 + batch_idx_in_epoch = 原来不断点训练时的 batch 的总数
+
+ 由于 ``返回的 dataloader 还会产生的batch数`` 在 ``batch_size`` 与 ``drop_last`` 参数给定的情况下,无法改变,因此只能通过调整 ``batch_idx_in_epoch``
+ 这个值来使等式成立。一个简单的计算原则如下:
+
+ * drop_last 为 ``True`` 时,等同于 floor(sample_in_this_rank/batch_size) - floor(num_left_samples/batch_size);
+ * drop_last 为 ``False`` 时,等同于 ceil(sample_in_this_rank/batch_size) - ceil(num_left_samples/batch_size)。
+ """
+ states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))
+
+ # 1. 加载 optimizers 的状态;
+ optimizers_state_dict = states.pop("optimizers_state_dict")
+ self.load_optimizer_state(optimizers_state_dict)
+
+ # 2. 加载模型状态;
+ if should_load_model:
+ self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict)
+
+ # 3. 加载 fp16 的状态
+ if "grad_scaler_state_dict" in states:
+ grad_scaler_state_dict = states.pop("grad_scaler_state_dict")
+ if not isinstance(self.grad_scaler, DummyGradScaler):
+ self.grad_scaler.load_state_dict(grad_scaler_state_dict)
+ logger.debug("Load grad_scaler state dict...")
+ elif not isinstance(self.grad_scaler, DummyGradScaler):
+ logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "
+ f"the training process may be unstable.")
+
+ # 4. 恢复 sampler 的状态;
+ sampler_states = states.pop('sampler_states')
+ states_ret = self.load_sampler_state(dataloader, sampler_states)
+ states.update(states_ret)
+
+ return states
+
+ def get_evaluate_context(self):
+ r"""
+ 返回一个不计算梯度的上下文环境用来对模型进行评测。
+
+ :return: 上下文环境 ``torch.no_grad``
+ """
+ return torch.no_grad
+
+ @staticmethod
+ def move_model_to_device(model: "torch.nn.Module", device: "torch.device"):
+ r"""
+ 将模型迁移到对应的设备上
+ """
+ if device is not None:
+ model.to(device)
+
+ def move_data_to_device(self, batch):
+ """
+ 将一个 ``batch`` 的数据迁移到对应的设备上
+
+ :param batch: 包含 :class:`torch.Tensor` 的数据集合,可以是 **List**、**Dict** 等嵌套类型
+ :return: 移动到指定机器后的 ``batch``
+ """
+ return torch_move_data_to_device(batch, self.data_device, self.non_blocking)
+
+ @staticmethod
+ def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover
+ """
+ """
+ """The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed
+ with ``seed_everything(seed, workers=True)``.
+
+ See also the PyTorch documentation on
+ `randomness in DataLoaders `_.
+ """
+ # implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
+ global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))
+ process_seed = torch.initial_seed()
+ # back out the base seed so we can use all the bits
+ base_seed = process_seed - worker_id
+ ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
+ # use 128 bits (4 x 32-bit words)
+ np.random.seed(ss.generate_state(4))
+ # Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module
+ torch_ss, stdlib_ss = ss.spawn(2)
+ torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0])
+ # use 128 bits expressed as an integer
+ stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum()
+ random.seed(stdlib_seed)
+
+ def set_deterministic_dataloader(self, dataloader: "DataLoader"):
+ """
+ 为了确定性训练要对 ``dataloader`` 进行修改,保证在确定随机数种子后,每次重新训练得到的结果是一样的。
+ """
+ if dataloader.worker_init_fn is None:
+ dataloader.worker_init_fn = partial(self.worker_init_function,
+ rank=int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)))
+
+ def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx: int):
+ r"""
+ 对于分布式的 ``sampler``,需要在每一个 ``epoch`` 前设置随机数种子,来保证每一个进程上的 ``shuffle`` 是一样的。
+
+ :param dataloader: 需要设置 ``epoch`` 的 ``dataloader``
+ :param cur_epoch_idx: 当前是第几个 ``epoch``
+ """
+ # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的;
+ if callable(getattr(dataloader.sampler, "set_epoch", None)):
+ dataloader.sampler.set_epoch(cur_epoch_idx)
+
+ @staticmethod
+ def get_dataloader_args(dataloader: "DataLoader"):
+ """
+ 从 ``dataloader`` 中获取参数 ``dataset``, ``batch_sampler``, ``sampler``, ``batch_size``, ``shuffle``
+ 和 ``drop_last`` 。
+ """
+ @dataclass
+ class Res:
+ dataset: Optional[Dataset] = None
+ batch_sampler: Optional[BatchSampler] = None
+ sampler: Optional[Sampler] = None
+ batch_size: Optional[int] = None
+ shuffle: Optional[bool] = None
+ drop_last: Optional[bool] = None
+
+ res = Res()
+
+ # pytorch 的 DataLoader 一定会有 dataset 属性;
+ res.dataset = dataloader.dataset
+
+ # dataloader 使用的是 sampler;
+ if dataloader.batch_sampler is None:
+ res.sampler = dataloader.sampler
+ res.batch_size = 1
+ res.shuffle = True if isinstance(dataloader.sampler, RandomSampler) else False
+ res.drop_last = False
+ # dataloader 使用的是 batch_sampler;
+ else:
+ res.batch_sampler = dataloader.batch_sampler
+ if hasattr(dataloader.batch_sampler, "batch_size"):
+ res.batch_size = getattr(dataloader.batch_sampler, "batch_size")
+ # 用户使用的是自己的 batch_sampler 并且其没有 "batch_size" 属性;
+ else:
+ dataloader_iter = iter(dataloader)
+ pre_sample = next(dataloader_iter)
+ res.batch_size = pre_sample.shape[0]
+
+ if hasattr(dataloader.batch_sampler, "sampler"):
+ res.sampler = dataloader.batch_sampler.sampler
+ if hasattr(dataloader.batch_sampler.sampler, "shuffle"):
+ res.shuffle = dataloader.batch_sampler.sampler.shuffle
+ elif isinstance(dataloader.batch_sampler.sampler, TorchRandomSampler):
+ res.shuffle = True
+ else:
+ res.shuffle = False
+ # ReproduceBatchSampler 的情况
+ elif hasattr(dataloader.batch_sampler, "batch_sampler"):
+ batch_sampler = dataloader.batch_sampler.batch_sampler
+ res.sampler = batch_sampler.sampler
+ if hasattr(batch_sampler.sampler, "shuffle"):
+ res.shuffle = dataloader.batch_sampler.sampler.shuffle
+ elif isinstance(batch_sampler.sampler, TorchRandomSampler):
+ res.shuffle = True
+ else:
+ res.shuffle = False
+ else:
+ # 如果 dataloader.batch_sampler 没有 sampler 这个属性,那么说明其使用的是自己的 batch_sampler,且没有 "sampler" 属性;
+ # 这种情况下 DataLoader 会自己初始化一个 sampler;我们因此将这个默认初始化的 sampler 挂载到 res 上;
+ res.sampler = dataloader.sampler
+ res.shuffle = False
+
+ if hasattr(dataloader.batch_sampler, "drop_last"):
+ res.drop_last = getattr(dataloader.batch_sampler, "drop_last")
+ # 用户使用的是自己的 batch_sampler 并且其没有 "drop_last" 属性;
+ else:
+ res.drop_last = False
+
+ return res
diff --git a/fastNLP/core/drivers/torch_driver/torch_fsdp.py b/fastNLP/core/drivers/torch_driver/torch_fsdp.py
new file mode 100644
index 00000000..0b1948e8
--- /dev/null
+++ b/fastNLP/core/drivers/torch_driver/torch_fsdp.py
@@ -0,0 +1,383 @@
+
+
+
+from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_12, _NEED_IMPORT_TORCH
+
+if _TORCH_GREATER_EQUAL_1_12:
+ from torch.distributed.fsdp import FullyShardedDataParallel, StateDictType, FullStateDictConfig, OptimStateKeyType
+
+if _NEED_IMPORT_TORCH:
+ import torch
+ import torch.distributed as dist
+ from torch.nn.parallel import DistributedDataParallel
+
+import os
+from typing import Optional, Union, List, Dict, Mapping
+from pathlib import Path
+
+from .ddp import TorchDDPDriver
+from fastNLP.core.drivers.torch_driver.utils import (
+ _DDPWrappingModel,
+)
+
+from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, \
+ FASTNLP_GLOBAL_RANK, rank_zero_call
+from fastNLP.core.drivers.torch_driver.utils import DummyGradScaler
+from fastNLP.core.log import logger
+from fastNLP.core.utils import check_user_specific_params
+from .utils import optimizer_state_to_device
+
+
+"""
+参考文档:
+1. https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/
+2. https://pytorch.org/docs/stable/fsdp.html?highlight=fsdp
+3. https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html
+4. https://engineering.fb.com/2021/07/15/open-source/fsdp/
+"""
+
+class TorchFSDPDriver(TorchDDPDriver):
+ r"""
+ 实现对于 pytorch 自己实现的 fully sharded data parallel;请阅读
+ `该文档 `_
+ 了解更多:
+
+ .. note::
+
+ ``TorchFSDPDriver`` 大部分行为与 ``TorchDDPDriver`` 相同,如果您不了解 ``TorchDDPDriver``,
+ 您可以先阅读 :class:`~fastNLP.core.drivers.TorchDDPDriver`;
+
+ .. warning::
+
+ ``TorchFSDPDriver`` 现在还不支持断点重训功能,但是支持保存模型和加载模型;
+
+ 注意当您在加载和保存模型的 checkpointcallback 的时候,您可以通过在初始化 ``Trainer`` 时传入
+ ``torch_kwargs={"fsdp_kwargs": {'save_on_rank0': True/False, 'load_on_rank0': True/False}}`` 来指定保存模型的行为:
+
+ 1. save/load_on_rank0 = True:表示在加载和保存模型时将所有 rank 上的模型参数全部聚合到 rank0 上,注意这样可能会造成 OOM;
+ 2. save/load_on_rank0 = False:表示每个 rank 分别保存加载自己独有的模型参数;
+
+ :param model: 传入给 ``Trainer`` 的 ``model`` 参数
+ :param parallel_device: 用于分布式训练的 ``gpu`` 设备
+ :param is_pull_by_torch_run: 标志当前的脚本的启动是否由 ``python -m torch.distributed.launch`` 启动的
+ :param fp16: 是否开启 fp16 训练
+ :param torch_kwargs:
+
+ * *fsdp_kwargs* --
+ * *set_grad_to_none* -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 ``None``
+ * *non_blocking* -- 表示用于 :meth:`torch.Tensor.to` 方法的参数 non_blocking
+ * *gradscaler_kwargs* -- 用于 ``fp16=True`` 时,提供给 :class:`torch.amp.cuda.GradScaler` 的参数
+ :kwargs:
+ * *wo_auto_param_call* (``bool``) -- 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为
+
+ .. note::
+
+ 关于该参数的详细说明,请参见 :class:`~fastNLP.core.controllers.Trainer` 中的描述;函数 ``auto_param_call`` 详见 :func:`fastNLP.core.utils.auto_param_call`。
+
+ """
+
+ def __init__(
+ self,
+ model,
+ parallel_device: Optional[Union[List["torch.device"], "torch.device"]],
+ is_pull_by_torch_run: bool = False,
+ fp16: bool = False,
+ torch_kwargs: Dict = None,
+ **kwargs
+ ):
+
+ # 在加入很多东西后,需要注意这里调用 super 函数的位置;
+ super(TorchDDPDriver, self).__init__(model, fp16=fp16, torch_kwargs=torch_kwargs, **kwargs)
+
+ if isinstance(model, torch.nn.DataParallel):
+ raise ValueError(f"Parameter `model` can not be `DataParallel` in `TorchDDPDriver`, it should be "
+ f"`torch.nn.Module` or `torch.nn.parallel.DistributedDataParallel` type.")
+
+ # 如果用户自己在外面初始化 DDP,那么其一定是通过 python -m torch.distributed.launch 拉起的;
+ self.is_pull_by_torch_run = is_pull_by_torch_run
+ self.parallel_device = parallel_device
+ if not is_pull_by_torch_run and parallel_device is None:
+ raise ValueError(
+ "Parameter `parallel_device` can not be None when using `TorchDDPDriver`. This error is caused "
+ "when your value of parameter `device` is `None` in your `Trainer` instance.")
+
+ # 注意我们在 initialize_torch_driver 中的逻辑就是如果是 is_pull_by_torch_run,那么我们就直接把 parallel_device 置为当前进程的gpu;
+ if is_pull_by_torch_run:
+ self.model_device = parallel_device
+ else:
+ # 我们的 model_device 一定是 torch.device,而不是一个 list;
+ self.model_device = parallel_device[self.local_rank]
+
+ # 如果用户自己在外面初始化了 FSDP;
+ self.outside_ddp = False
+ if dist.is_initialized() and FASTNLP_DISTRIBUTED_CHECK not in os.environ and \
+ "fastnlp_torch_launch_not_ddp" not in os.environ:
+ # 如果用户自己在外面初始化了 DDP,那么我们要求用户传入的模型一定是已经由 DistributedDataParallel 包裹后的模型;
+ if not isinstance(model, FullyShardedDataParallel):
+ raise RuntimeError(
+ "It is not allowed to input a normal model instead of `FullyShardedDataParallel` when"
+ "you initialize the ddp process out of our control.")
+ if isinstance(model, DistributedDataParallel):
+ logger.warning("You are using `TorchFSDPDriver`, but you have initialized your model as "
+ "`DistributedDataParallel`, which will make the `FullyShardedDataParallel` not work "
+ "as expected. You could just delete `DistributedDataParallel` wrap operation.")
+
+ self.outside_ddp = True
+ # 用户只有将模型上传到对应机器上后才能用 DistributedDataParallel 包裹,因此如果用户在外面初始化了 DDP,那么在 TorchDDPDriver 中
+ # 我们就直接将 model_device 置为 None;
+ self.model_device = None
+
+ # 当用户自己在外面初始化 DDP 时我们会将 model_device 置为 None,这是用户可以通过 `data_device` 将对应的数据移到指定的机器上;
+ self._data_device = kwargs.get("data_device", None)
+ if isinstance(self._data_device, int):
+ if self._data_device < 0:
+ raise ValueError("Parameter `data_device` can not be smaller than 0.")
+ _could_use_device_num = torch.cuda.device_count()
+ if self._data_device >= _could_use_device_num:
+ raise ValueError("The gpu device that parameter `device` specifies is not existed.")
+ self._data_device = torch.device(f"cuda:{self._data_device}")
+ elif isinstance(self._data_device, str):
+ self._data_device = torch.device(self._data_device)
+ elif self._data_device is not None and not isinstance(self._data_device, torch.device):
+ raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.")
+
+ self._master_port = None
+ # world_size 表示的就是全局的显卡的数量;
+ self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device)
+ self.global_rank = 0
+
+ self._fsdp_kwargs = self._torch_kwargs.get("fsdp_kwargs", {})
+ self._save_on_rank0 = self._fsdp_kwargs.get("save_on_rank0", False)
+ if "save_on_rank0" in self._fsdp_kwargs:
+ self._fsdp_kwargs.pop("save_on_rank0")
+ self._load_on_rank0 = self._fsdp_kwargs.get("load_on_rank0", False)
+ if "load_on_rank0" in self._fsdp_kwargs:
+ self._fsdp_kwargs.pop("load_on_rank0")
+
+ if self._save_on_rank0 != self._load_on_rank0:
+ logger.warning(f"Notice the behavior between ``save`` and ``load`` is not matched, you choose "
+ f"{'save on rank0' if self._save_on_rank0 else 'save on each rank'}, but "
+ f"{'load on rank0' if self._save_on_rank0 else 'load on each rank'}!")
+
+ check_user_specific_params(self._fsdp_kwargs, FullyShardedDataParallel.__init__, FullyShardedDataParallel.__name__)
+ if "cpu_offload" in self._fsdp_kwargs and kwargs["accumulation_steps"] != 1:
+ logger.warning("It is not supported ``accumulation_steps`` when using ``cpu_offload`` in "
+ "``FullyShardedDataParallel``.")
+
+ self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error")
+ assert isinstance(self.output_from_new_proc, str), "Parameter `output_from_new_proc` can only be `str` type."
+ if self.output_from_new_proc not in {"all", "ignore", "only_error"}:
+ os.makedirs(name=self.output_from_new_proc, exist_ok=True)
+ self.output_from_new_proc = os.path.abspath(self.output_from_new_proc)
+
+ self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的;
+ self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹;
+
+ def configure_ddp(self):
+ torch.cuda.set_device(self.model_device)
+ if not isinstance(self.model, FullyShardedDataParallel):
+ self.model = FullyShardedDataParallel(
+ # 注意这里的 self.model_device 是 `torch.device` type,因此 self.model_device.index;
+ _DDPWrappingModel(self.model), device_id=self.model_device.index,
+ **self._fsdp_kwargs
+ )
+
+ # 必须先使用 FullyShardedDataParallel 包裹模型后再使用 optimizer 包裹模型的参数,因此这里需要将 optimizer 重新初始化一遍;
+ for i in range(len(self.optimizers)):
+ self.optimizers[i] = type(self.optimizers[i])(self.model.parameters(), **self.optimizers[i].defaults)
+
+ self._has_ddpwrapped = True
+
+ def unwrap_model(self):
+ """
+ 注意该函数因为需要在特定的时候进行调用,例如 ddp 在 get_model_call_fn 的时候,因此不能够删除;
+ 如果您使用该函数来获取原模型的结构信息,是可以的;
+ 但是如果您想要通过该函数来获取原模型实际的参数,是不可以的,因为在 FullyShardedDataParallel 中模型被切分成了多个部分,而对于每个 gpu 上
+ 的模型只是整体模型的一部分。
+ """
+ _module = self.model.module.module
+ if isinstance(_module, _DDPWrappingModel):
+ return _module.model
+ else:
+ return _module
+
+ def save_model(self, filepath: Union[str, Path], only_state_dict: bool = True, **kwargs):
+ """
+ 保存的模型到 ``filepath`` 中。
+
+ :param filepath: 文件路径
+ :param only_state_dict: 是否只保存权重;在 ``TorchFSDPDriver`` 中只能为 ``True`` 。
+ :param kwargs:
+ :return:
+ """
+ filepath = Path(filepath)
+ prefix = filepath.parent
+ filename = filepath.name
+ _filename = filename.split('.')
+ filename, suffix = _filename[0], '.'.join(_filename[1:])
+ if only_state_dict:
+ if self._save_on_rank0:
+ full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
+ with FullyShardedDataParallel.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, full_state_dict_config):
+ state_dict = self.model.state_dict()
+ rank_zero_call(torch.save)(state_dict, filepath)
+ else:
+ # 添加 'rank0/1' 字段来区分全部聚集到 rank0 保存的方式;
+ _filename = filename.split('_')
+ filename = _filename[0] + f"_rank{int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))}_" + _filename[1]
+ filepath = prefix.joinpath(filename + "." + suffix)
+ with FullyShardedDataParallel.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT):
+ state_dict = self.model.state_dict()
+ torch.save(state_dict, filepath)
+ else:
+ raise RuntimeError("When using `TorchFSDPDriver`, only `only_state_dict=True` is allowed.")
+
+ def load_model(self, filepath: Union[Path, str], only_state_dict: bool = True, **kwargs):
+ """
+ 从 ``filepath`` 中加载权重并赋值到当前 driver 的模型上。
+
+ :param filepath: 加载权重或模型的路径
+ :param load_state_dict: 保存的内容是否只是权重;在 ``TorchFSDPDriver`` 中只能为 ``True`` 。
+ :param kwargs:
+ :return:
+ """
+ if only_state_dict is False:
+ raise RuntimeError("When using `TorchFSDPDriver`, only `only_state_dict=True` is allowed.")
+ filepath = Path(filepath)
+ prefix = filepath.parent
+ filename = filepath.name
+ _filename = filename.split('.')
+ filename, suffix = _filename[0], '.'.join(_filename[1:])
+
+ if not self._load_on_rank0:
+ _filename = filename.split('_')
+ filename = _filename[0] + f"_rank{int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))}_" + _filename[1]
+ filepath = prefix.joinpath(filename + "." + suffix)
+ states = torch.load(filepath)
+ else:
+ states = torch.load(filepath, map_location="cpu")
+
+ if isinstance(states, dict) and only_state_dict is False:
+ logger.rank_zero_warning(f"It seems like that {filepath} only contains state, you may need to use "
+ f"`only_state_dict=True`")
+ elif not isinstance(states, dict) and only_state_dict is True:
+ logger.rank_zero_warning(f"It seems like that {filepath} is not state, you may need to use "
+ f"`only_state_dict=False`")
+ if not isinstance(states, Mapping):
+ states = states.state_dict()
+
+ if self._load_on_rank0:
+ with FullyShardedDataParallel.state_dict_type(self.model, StateDictType.FULL_STATE_DICT):
+ self.model.load_state_dict(states)
+ else:
+ with FullyShardedDataParallel.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT):
+ self.model.load_state_dict(states)
+
+ def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
+ raise RuntimeError("``TorchFSDPDriver`` does not support ``save_checkpoint`` function for now, there is some "
+ "technical issues that needs to solve. You can implement your own breakpoint retraining "
+ "by rewriting this function. The important thing is how to save and load the optimizers' state dict, "
+ "you can see ``https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.full_optim_state_dict``.")
+
+ def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
+ raise RuntimeError("``TorchFSDPDriver`` does not support ``load_checkpoint`` function for now, there is some "
+ "technical issues that needs to solve. You can implement your own breakpoint retraining "
+ "by rewriting this function. The important thing is how to save and load the optimizers' state dict, "
+ "you can see ``https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.full_optim_state_dict``.")
+
+ # todo 这些加了 __ 的函数是目前还不支持;
+ # 这是因为 1.12 的 pytorch fsdp 的关于如何保存和加载 optimizer state dict 的接口有点过于反人类,无法在 fastNLP 的框架中进行调和
+ # 使用;
+ def __get_optimizer_state(self):
+ optimizers_state_dict = {}
+ for i in range(len(self.optimizers)):
+ # 注意这里其余 rank 拿到的是一个空字典,因此在真正保存的时候需要保证只有 rank0 在工作;
+ optimizer_state = FullyShardedDataParallel.full_optim_state_dict(self.model, self.optimizers[i])
+ if self._save_on_rank0:
+ with FullyShardedDataParallel.summon_full_params(self.model):
+ if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
+ unwrapped_model = self.model.module.module
+ optimizer_state = FullyShardedDataParallel.rekey_optim_state_dict(
+ optimizer_state, OptimStateKeyType.PARAM_ID, unwrapped_model)
+ if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
+ optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu"))
+ optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的;
+ return optimizers_state_dict
+
+ # 这里单独拿出来是因为对于 fsdp 来说,每一个进程都需要运行此函数,因此不能包裹 rank_zero_call;
+ def __save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
+ if not only_state_dict:
+ raise RuntimeError("When using `TorchFSDPDriver`, only `only_state_dict=True` is allowed.")
+
+ # 1. sampler 的状态;
+ num_consumed_batches = states.pop('num_consumed_batches')
+ states['sampler_states'] = self.get_sampler_state(dataloader, num_consumed_batches)
+
+ # 2. 保存模型的状态;
+ if should_save_model:
+ if not os.path.exists(folder):
+ os.mkdir(folder)
+ model_path = folder.joinpath(FASTNLP_MODEL_FILENAME)
+ self.save_model(model_path, only_state_dict=True)
+
+ # 3. 保存 optimizers 的状态;
+ states["optimizers_state_dict"] = self.get_optimizer_state()
+ logger.debug("Save optimizer state dict.")
+
+ # 4. 保存fp16的状态
+ if not isinstance(self.grad_scaler, DummyGradScaler):
+ grad_scaler_state_dict = self.grad_scaler.state_dict()
+ states['grad_scaler_state_dict'] = grad_scaler_state_dict
+
+ # 确保只有 rank0 才会执行实际的保存操作;
+ rank_zero_call(torch.save)(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME))
+
+ def __load_optimizer_state(self, states):
+ assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \
+ f"checkpoint it is:{len(states)}"
+
+ with FullyShardedDataParallel.summon_full_params(self.model):
+ unwrapped_model = self.model.module.module
+
+ for i in range(len(self.optimizers)):
+ optimizer_state = states[f'optimizer{i}']
+ if self._load_on_rank0:
+ optimizer_state = FullyShardedDataParallel.rekey_optim_state_dict(optimizer_state, OptimStateKeyType.PARAM_NAME, unwrapped_model)
+ optimizer_state = FullyShardedDataParallel.shard_full_optim_state_dict(optimizer_state, unwrapped_model)
+ optimizer: torch.optim.Optimizer = type(self.optimizers[i])(unwrapped_model.parameters(), **self.optimizers[i].defaults)
+ optimizer.load_state_dict(optimizer_state)
+ self.optimizers[i] = optimizer
+
+ logger.debug("Load optimizer state dict.")
+
+ def __load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
+ if not only_state_dict:
+ raise RuntimeError("When using `TorchFSDPDriver`, only `only_state_dict=True` is allowed.")
+
+ states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))
+
+ # 1. 加载 optimizers 的状态;
+ optimizers_state_dict = states.pop("optimizers_state_dict")
+ self.load_optimizer_state(optimizers_state_dict)
+
+ # 2. 加载模型状态;
+ if should_load_model:
+ self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict)
+
+ # 3. 加载 fp16 的状态
+ if "grad_scaler_state_dict" in states:
+ grad_scaler_state_dict = states.pop("grad_scaler_state_dict")
+ if not isinstance(self.grad_scaler, DummyGradScaler):
+ self.grad_scaler.load_state_dict(grad_scaler_state_dict)
+ logger.debug("Load grad_scaler state dict...")
+ elif not isinstance(self.grad_scaler, DummyGradScaler):
+ logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "
+ f"the training process may be unstable.")
+
+ # 4. 恢复 sampler 的状态;
+ sampler_states = states.pop('sampler_states')
+ states_ret = self.load_sampler_state(dataloader, sampler_states)
+ states.update(states_ret)
+
+ return states
+
diff --git a/fastNLP/core/drivers/torch_driver/utils.py b/fastNLP/core/drivers/torch_driver/utils.py
new file mode 100644
index 00000000..8306d33c
--- /dev/null
+++ b/fastNLP/core/drivers/torch_driver/utils.py
@@ -0,0 +1,415 @@
+import os
+
+from typing import Any, Dict, Optional, Union
+from enum import IntEnum
+import contextlib
+import random
+import numpy as np
+import inspect
+
+from fastNLP.envs.imports import _NEED_IMPORT_TORCH
+from fastNLP.envs.utils import get_global_seed
+from fastNLP.envs import (
+ get_global_rank,
+ FASTNLP_BACKEND_LAUNCH,
+ FASTNLP_GLOBAL_SEED,
+)
+from fastNLP.core.samplers import re_instantiate_sampler, ReproducibleBatchSampler, ReproducibleSampler
+from fastNLP.core.utils import auto_param_call, apply_to_collection
+from fastNLP.core.log import logger
+
+if _NEED_IMPORT_TORCH:
+ import torch
+ # import torch.nn as nn
+ from torch.nn import Module
+ from torch.utils.data import DataLoader
+ from torch.utils.data import RandomSampler as TorchRandomSampler
+ from torch.utils.data import SequentialSampler as TorchSequentialSampler
+ from torch.utils.data import BatchSampler as TorchBatchSampler
+
+else:
+ from fastNLP.core.utils.dummy_class import DummyClass as Module
+
+
+__all__ = [
+ 'torch_seed_everything',
+ 'optimizer_state_to_device'
+]
+
+def torch_seed_everything(seed: int = None, add_global_rank_to_seed: bool = True) -> int:
+ r"""
+ 为 **torch**、**numpy**、**python.random** 伪随机数生成器设置种子。
+
+ :param seed: 全局随机状态的整数值种子。如果为 ``None`` 则会根据时间戳生成一个种子。
+ :param add_global_rank_to_seed: 在分布式训练中,是否在不同 **rank** 中使用不同的随机数。
+ 当设置为 ``True`` 时,**fastNLP** 会将种子加上当前的 ``global_rank``。
+ """
+ max_seed_value = np.iinfo(np.uint32).max
+ min_seed_value = np.iinfo(np.uint32).min
+
+ if seed is None:
+ if os.getenv(FASTNLP_BACKEND_LAUNCH) == "1":
+ seed = 42
+ else:
+ seed = get_global_seed()
+ logger.info(f"'FASTNLP_GLOBAL_SEED' is set to {seed} automatically.")
+ if not isinstance(seed, int):
+ seed = int(seed)
+
+ if not (min_seed_value <= seed <= max_seed_value):
+ logger.rank_zero_warning("Your seed value is too big or too small for numpy, we will choose a random seed for you.")
+ seed %= max_seed_value
+
+ os.environ[FASTNLP_GLOBAL_SEED] = f"{seed}"
+ if add_global_rank_to_seed:
+ seed += get_global_rank()
+
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ return seed
+
+
+class ForwardState(IntEnum):
+ TRAIN = 0
+ VALIDATE = 1
+ TEST = 2
+ PREDICT = 3
+
+
+class _DDPWrappingModel(Module):
+ """
+ 该函数用于 DDP 训练时处理用户自己定制的 train_step 等函数;
+ 之所以要使用这一额外的包裹模型,是因为在使用 DDP 时,必须使用 DistributedDataParallel 的 forward 函数才能实现正常的运行;
+ 另一方面,我们要求用户在使用我们的框架时,需要针对不用的模式实现不同的处理函数,例如 'train_step', 'evaluate_step' 等;
+ 然而,当使用 DistributedDataParallel 包裹 model 后,模型看不见其除了 forward 之外的方法;并且当我们尝试在训练过程中主动提取
+ `model = model.module`,这同样会导致错误,会使得每一个gpu上的模型参数不同;
+
+ 因此出于以上考虑,我们实现了这一函数;
+ 对于更详细的解释,可以参考 'pytorch_lightning' 的 ddp 的设计;
+ """
+
+ def __init__(self, model: Module):
+ super(_DDPWrappingModel, self).__init__()
+ self.model = model
+
+ def forward(self, batch, **kwargs) -> Dict:
+ """
+ pytorch lightning 实现了先 unwrapping_model 的操作,但是感觉对于我们来说没有什么必须要,先写个注释放这里,之后有需求了再看;
+ """
+ fn = kwargs.pop("fastnlp_fn")
+ signature_fn = kwargs.pop("fastnlp_signature_fn")
+ wo_auto_param_call = kwargs.pop("wo_auto_param_call")
+
+ if isinstance(batch, Dict) and not wo_auto_param_call:
+ return auto_param_call(fn, batch, signature_fn=signature_fn)
+ else:
+ return fn(batch)
+
+class _DeepSpeedWrappingModel(_DDPWrappingModel):
+ """
+ 继承 ``_DDPWrappingModel``,区别在于进行 forward 之前先将 float 数据转换为 float16
+ """
+
+ def __init__(self, model: Module, fp16):
+ super(_DeepSpeedWrappingModel, self).__init__(model)
+ self.fp16 = fp16
+
+ def forward(self, batch, **kwargs):
+ if self.fp16:
+ batch = self._move_float_tensors_to_half(batch)
+
+ return super().forward(batch, **kwargs)
+
+ @staticmethod
+ def batch_to(data):
+ return data.half()
+
+ def _move_float_tensors_to_half(self, batch: Any):
+ batch = apply_to_collection(batch, (torch.FloatTensor, torch.cuda.FloatTensor), function=self.batch_to)
+ return batch
+
+
+class DummyGradScaler:
+ """
+ 用于Dummy pytorch的GradScaler对象,防止重复写大量的if判断
+
+ """
+
+ def __init__(self, *args, **kwargs):
+ pass
+
+ def get_scale(self):
+ return 1.0
+
+ def is_enabled(self):
+ return False
+
+ def scale(self, outputs):
+ return outputs
+
+ def step(self, optimizer, *args, **kwargs):
+ optimizer.step(*args, **kwargs)
+
+ def update(self, new_scale=None):
+ pass
+
+ def unscale_(self, optimizer):
+ pass
+
+ def load_state_dict(self, state_dict):
+ pass
+
+ def state_dict(self):
+ return {}
+
+
+def _build_fp16_env(dummy=False):
+ if dummy:
+ autocast = contextlib.ExitStack
+ GradScaler = DummyGradScaler
+ else:
+ if not torch.cuda.is_available():
+ raise RuntimeError("Pytorch is not installed in gpu version, please use device='cpu'.")
+ if torch.cuda.get_device_capability(0)[0] < 7:
+ logger.rank_zero_warning(
+ "NOTE: your device does NOT support faster training with fp16, "
+ "please switch to FP32 which is likely to be faster"
+ )
+ try:
+ from torch.cuda.amp import autocast, GradScaler
+ except ImportError:
+ raise RuntimeError("torch version too low (less than 1.6)")
+ return autocast, GradScaler
+
+
+def replace_sampler(dataloader: "DataLoader", sampler):
+ r"""
+ 替换 sampler (初始化一个新的 dataloader 的逻辑在于):
+
+ 用户可能继承了 dataloader,定制了自己的 dataloader 类,这也是我们为什么先 `inspect.signature(dataloader)` 而不是直接
+ `inspect.signature(DataLoader)` 的原因,因此同时注意到我们在外层重新初始化一个 dataloader 时也是使用的用户传进来的 dataloader
+ 的类,而不是直接的 DataLoader;
+
+ 如果需要定制自己的 dataloader,保证以下两点:
+
+ 1. 在 __init__ 方法中加入 **kwargs,这是为了方便我们将 sampler 插入到具体的 DataLoader 的构造中;
+ 2. 在 __init__ 方法中出现的参数,请务必挂为同样名字的实例属性,例如 self.one_arg_name = one_arg_name,这是因为我们只能通过属性
+ 来获取实际的参数的值;
+
+ """
+
+ # 拿到实例属性;
+ instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')}
+
+ # 'multiprocessing_context' 是 user-defined function;
+ if getattr(dataloader, 'multiprocessing_context', None) is not None:
+ instance_attrs["multiprocessing_context"] = dataloader.multiprocessing_context
+
+ # 拿到 dataloader '__init__' 函数的默认函数签名;
+ init_params = dict(inspect.signature(dataloader.__init__).parameters)
+
+ # 防止用户的 DataLoader 是继承了 pytorch 的 DataLoader,然后还是使用了 **kwargs 的方式对父类传参数
+ has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items())
+ if has_variadic_kwargs and isinstance(dataloader, DataLoader):
+ # 防止用户写入了 super().__init__(**kwargs)
+ for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items():
+ if key not in init_params and key != 'self':
+ init_params[key] = value
+
+ # 如果初始化dataloader所使用的参数不是默认值,那么我们需要将其记录下来用于重新初始化时设置;
+ non_default_params = {name for name, p in init_params.items() if
+ name in instance_attrs and p.default != instance_attrs[name]}
+ # add `dataset` as it might have been replaced with `*args`
+ non_default_params.add("dataset")
+
+ reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params}
+ if isinstance(dataloader, DataLoader):
+ reconstruct_args.update({"sampler": sampler, "shuffle": False, "batch_sampler": None})
+
+ batch_sampler = getattr(dataloader, "batch_sampler")
+ if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler):
+ raise RuntimeError("It should not be running here, please report a bug to us.")
+
+ required_args = {
+ p.name
+ for p in init_params.values()
+ if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
+ and p.default is p.empty
+ and p.name not in reconstruct_args
+ }
+
+ # 在 attribute 中没有找到这些参数,导致了没有办法重新初始化
+ if required_args:
+ required_args = sorted(required_args)
+ dataloader_self_name = dataloader.__class__.__name__
+ raise Exception(
+ f"Need to inject arguments {required_args} into the __init__ of `{dataloader_self_name}`. "
+ f"But they are not found in the attribute of `{dataloader_self_name}`, fastNLP cannot determine its "
+ f"value when try to reinitialize `{dataloader_self_name}`, please add `{required_args}` to be "
+ f"`{dataloader_self_name}`'s attribute."
+ )
+
+ # 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs;
+ if not has_variadic_kwargs:
+ # the dataloader signature does not allow keyword arguments that need to be passed
+ missing_kwargs = reconstruct_args.keys() - init_params.keys()
+ if missing_kwargs:
+ missing_kwargs = sorted(missing_kwargs)
+ dataloader_self_name = dataloader.__class__.__name__
+ raise Exception(
+ f"The parameter:{missing_kwargs} needed to reinitialize `{dataloader_self_name}` is not found."
+ )
+ # 如果没有kwargs,则保证一下只传入需要的参数
+ if not isinstance(dataloader, DataLoader):
+ reconstruct_args = {key:value for key,value in reconstruct_args.items() if key in init_params}
+
+ return type(dataloader)(**reconstruct_args)
+
+
+def replace_batch_sampler(dataloader, new_batch_sampler):
+ r"""
+ 替换一个 dataloader 的 batch_sampler;
+ """
+ params_keys = [k for k in dataloader.__dict__.keys() if not k.startswith("_")]
+ for k in ["batch_size", "sampler", "drop_last", "batch_sampler", "dataset_kind"]:
+ if k in params_keys:
+ params_keys.remove(k)
+ params = {k: getattr(dataloader, k) for k in params_keys}
+ params["batch_sampler"] = new_batch_sampler
+
+ if not isinstance(dataloader, DataLoader):
+ init_params = dict(inspect.signature(dataloader.__init__).parameters)
+ has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items())
+ if not has_variadic_kwargs:
+ params = {key:value for key,value in params.items() if key in init_params}
+
+ return type(dataloader)(**params)
+
+
+def optimizer_state_to_device(state, device):
+ r"""
+ 将一个 ``optimizer`` 的 ``state_dict`` 迁移到对应的设备。
+
+ :param state: ``optimzier.state_dict()``。
+ :param device: 要迁移到的目的设备。
+ :return: 迁移后的新的 state_dict。
+ """
+ new_state = {}
+ for name, param in state.items():
+ if isinstance(param, dict):
+ new_state[name] = optimizer_state_to_device(param, device)
+ elif isinstance(param, torch.Tensor):
+ new_state[name] = param.to(device).clone()
+ else:
+ new_state[name] = param
+ return new_state
+
+
+def _check_dataloader_args_for_distributed(args, controller='Trainer'):
+ """
+ 检查 dataloader 的 sampler 情况,如果用户替换了自己定制的 sampler ,为了防止
+ 在分布式训练中出现错误会报错。
+ """
+ error_flag = (type(args.sampler) not in {TorchRandomSampler, TorchSequentialSampler})
+ if controller == 'Trainer':
+ mode = 'training'
+ substitution = 'fastNLP.RandomSampler'
+ error_flag = (type(args.batch_sampler) != TorchBatchSampler) or error_flag
+ else: # Evaluator
+ mode = 'evaluation'
+ substitution = 'fastNLP.UnrepeatedSequentialSampler'
+ if error_flag:
+ raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause "
+ f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into "
+ f"``{substitution}``. The customized sampler should set for distributed running "
+ f"before initializing ``{controller}`` , and then set the "
+ f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``."
+ f"\n Current batch_sampler: {type(args.batch_sampler)}"
+ f"\n Current sampler: {type(args.sampler)}")
+
+def _create_default_config(
+ zero_optimization: bool = True,
+ zero_allow_untested_optimizer: bool = True,
+ logging_batch_size_per_gpu: Union[str, int] = "auto",
+ partition_activations: bool = False,
+ cpu_checkpointing: bool = False,
+ contiguous_memory_optimization: bool = False,
+ synchronize_checkpoint_boundary: bool = False,
+ offload_optimizer: bool = False,
+ offload_parameters: bool = False,
+ offload_params_device: str = "cpu",
+ nvme_path: str = "/local_nvme",
+ params_buffer_count: int = 5,
+ params_buffer_size: int = 100_000_000,
+ max_in_cpu: int = 1_000_000_000,
+ offload_optimizer_device: str = "cpu",
+ optimizer_buffer_count: int = 4,
+ pin_memory: bool = False,
+ block_size: int = 1048576,
+ queue_depth: int = 8,
+ single_submit: bool = False,
+ overlap_events: bool = True,
+ thread_count: int = 1,
+ stage: int = 2,
+ contiguous_gradients: bool = True,
+ overlap_comm: bool = True,
+ allgather_partitions: bool = True,
+ reduce_scatter: bool = True,
+ allgather_bucket_size: int = 200_000_000,
+ reduce_bucket_size: int = 200_000_000,
+ sub_group_size: int = 1_000_000_000_000,
+) -> Dict:
+ cfg = {
+ "activation_checkpointing": {
+ "partition_activations": partition_activations,
+ "cpu_checkpointing": cpu_checkpointing,
+ "contiguous_memory_optimization": contiguous_memory_optimization,
+ "synchronize_checkpoint_boundary": synchronize_checkpoint_boundary,
+ },
+ "aio": {
+ "block_size": block_size,
+ "queue_depth": queue_depth,
+ "single_submit": single_submit,
+ "overlap_events": overlap_events,
+ "thread_count": thread_count,
+ },
+ }
+ zero_kwargs = {
+ "stage": stage,
+ "contiguous_gradients": contiguous_gradients,
+ "overlap_comm": overlap_comm,
+ "allgather_partitions": allgather_partitions,
+ "reduce_scatter": reduce_scatter,
+ "allgather_bucket_size": allgather_bucket_size,
+ "reduce_bucket_size": reduce_bucket_size,
+ "sub_group_size": sub_group_size,
+ }
+ if zero_optimization:
+ zero_config = zero_kwargs
+
+ if offload_optimizer:
+ zero_config["offload_optimizer"] = {
+ "device": offload_optimizer_device,
+ "nvme_path": nvme_path,
+ "buffer_count": optimizer_buffer_count,
+ "pin_memory": pin_memory,
+ }
+ if offload_parameters:
+ zero_config["offload_param"] = {
+ "device": offload_params_device,
+ "nvme_path": nvme_path,
+ "buffer_count": params_buffer_count,
+ "buffer_size": params_buffer_size,
+ "max_in_cpu": max_in_cpu,
+ "pin_memory": pin_memory,
+ }
+ cfg = {
+ "zero_allow_untested_optimizer": zero_allow_untested_optimizer,
+ "zero_optimization": zero_config,
+ **cfg,
+ }
+ if logging_batch_size_per_gpu != "auto":
+ cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg}
+ return cfg
\ No newline at end of file
diff --git a/fastNLP/core/drivers/utils.py b/fastNLP/core/drivers/utils.py
new file mode 100644
index 00000000..f4f4ed0a
--- /dev/null
+++ b/fastNLP/core/drivers/utils.py
@@ -0,0 +1,32 @@
+from typing import List
+import subprocess
+
+__all__ = []
+
+def distributed_open_proc(output_from_new_proc:str, command:List[str], env_copy:dict, rank:int=None):
+ r"""
+ 使用 command 通过 subprocess.Popen 开启新的进程。
+
+ :param output_from_new_proc: 可选 ``["ignore", "all", "only_error"]``,以上三个为特殊关键字,分别表示:
+ * ``"ignore:`` -- 完全忽略拉起进程的打印输出;
+ * ``"only_error"`` -- 表示只打印错误输出流;
+ * ``"all"`` -- 子进程的所有输出都打印。
+ * 如果不为以上的关键字,则表示一个文件夹,将在该文件夹下建立两个文件,名称分别为 {rank}_std.log, {rank}_err.log 。
+ 原有的文件会被直接覆盖。
+ :param command: 启动的命令
+ :param env_copy: 需要注入的环境变量。
+ :param rank: global_rank;
+ :return: 使用 ``subprocess.Popen`` 打开的进程;
+ """
+ if output_from_new_proc == "all":
+ proc = subprocess.Popen(command, env=env_copy)
+ elif output_from_new_proc == "only_error":
+ proc = subprocess.Popen(command, env=env_copy, stdout=subprocess.DEVNULL)
+ elif output_from_new_proc == "ignore":
+ proc = subprocess.Popen(command, env=env_copy, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
+ else:
+ assert rank is not None
+ std_f = open(output_from_new_proc + f'/{rank}_std.log', 'w')
+ err_f = open(output_from_new_proc + f'/{rank}_err.log', 'w')
+ proc = subprocess.Popen(command, env=env_copy, stdout=std_f, stderr=err_f)
+ return proc
diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py
deleted file mode 100644
index 9834a653..00000000
--- a/fastNLP/core/field.py
+++ /dev/null
@@ -1,696 +0,0 @@
-r"""
-.. todo::
- doc
-"""
-
-__all__ = [
- "Padder",
- "AutoPadder",
- "EngChar2DPadder",
-]
-
-from abc import abstractmethod
-from collections import Counter
-from copy import deepcopy
-from numbers import Number
-from typing import Any
-
-import numpy as np
-import torch
-
-from ._logger import logger
-from .utils import _is_iterable
-
-
-class SetInputOrTargetException(Exception):
- def __init__(self, msg, index=None, field_name=None):
- super().__init__(msg)
- self.msg = msg
- self.index = index # 标示在哪个数据遭遇到问题了
- self.field_name = field_name # 标示当前field的名称
-
-
-class AppendToTargetOrInputException(Exception):
- def __init__(self, msg, index=None, field_name=None):
- super().__init__(msg)
- self.msg = msg
- self.index = index # 标示在哪个数据遭遇到问题了
- self.field_name = field_name # 标示当前field的名称
-
-
-def _get_ele_type_and_dim(cell: Any, dim=0):
- r"""
- 识别cell的类别与dimension的数量
-
- numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
- :param cell:
- :param dim:
- :return:
- """
- if isinstance(cell, (str, Number, np.bool_)):
- if hasattr(cell, 'dtype'):
- return cell.dtype.type, dim
- return type(cell), dim
- elif isinstance(cell, list):
- dim += 1
- res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell]
- types = set([i for i, j in res])
- dims = set([j for i, j in res])
- if len(types) > 1:
- raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types)))
- elif len(types) == 0:
- raise SetInputOrTargetException("Empty value encountered.")
- if len(dims) > 1:
- raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims)))
- return types.pop(), dims.pop()
- elif isinstance(cell, torch.Tensor):
- return cell.dtype, cell.dim() + dim # 如果是torch.mean的结果是0
- elif isinstance(cell, np.ndarray):
- if cell.dtype != np.dtype('O'): # 如果不是object的话说明是well-formatted的了
- return cell.dtype.type, cell.ndim + dim # dtype.type返回的会是np.int32, np.float等
- # 否则需要继续往下iterate
- dim += 1
- res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell]
- types = set([i for i, j in res])
- dims = set([j for i, j in res])
- if len(types) > 1:
- raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types)))
- elif len(types) == 0:
- raise SetInputOrTargetException("Empty value encountered.")
- if len(dims) > 1:
- raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims)))
- return types.pop(), dims.pop()
- else: # 包含tuple, set, dict以及其它的类型
- raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.")
-
-
-class Padder:
- r"""
- 所有padder都需要继承这个类,并覆盖__call__方法。
- 用于对batch进行padding操作。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前deepcopy一份。
-
- .. py:function:: __call__(self, contents, field_name, field_ele_dtype):
-
- """
-
- def __init__(self, pad_val=0, **kwargs):
- r"""
-
- :param List[Any] contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前
- deepcopy一份。
- :param str, field_name: field的名称。
- :param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。
- :return: np.array([padded_element])
- """
- self.pad_val = pad_val
-
- def set_pad_val(self, pad_val):
- self.pad_val = pad_val
-
- def get_pad_val(self):
- return self.pad_val
-
- @abstractmethod
- def __call__(self, contents, field_name, field_ele_dtype, dim: int):
- r"""
- 传入的是List内容。假设有以下的DataSet。
-
- :param List[Any] contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前
- deepcopy一份。
- :param str, field_name: field的名称。
- :param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,
- 该这个值为None。
- :param dim: 这个field的维度。当ignore_type为True时,该值为None
- :return: np.array([padded_element])
-
- Example::
-
- from fastNLP import DataSet
- from fastNLP import Instance
- dataset = DataSet()
- dataset.append(Instance(sent='this is a demo', length=4,
- chars=[['t', 'h', 'i', 's'], ['i', 's'], ['a'], ['d', 'e', 'm', 'o']]))
- dataset.append(Instance(sent='another one', length=2,
- chars=[['a', 'n', 'o', 't', 'h', 'e', 'r'], ['o', 'n', 'e']]))
- 如果调用
- batch = dataset.get([0,1], pad=True)
- sent这个field的padder的__call__会接收到的内容会是
- [
- 'this is a demo',
- 'another one'
- ]
-
- length这个field的padder的__call__会接收到的内容会是
- [4, 2]
-
- chars这个field的padder的__call__会接收到的内容会是
- [
- [['t', 'h', 'i', 's'], ['i', 's'], ['a'], ['d', 'e', 'm', 'o']],
- [['a', 'n', 'o', 't', 'h', 'e', 'r'], ['o', 'n', 'e']]
- ]
-
- 即把每个instance中某个field的内容合成一个List传入
-
- """
- raise NotImplementedError
-
-
-class AutoPadder(Padder):
- r"""
- 根据contents的数据自动判定是否需要做padding。
-
- 1 如果元素类型(元素类型是指field中最里层元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类
- 型为str, [[1,2], ...]的元素类型为int)的数据不为数值类型则不会进行pad
-
- 2 如果元素类型为数值类型,比如np.int64, np.float64, int, float, torch.int64等
-
- 2.1 如果该field的内容为数值类型(包括int, float等),比如为seq_len, 则不进行padding
-
- 2.2 如果该field的内容等价于一维list, 那么会将Batch中的List pad为一样长。
-
- 2.3 如果该field的内容等价于二维list,那么会按照英语character padding的方式进行padding。如果是character padding建议使用
- :class: fastNLP.EngChar2DPadder.
-
- 2.4 如果该field的内容等价于三维list,则如果每个instance在每个维度上相等,会组成一个batch的tensor返回,这种情况应该是为图片
- 的情况。
-
- 3 其它情况不进行处理,返回一个np.array类型。
- """
-
- def __init__(self, pad_val=0):
- super().__init__(pad_val=pad_val)
-
- def __call__(self, contents, field_name, field_ele_dtype, dim):
- if field_ele_dtype:
- if dim > 3:
- return np.array(contents)
- if isinstance(field_ele_dtype, type) and \
- (issubclass(field_ele_dtype, np.number) or issubclass(field_ele_dtype, Number)):
- if dim == 0:
- array = np.array(contents, dtype=field_ele_dtype)
- elif dim == 1:
- max_len = max(map(len, contents))
- array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype)
- for i, content_i in enumerate(contents):
- array[i, :len(content_i)] = content_i
- elif dim == 2:
- max_len = max(map(len, contents))
- max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
- content_i in contents])
- array = np.full((len(contents), max_len, max_word_len), self.pad_val, dtype=field_ele_dtype)
- for i, content_i in enumerate(contents):
- for j, content_ii in enumerate(content_i):
- array[i, j, :len(content_ii)] = content_ii
- else:
- shape = np.shape(contents)
- if len(shape) == 4: # 说明各dimension是相同的大小
- array = np.array(contents, dtype=field_ele_dtype)
- else:
- raise RuntimeError(
- f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
- return array
- elif str(field_ele_dtype).startswith('torch'):
- if dim == 0:
- tensor = torch.tensor(contents).to(field_ele_dtype)
- elif dim == 1:
- max_len = max(map(len, contents))
- tensor = torch.full((len(contents), max_len), fill_value=self.pad_val, dtype=field_ele_dtype)
- for i, content_i in enumerate(contents):
- tensor[i, :len(content_i)] = content_i.clone().detach()
- elif dim == 2:
- max_len = max(map(len, contents))
- max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
- content_i in contents])
- tensor = torch.full((len(contents), max_len, max_word_len), fill_value=self.pad_val,
- dtype=field_ele_dtype)
- for i, content_i in enumerate(contents):
- for j, content_ii in enumerate(content_i):
- tensor[i, j, :len(content_ii)] = content_ii.clone().detach()
- else:
- shapes = set([np.shape(content_i) for content_i in contents])
- if len(shapes) > 1:
- raise RuntimeError(
- f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
- shape = shapes.pop()
- if len(shape) == 3:
- tensor = torch.full([len(contents)] + list(shape), fill_value=self.pad_val,
- dtype=field_ele_dtype)
- for i, content_i in enumerate(contents):
- tensor[i] = content_i.clone().detach().to(field_ele_dtype)
- else:
- raise RuntimeError(
- f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
- return tensor
- else:
- return np.array(contents) # 不进行任何操作
- else:
- return np.array(contents)
-
-
-class EngChar2DPadder(Padder):
- r"""
- 用于为英语执行character级别的2D padding操作。对应的field内容应该类似[['T', 'h', 'i', 's'], ['a'], ['d', 'e', 'm', 'o']],
- 但这个Padder只能处理index为int的情况。
-
- padded过后的batch内容,形状为(batch_size, max_sentence_length, max_word_length). max_sentence_length为这个batch中最大句
- 子长度;max_word_length为这个batch中最长的word的长度::
-
- from fastNLP import DataSet
- from fastNLP import EngChar2DPadder
- from fastNLP import Vocabulary
- dataset = DataSet({'sent': ['This is the first demo', 'This is the second demo']})
- dataset.apply(lambda ins:[list(word) for word in ins['sent'].split()], new_field_name='chars')
- vocab = Vocabulary()
- vocab.from_dataset(dataset, field_name='chars')
- vocab.index_dataset(dataset, field_name='chars')
- dataset.set_input('chars')
- padder = EngChar2DPadder()
- dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder
-
- """
-
- def __init__(self, pad_val=0, pad_length=0):
- r"""
- :param pad_val: int, pad的位置使用该index
- :param pad_length: int, 如果为0则取一个batch中最大的单词长度作为padding长度。如果为大于0的数,则将所有单词的长度
- 都pad或截取到该长度.
- """
- super().__init__(pad_val=pad_val)
-
- self.pad_length = pad_length
-
- def __call__(self, contents, field_name, field_ele_dtype, dim):
- r"""
- 期望输入类似于
- [
- [[0, 2], [2, 3, 4], ..],
- [[9, 8, 2, 4], [1, 2,], ...],
- ....
- ]
-
- :param contents:
- :param field_name:
- :param field_ele_dtype
- :return:
- """
- if field_ele_dtype not in (np.int64, np.float64, int, float):
- raise TypeError('dtype of Field:{} should be np.int64 or np.float64 to do 2D padding, get {}.'.format(
- field_name, field_ele_dtype
- ))
- assert dim == 2, f"Field:{field_name} has {dim}, EngChar2DPadder only supports input with 2 dimensions."
- if self.pad_length < 1:
- max_char_length = max([max(len(char_lst) for char_lst in word_lst) for word_lst in contents])
- else:
- max_char_length = self.pad_length
- max_sent_length = max(len(word_lst) for word_lst in contents)
- batch_size = len(contents)
- dtype = type(contents[0][0][0])
-
- padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val,
- dtype=dtype)
- for b_idx, word_lst in enumerate(contents):
- for c_idx, char_lst in enumerate(word_lst):
- chars = char_lst[:max_char_length]
- padded_array[b_idx, c_idx, :len(chars)] = chars
-
- return padded_array
-
-
-class FieldArray:
- def __init__(self, name, content, is_target=False, is_input=False, padder=AutoPadder(), ignore_type=False,
- use_1st_ins_infer_dim_type=True):
- if len(content) == 0:
- raise RuntimeError("Empty fieldarray is not allowed.")
- _content = content
- try:
- _content = list(_content)
- except BaseException as e:
- logger.error(f"Cannot convert content(of type:{type(content)}) into list.")
- raise e
- self.name = name
- self.content = _content
- self._ignore_type = ignore_type
- # 根据input的情况设置input,target等
- self._cell_ndim = None # 多少维度, 如果value是1, dim为0; 如果value是[1, 2], dim=2
- self.dtype = None # 最内层的element都是什么类型的
- self._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type)
- self._is_input = False
- self._is_target = False
-
- if is_input:
- self.is_input = is_input
- if is_target:
- self.is_target = is_target
-
- self.set_padder(padder)
-
- @property
- def ignore_type(self):
- return self._ignore_type
-
- @ignore_type.setter
- def ignore_type(self, value):
- if value:
- self._cell_ndim = None
- self.dtype = None
- self._ignore_type = value
-
- @property
- def is_input(self):
- return self._is_input
-
- @is_input.setter
- def is_input(self, value):
- r"""
- 当 field_array.is_input = True / False 时被调用
- """
- # 如果(value为True)且(_is_input和_is_target都是False)且(ignore_type为False)
- if value is True and \
- self._is_target is False and \
- self._ignore_type is False:
- self._check_dtype_and_ndim(only_check_1st_ins_dim_type=self._use_1st_ins_infer_dim_type)
- if value is False and self._is_target is False:
- self.dtype = None
- self._cell_ndim = None
- self._is_input = value
-
- @property
- def is_target(self):
- return self._is_target
-
- @is_target.setter
- def is_target(self, value):
- r"""
- 当 field_array.is_target = True / False 时被调用
- """
- if value is True and \
- self._is_input is False and \
- self._ignore_type is False:
- self._check_dtype_and_ndim(only_check_1st_ins_dim_type=self._use_1st_ins_infer_dim_type)
- if value is False and self._is_input is False:
- self.dtype = None
- self._cell_ndim = None
- self._is_target = value
-
- def _check_dtype_and_ndim(self, only_check_1st_ins_dim_type=True):
- r"""
- 检查当前content所有的element是否是同一个类型,且是否每个元素具有相同的维度。通过的话,设置_cell_ndim与_ele_type属性;没有
- 通过将直接报错.
-
- :param bool only_check_1st_ins_dim_type: 是否只检查第一个元素的type和dim
- :return:
- """
- cell_0 = self.content[0]
- index = 0
- try:
- type_0, dim_0 = _get_ele_type_and_dim(cell_0)
- if not only_check_1st_ins_dim_type:
- for cell in self.content[1:]:
- index += 1
- type_i, dim_i = _get_ele_type_and_dim(cell)
- if type_i != type_0:
- raise SetInputOrTargetException(
- "Type:{} in index {} is different from the first element with type:{}."
- ".".format(type_i, index, type_0))
- if dim_0 != dim_i:
- raise SetInputOrTargetException(
- "Dimension:{} in index {} is different from the first element with "
- "dimension:{}.".format(dim_i, index, dim_0))
- self._cell_ndim = dim_0
- self.dtype = type_0
- except SetInputOrTargetException as e:
- e.index = index
- raise e
-
- def append(self, val: Any):
- r"""
- :param val: 把该val append到fieldarray。
- :return:
- """
- if (self._is_target or self._is_input) and self._ignore_type is False and not self._use_1st_ins_infer_dim_type:
- type_, dim_ = _get_ele_type_and_dim(val)
- if self.dtype != type_:
- raise AppendToTargetOrInputException(f"Value(type:{type_}) are of different types with "
- f"previous values(type:{self.dtype}).")
- if self._cell_ndim != dim_:
- raise AppendToTargetOrInputException(f"Value(dim:{dim_}) are of different dimensions with "
- f"previous values(dim:{self._cell_ndim}).")
- self.content.append(val)
- else:
- self.content.append(val)
-
- def pop(self, index):
- r"""
- 删除该field中index处的元素
- :param int index: 从0开始的数据下标。
- :return:
- """
- self.content.pop(index)
-
- def __getitem__(self, indices):
- return self.get(indices, pad=False)
-
- def __setitem__(self, idx, val):
- assert isinstance(idx, int)
- if (self._is_target or self._is_input) and self.ignore_type is False: # 需要检测类型
- type_, dim_ = _get_ele_type_and_dim(val)
- if self.dtype != type_:
- raise RuntimeError(f"Value(type:{type_}) are of different types with "
- f"other values(type:{self.dtype}).")
- if self._cell_ndim != dim_:
- raise RuntimeError(f"Value(dim:{dim_}) are of different dimensions with "
- f"previous values(dim:{self._cell_ndim}).")
- self.content[idx] = val
-
- def get(self, indices, pad=True):
- r"""
- 根据给定的indices返回内容。
-
- :param int,List[int] indices: 获取indices对应的内容。
- :param bool pad: 是否对返回的结果进行padding。仅对: (1) indices为List[int]; (2)padder不为None; (3)field设置了input
- 或target,有效
- :return: 根据给定的indices返回的内容,可能是单个值或ndarray
- """
- if isinstance(indices, int):
- return self.content[indices]
-
- contents = [self.content[i] for i in indices]
- if self.padder is None or pad is False:
- return np.array(contents)
- elif self.is_input or self.is_target:
- return self.pad(contents)
- else:
- return np.array(contents)
-
- def pad(self, contents):
- r"""
- 传入list的contents,将contents使用padder进行padding,contents必须为从本FieldArray中取出的。
-
- :param list contents:
- :return:
- """
- return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim)
-
- def set_padder(self, padder):
- r"""
- 设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。
-
- :param padder: :class:`~fastNLP.Padder` 类型,设置为None即删除padder。
- """
- if padder is not None:
- assert isinstance(padder, Padder), "padder must be of type `fastNLP.core.Padder`."
- self.padder = deepcopy(padder)
- else:
- self.padder = None
-
- def set_pad_val(self, pad_val):
- r"""
- 修改padder的pad_val.
-
- :param int pad_val: 该field的pad值设置为该值。
- """
- if self.padder is not None:
- self.padder.set_pad_val(pad_val)
- return self
-
- def __len__(self):
- r"""
- Returns the size of FieldArray.
-
- :return int length:
- """
- return len(self.content)
-
- def to(self, other):
- r"""
- 将other的属性复制给本FieldArray(other必须为FieldArray类型).
- 属性包括 is_input, is_target, padder, ignore_type
-
- :param other: :class:`~fastNLP.FieldArray` 从哪个field拷贝属性
- :return: :class:`~fastNLP.FieldArray`
- """
- assert isinstance(other, FieldArray), "Only supports fastNLP.FieldArray type, not {}.".format(type(other))
-
- self.ignore_type = other.ignore_type
- self.is_input = other.is_input
- self.is_target = other.is_target
- self.padder = other.padder
-
- return self
-
- def split(self, sep: str = None, inplace: bool = True):
- r"""
- 依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。将返回值
-
- :param sep: 分割符,如果为None则直接调用str.split()。
- :param inplace: 如果为True,则将新生成值替换本field。否则返回list。
- :return: List[List[str]] or self
- """
- new_contents = []
- for index, cell in enumerate(self.content):
- try:
- new_contents.append(cell.split(sep))
- except Exception as e:
- logger.error(f"Exception happens when process value in index {index}.")
- raise e
- return self._after_process(new_contents, inplace=inplace)
-
- def int(self, inplace: bool = True):
- r"""
- 将本field中的值调用int(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
- (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。)
-
- :param inplace: 如果为True,则将新生成值替换本field。否则返回list。
- :return: List[int], List[List[int]], self
- """
- new_contents = []
- for index, cell in enumerate(self.content):
- try:
- if isinstance(cell, list):
- new_contents.append([int(value) for value in cell])
- else:
- new_contents.append(int(cell))
- except Exception as e:
- logger.error(f"Exception happens when process value in index {index}.")
- raise e
- return self._after_process(new_contents, inplace=inplace)
-
- def float(self, inplace=True):
- r"""
- 将本field中的值调用float(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
- (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。)
-
- :param inplace: 如果为True,则将新生成值替换本field。否则返回list。
- :return:
- """
- new_contents = []
- for index, cell in enumerate(self.content):
- try:
- if isinstance(cell, list):
- new_contents.append([float(value) for value in cell])
- else:
- new_contents.append(float(cell))
- except Exception as e:
- logger.error(f"Exception happens when process value in index {index}.")
- raise e
- return self._after_process(new_contents, inplace=inplace)
-
- def bool(self, inplace=True):
- r"""
- 将本field中的值调用bool(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
- (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。)
-
- :param inplace: 如果为True,则将新生成值替换本field。否则返回list。
- :return:
- """
- new_contents = []
- for index, cell in enumerate(self.content):
- try:
- if isinstance(cell, list):
- new_contents.append([bool(value) for value in cell])
- else:
- new_contents.append(bool(cell))
- except Exception as e:
- logger.error(f"Exception happens when process value in index {index}.")
- raise e
-
- return self._after_process(new_contents, inplace=inplace)
-
- def lower(self, inplace=True):
- r"""
- 将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
- (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。)
-
- :param inplace: 如果为True,则将新生成值替换本field。否则返回list。
- :return: List[int], List[List[int]], self
- """
- new_contents = []
- for index, cell in enumerate(self.content):
- try:
- if isinstance(cell, list):
- new_contents.append([value.lower() for value in cell])
- else:
- new_contents.append(cell.lower())
- except Exception as e:
- logger.error(f"Exception happens when process value in index {index}.")
- raise e
- return self._after_process(new_contents, inplace=inplace)
-
- def upper(self, inplace=True):
- r"""
- 将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
- (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。)
-
- :param inplace: 如果为True,则将新生成值替换本field。否则返回list。
- :return: List[int], List[List[int]], self
- """
- new_contents = []
- for index, cell in enumerate(self.content):
- try:
- if isinstance(cell, list):
- new_contents.append([value.upper() for value in cell])
- else:
- new_contents.append(cell.upper())
- except Exception as e:
- logger.error(f"Exception happens when process value in index {index}.")
- raise e
- return self._after_process(new_contents, inplace=inplace)
-
- def value_count(self):
- r"""
- 返回该field下不同value的数量。多用于统计label数量
-
- :return: Counter, key是label,value是出现次数
- """
- count = Counter()
-
- def cum(cell):
- if _is_iterable(cell) and not isinstance(cell, str):
- for cell_ in cell:
- cum(cell_)
- else:
- count[cell] += 1
-
- for cell in self.content:
- cum(cell)
- return count
-
- def _after_process(self, new_contents, inplace):
- r"""
- 当调用处理函数之后,决定是否要替换field。
-
- :param new_contents:
- :param inplace:
- :return: self或者生成的content
- """
- if inplace:
- self.content = new_contents
- try:
- self.is_input = self.is_input
- self.is_target = self.is_input
- except SetInputOrTargetException as e:
- logger.error("The newly generated field cannot be set as input or target.")
- raise e
- return self
- else:
- return new_contents
diff --git a/fastNLP/core/instance.py b/fastNLP/core/instance.py
deleted file mode 100644
index 83e3903e..00000000
--- a/fastNLP/core/instance.py
+++ /dev/null
@@ -1,61 +0,0 @@
-r"""
-instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可以认为是一个Instance类型的对象。
-便于理解的例子可以参考文档 :mod:`fastNLP.core.dataset` 中的表格
-
-"""
-
-__all__ = [
- "Instance"
-]
-
-from .utils import pretty_table_printer
-
-
-class Instance(object):
- r"""
- Instance是fastNLP中对应一个sample的类。每个sample在fastNLP中是一个Instance对象。
- Instance一般与 :class:`~fastNLP.DataSet` 一起使用, Instance的初始化如下面的Example所示::
-
- >>>from fastNLP import Instance
- >>>ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2])
- >>>ins["field_1"]
- [1, 1, 1]
- >>>ins.add_field("field_3", [3, 3, 3])
- >>>ins = Instance(**{'x1': 1, 'x2':np.zeros((3, 4))})
- """
-
- def __init__(self, **fields):
-
- self.fields = fields
-
- def add_field(self, field_name, field):
- r"""
- 向Instance中增加一个field
-
- :param str field_name: 新增field的名称
- :param Any field: 新增field的内容
- """
- self.fields[field_name] = field
-
- def items(self):
- r"""
- 返回一个迭代器,迭代器返回两个内容,第一个内容是field_name, 第二个内容是field_value
-
- :return: 一个迭代器
- """
- return self.fields.items()
-
- def __contains__(self, item):
- return item in self.fields
-
- def __getitem__(self, name):
- if name in self.fields:
- return self.fields[name]
- else:
- raise KeyError("{} not found".format(name))
-
- def __setitem__(self, name, field):
- return self.add_field(name, field)
-
- def __repr__(self):
- return str(pretty_table_printer(self))
diff --git a/fastNLP/core/log/__init__.py b/fastNLP/core/log/__init__.py
new file mode 100644
index 00000000..d1d95f20
--- /dev/null
+++ b/fastNLP/core/log/__init__.py
@@ -0,0 +1,8 @@
+__all__ = [
+ 'logger',
+ "print"
+]
+
+from .logger import logger
+from .print import print
+
diff --git a/fastNLP/core/log/handler.py b/fastNLP/core/log/handler.py
new file mode 100644
index 00000000..40931c26
--- /dev/null
+++ b/fastNLP/core/log/handler.py
@@ -0,0 +1,90 @@
+import logging
+import sys
+from logging import getLevelName
+
+try:
+ from tqdm.auto import tqdm
+except ImportError:
+ tqdm = None
+
+__all__ = []
+
+if tqdm is not None:
+ class TqdmLoggingHandler(logging.Handler):
+ def __init__(self, level=logging.INFO):
+ super().__init__(level)
+
+ def emit(self, record):
+ try:
+ msg = self.format(record)
+ tqdm.write(msg)
+ self.flush()
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except:
+ self.handleError(record)
+else:
+ class TqdmLoggingHandler(logging.StreamHandler):
+ def __init__(self, level=logging.INFO):
+ super().__init__(sys.stdout)
+ self.setLevel(level)
+
+
+class StdoutStreamHandler(logging.StreamHandler):
+ """
+ 重载 StreamHandler 使得替换 sys.stdout 的时候能够生效。
+
+ """
+ def __init__(self):
+ super(StdoutStreamHandler, self).__init__()
+
+ def flush(self):
+ """
+ Flushes the stream.
+ """
+ self.acquire()
+ try:
+ sys.stdout.flush()
+ finally:
+ self.release()
+
+ def emit(self, record):
+ """
+ Emit a record.
+
+ If a formatter is specified, it is used to format the record.
+ The record is then written to the stream with a trailing newline. If
+ exception information is present, it is formatted using
+ traceback.print_exception and appended to the stream. If the stream
+ has an 'encoding' attribute, it is used to determine how to do the
+ output to the stream.
+ """
+ try:
+ msg = self.format(record)
+ stream = sys.stdout
+ # issue 35046: merged two stream.writes into one.
+ stream.write(msg + self.terminator)
+ self.flush()
+ except RecursionError: # See issue 36272
+ raise
+ except Exception:
+ self.handleError(record)
+
+ def setStream(self, stream):
+ """
+ Sets the StreamHandler's stream to the specified value,
+ if it is different.
+
+ Returns the old stream, if the stream was changed, or None
+ if it wasn't.
+ """
+ raise RuntimeError("Cannot set the stream of FStreamHandler.")
+
+ def __repr__(self):
+ level = getLevelName(self.level)
+ name = getattr(sys.stdout, 'name', '')
+ # bpo-36015: name can be an int
+ name = str(name)
+ if name:
+ name += ' '
+ return '<%s %s(%s)>' % (self.__class__.__name__, name, level)
diff --git a/fastNLP/core/log/highlighter.py b/fastNLP/core/log/highlighter.py
new file mode 100644
index 00000000..2935366d
--- /dev/null
+++ b/fastNLP/core/log/highlighter.py
@@ -0,0 +1,9 @@
+from rich.highlighter import Highlighter
+
+__all__ = []
+class ColorHighlighter(Highlighter):
+ def __init__(self, color='black'):
+ self.color = color
+
+ def highlight(self, text):
+ text.stylize(self.color)
\ No newline at end of file
diff --git a/fastNLP/core/log/logger.py b/fastNLP/core/log/logger.py
new file mode 100644
index 00000000..8761c2fe
--- /dev/null
+++ b/fastNLP/core/log/logger.py
@@ -0,0 +1,386 @@
+r"""
+:class:`Logger` 是 **fastNLP** 中记录日志的模块,**logger** 封装了 logging 模块的 Logger,
+具体使用方式与直接使用 :class:`logging.Logger` 相同,同时也新增一些简单好用的API
+
+使用方式::
+
+ from fastNLP import logger
+
+ # logger 可以和 logging.Logger 一样使用
+ logger.info('your msg')
+ logger.error('your msg')
+
+ # logger 新增的API
+ # 将日志输出到文件,以及输出的日志等级
+ logger.add_file('/path/to/log', level='INFO')
+ # 定义在命令行中的显示格式和日志等级
+ logger.set_stdout('tqdm', level='WARN')
+ # 仅警告一次
+ logger.warning_once('your msg')
+ # 分布式训练下,仅在 rank 0 输出警告
+ logger.rank_zero_warning('your msg')
+
+"""
+
+
+import logging
+import logging.config
+from logging import DEBUG, ERROR, INFO, WARNING, CRITICAL, raiseExceptions
+import os
+import sys
+import warnings
+from pathlib import Path
+from typing import Optional, Union
+from rich.logging import RichHandler
+
+__all__ = [
+ 'logger'
+]
+
+from fastNLP.core.log.handler import StdoutStreamHandler, TqdmLoggingHandler
+from fastNLP.envs.env import FASTNLP_LOG_LEVEL, FASTNLP_GLOBAL_RANK, FASTNLP_LAUNCH_TIME, FASTNLP_BACKEND_LAUNCH
+from fastNLP.envs.distributed import is_cur_env_distributed
+
+
+ROOT_NAME = 'fastNLP'
+
+
+class LoggerSingleton(type):
+ _instances = {}
+
+ def __call__(cls, *args, **kwargs):
+ if cls not in cls._instances:
+ cls._instances[cls] = super(LoggerSingleton, cls).__call__(*args, **kwargs)
+ return cls._instances[cls]
+
+
+class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton):
+ def __init__(self, name):
+ super().__init__(name)
+ self._warning_msgs = set()
+
+ def add_file(self, path: Optional[Union[str, Path]] = None, level='AUTO', remove_other_handlers: bool = False,
+ mode: str = "w"):
+ """
+ 将日志输出到 path 中。
+
+ :param path: 若 path 为文件路径(通过 path 是否包含后缀判定 path 是否表示文件名,例如 output.log 会被认为是文件,而
+ output 则认为是文件夹)则直接写入到给定文件中;如果判定为文件夹,则是在该文件夹下以 时间戳 创建一个日志文件。
+ :param level: 可选 ['INFO', 'WARNING', 'DEBUG', 'ERROR', 'AUTO'], 其中AUTO表示根据环境变量"FASTNLP_LOG_LEVEL'进行
+ 设置。
+ :param remove_other_handlers: 是否移除其它 handler ,如果移除,则terminal中将不会有 log 输出。
+ :param mode: 可选为['w', 'a'],如果传入的 path 是存在的文件,'w' 会覆盖原有内容 'a' 则会在文件结尾处继续添加。
+ :return:
+ """
+ r"""添加日志输出文件和输出级别"""
+ if level == 'AUTO':
+ level = parse_level()
+ return _add_file_handler(self, path, level, remove_other_handlers, mode)
+
+ def set_stdout(self, stdout: str = 'raw', level: str = 'AUTO'):
+ """
+ 设置 log 的 terminal 输出形式。
+
+ :param stdout: 可选['rich', 'naive', 'raw', 'none']。
+ :param level: 可选 ['INFO', 'WARNING', 'DEBUG', 'ERROR', 'AUTO'], 其中AUTO表示根据环境变量"FASTNLP_LOG_LEVEL'进行
+ 设置。
+ :return:
+ """
+ r"""设置标准输出格式和输出级别"""
+ if level == 'AUTO':
+ level = parse_level()
+ return _set_stdout_handler(self, stdout, level)
+
+ def debug(self, msg, *args, **kwargs):
+ """
+ Delegate a debug call to the underlying log.
+ """
+ if self.isEnabledFor(DEBUG):
+ kwargs = self._add_rank_info(kwargs)
+ self._log(DEBUG, msg, args, **kwargs)
+
+ def info(self, msg, *args, **kwargs):
+ """
+ Delegate an info call to the underlying log.
+ """
+ if self.isEnabledFor(INFO):
+ kwargs = self._add_rank_info(kwargs)
+ self._log(INFO, msg, args, **kwargs)
+
+ def warning(self, msg, *args, **kwargs):
+ """
+ Delegate a warning call to the underlying log.
+ """
+ if self.isEnabledFor(WARNING):
+ kwargs = self._add_rank_info(kwargs)
+ self._log(WARNING, msg, args, **kwargs)
+
+ def warning_once(self, msg, *args, **kwargs):
+ """
+ 相同的 warning 内容只会 warning 一次
+
+ :param msg:
+ :param args:
+ :param kwargs:
+ :return:
+ """
+ if msg not in self._warning_msgs:
+ if self.isEnabledFor(WARNING):
+ kwargs = self._add_rank_info(kwargs)
+ self._log(WARNING, msg, args, **kwargs)
+ self._warning_msgs.add(msg)
+
+ def rank_zero_warning(self, msg, *args, once=False, **kwargs):
+ """
+ 只在 rank 0 上 warning 。
+
+ :param msg:
+ :param args:
+ :param once: 是否只 warning 一次
+ :param kwargs:
+ :return:
+ """
+ if os.environ.get(FASTNLP_GLOBAL_RANK, '0') == '0':
+ if once:
+ if msg in self._warning_msgs:
+ return
+ self._warning_msgs.add(msg)
+
+ if self.isEnabledFor(WARNING):
+ kwargs = self._add_rank_info(kwargs)
+ self._log(WARNING, msg, args, **kwargs)
+
+ def warn(self, msg, *args, **kwargs):
+ if self.isEnabledFor(WARNING):
+ kwargs = self._add_rank_info(kwargs)
+ self._log(WARNING, msg, args, **kwargs)
+
+ def error(self, msg, *args, **kwargs):
+ """
+ Delegate an error call to the underlying log.
+ """
+ if self.isEnabledFor(ERROR):
+ kwargs = self._add_rank_info(kwargs)
+ self._log(ERROR, msg, args, **kwargs)
+
+ def exception(self, msg, *args, exc_info=True, **kwargs):
+ """
+ Delegate an exception call to the underlying log.
+ """
+ kwargs = self._add_rank_info(kwargs)
+ self.error(msg, *args, exc_info=exc_info, **kwargs)
+
+ def critical(self, msg, *args, **kwargs):
+ """
+ Delegate a critical call to the underlying log.
+ """
+ if self.isEnabledFor(CRITICAL):
+ kwargs = self._add_rank_info(kwargs)
+ self._log(CRITICAL, msg, args, **kwargs)
+
+ def log(self, level, msg, *args, **kwargs):
+ """
+ Delegate a log call to the underlying log, after adding
+ contextual information from this adapter instance.
+ """
+ if not isinstance(level, int):
+ if raiseExceptions:
+ raise TypeError("level must be an integer")
+ else:
+ return
+ if self.isEnabledFor(level):
+ kwargs = self._add_rank_info(kwargs)
+ self._log(level, msg, args, **kwargs)
+
+ def _add_rank_info(self, kwargs):
+ if is_cur_env_distributed():
+ extra = kwargs.get('extra', {})
+ extra.update({"rank": int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))})
+ kwargs["extra"] = extra
+ return kwargs
+
+ def setLevel(self, level) -> None:
+ """
+ 设置当前 logger 以及其 handler 的 log 级别
+
+ :param level:
+ :return:
+ """
+ if isinstance(level, str):
+ level = level.upper()
+ super().setLevel(level)
+ for handler in self.handlers:
+ handler.setLevel(level)
+
+ def _set_distributed(self):
+ """
+ 在 fastNLP 拉起进程的时候,调用一下这个方法,使得能够输出 rank 信息
+
+ :return:
+ """
+ for handler in self.handlers:
+ if isinstance(handler, logging.FileHandler):
+ formatter = logging.Formatter(fmt='Rank: %(rank)s - %(asctime)s - %(module)s - [%(levelname)s] - %(message)s',
+ datefmt='%Y/%m/%d %H:%M:%S')
+ else:
+ formatter = logging.Formatter('Rank: %(rank)s - %(message)s')
+ handler.setFormatter(formatter)
+
+
+def _get_level(level):
+ if not isinstance(level, int):
+ level = level.lower()
+ level = {'info': logging.INFO, 'debug': logging.DEBUG,
+ 'warn': logging.WARN, 'warning': logging.WARNING,
+ 'error': logging.ERROR}[level]
+ return level
+
+
+def _add_file_handler(_logger: logging.Logger, path: Optional[Union[str, Path]] = None, level: str = 'INFO',
+ remove_other_handlers: bool = False, mode: str = "w"):
+ if path is None:
+ path = Path.cwd()
+ if isinstance(path, str):
+ path = Path(path)
+ if not isinstance(path, Path):
+ raise TypeError("Parameter `path` can only be `str` or `pathlib.Path` type.")
+ if not path.exists():
+ head, tail = os.path.splitext(path)
+ if tail == '': # 说明没有后缀,理解为是一个folder
+ path.mkdir(parents=True, exist_ok=True)
+ else:
+ # 主进程会帮助我们创建文件夹,但是由于主从进程几乎是同步的,因此到这里时子进程也会尝试创建文件夹,即使主进程会做这件事情;
+ dirname = os.path.dirname(path)
+ os.makedirs(dirname, exist_ok=True)
+ if path.is_dir():
+ if os.environ.get(FASTNLP_BACKEND_LAUNCH, '0')== '1':
+ # 如果是通过 python -m xxx.launch 等启动的话,FASTNLP_LAUNCH_TIME这个名称可能是不一致的。
+ path = path.joinpath(f"RANK-{os.environ.get(FASTNLP_GLOBAL_RANK, '0')}-" +
+ os.environ.get(FASTNLP_LAUNCH_TIME) + '.log')
+ else:
+ path = path.joinpath(os.environ.get(FASTNLP_LAUNCH_TIME) + '.log')
+
+ if not isinstance(remove_other_handlers, bool):
+ raise TypeError("Parameter `remove_other_handlers` can only be `bool` type.")
+
+ if not isinstance(mode, str):
+ raise TypeError("Parameter 'evaluate_fn' can only be `str` type.")
+ if mode not in {"w", "a"}:
+ raise ValueError("Parameter `evaluate_fn` can only be one of these values: ('w', 'a').")
+
+ for h in _logger.handlers:
+ if isinstance(h, logging.FileHandler):
+ if os.path.abspath(path) == h.baseFilename:
+ # file path already added
+ return
+
+ # File Handler
+ if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
+ if os.path.exists(path):
+ assert os.path.isfile(path)
+ warnings.warn('log already exists in {}'.format(path))
+
+ dirname = os.path.abspath(os.path.dirname(path))
+ os.makedirs(dirname, exist_ok=True)
+
+ # 这里只要检测到是分布式训练,我们就将 evaluate_fn 改为 "a";这样会导致的一个问题在于,如果第二次训练也是分布式训练,logger记录的log不会重新
+ # 覆盖掉原文件,而是会接着上一次的 log 继续添加;
+ # 这样做主要是为了解决这样的情形所导致的问题:在分布式训练中,进程 1 比 进程 0 先运行到这里,然后使得进程 0 将进程 1 的 log 覆盖掉;
+ if is_cur_env_distributed():# and int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) != 0:
+ mode = "a"
+
+ file_handler = logging.FileHandler(path, mode=mode)
+ logger.info(f"Writing log to file:{os.path.abspath(path)}")
+ file_handler.setLevel(_get_level(level))
+
+ if is_cur_env_distributed():
+ file_formatter = logging.Formatter(fmt='Rank: %(rank)s - %(asctime)s - %(module)s - [%(levelname)s] - %(message)s',
+ datefmt='%Y/%m/%d %H:%M:%S')
+ else:
+ file_formatter = logging.Formatter(fmt='%(asctime)s - %(module)s - [%(levelname)s] - %(message)s',
+ datefmt='%Y/%m/%d %H:%M:%S')
+
+ file_handler.setFormatter(file_formatter)
+ _logger.addHandler(file_handler)
+
+ if remove_other_handlers:
+ _need_remove_handlers = []
+ for i, h in enumerate(_logger.handlers):
+ if not isinstance(h, logging.FileHandler):
+ _need_remove_handlers.append(h)
+ for handler in _need_remove_handlers:
+ _logger.removeHandler(handler)
+
+ return file_handler
+
+
+def _set_stdout_handler(_logger, stdout='raw', level='INFO'):
+ level = _get_level(level)
+ supported_stdout = ['none', 'raw', 'tqdm', 'naive', 'rich']
+ if stdout not in supported_stdout:
+ raise ValueError('stdout must in one of {}'.format(supported_stdout))
+ # make sure to initialize _logger only once
+ stream_handler = None
+ _handlers = (logging.StreamHandler, TqdmLoggingHandler, StdoutStreamHandler, RichHandler)
+ for i, h in enumerate(_logger.handlers):
+ if isinstance(h, _handlers):
+ stream_handler = h
+ break
+ if stream_handler is not None:
+ _logger.removeHandler(stream_handler)
+ del stream_handler
+
+ # Stream Handler
+ if stdout == 'raw':
+ stream_handler = StdoutStreamHandler()
+ elif stdout == 'rich':
+ stream_handler = RichHandler(level=level, log_time_format="[%X]")
+ elif stdout == 'naive':
+ stream_handler = logging.StreamHandler(sys.stdout)
+ elif stdout == 'tqdm':
+ stream_handler = TqdmLoggingHandler(level)
+ else:
+ stream_handler = None
+
+ if stream_handler is not None:
+ if is_cur_env_distributed():
+ stream_formatter = logging.Formatter('Rank: %(rank)s - %(message)s')
+ else:
+ stream_formatter = logging.Formatter('%(message)s')
+ stream_handler.setLevel(level)
+ stream_handler.setFormatter(stream_formatter)
+ _logger.addHandler(stream_handler)
+
+ return stream_handler
+
+
+def _init_logger(path=None, stdout='rich', level='INFO'):
+ r"""initialize _logger"""
+ level = _get_level(level)
+
+ logger = FastNLPLogger(ROOT_NAME)
+
+ logger.propagate = False
+
+ _set_stdout_handler(logger, stdout, level)
+
+ # File Handler
+ if path is not None:
+ _add_file_handler(logger, path, level)
+
+ logger.setLevel(level)
+
+ return logger
+
+
+def parse_level():
+ if os.environ[FASTNLP_LOG_LEVEL] == 'AUTO':
+ level = 'WARNING' if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) != 0 else "INFO"
+ else:
+ level = os.environ[FASTNLP_LOG_LEVEL]
+ return level
+
+
+logger = _init_logger(path=None, stdout='rich', level=parse_level())
+logger.debug("The environment variables are as following:")
+logger.debug(os.environ)
diff --git a/fastNLP/core/log/print.py b/fastNLP/core/log/print.py
new file mode 100644
index 00000000..49facdc4
--- /dev/null
+++ b/fastNLP/core/log/print.py
@@ -0,0 +1,27 @@
+__all__ = [
+ 'print'
+]
+from logging import INFO
+from .logger import logger
+
+
+def print(*args, sep=' ', end='\n', file=None, flush=False):
+ """
+ 用来重定向 print 函数至 logger.info 的函数。
+
+ Example::
+
+ from fastNLP import print
+ print("This is a test") # 等价于调用了 logger.info("This is a test")
+
+ :param args: 需要打印的内容
+ :param sep: 存在多个输入时,使用的间隔。
+ :param end: 该参数在当前设置无意义,因为结尾一定会被加入 ``'\\\\n'`` 。
+ :param file: 该参数无意义。
+ :param flush: 该参数无意义。
+ :return:
+ """
+ line = sep.join(map(str, args))
+ if logger.isEnabledFor(INFO):
+ kwargs = logger._add_rank_info({})
+ logger._log(INFO, line, None, **kwargs)
diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py
deleted file mode 100644
index 3bce8733..00000000
--- a/fastNLP/core/losses.py
+++ /dev/null
@@ -1,480 +0,0 @@
-r"""
-losses 模块定义了 fastNLP 中所需的各种损失函数,一般做为 :class:`~fastNLP.Trainer` 的参数使用。
-
-"""
-__all__ = [
- "LossBase",
-
- "LossFunc",
- "LossInForward",
-
- "CrossEntropyLoss",
- "BCELoss",
- "BCEWithLogits",
- "L1Loss",
- "NLLLoss",
- "MSELoss",
-
- "CMRC2018Loss"
-
-]
-
-import inspect
-from collections import defaultdict
-
-import torch
-import torch.nn.functional as F
-
-from .utils import _CheckError
-from .utils import _CheckRes
-from .utils import _build_args
-from .utils import _check_arg_dict_list
-from .utils import _check_function_or_method
-from .utils import _get_func_signature
-from .utils import seq_len_to_mask
-from ..core.const import Const
-
-
-class LossBase(object):
- r"""
- 所有loss的基类。如果需要结合到Trainer之中需要实现get_loss方法
- """
-
- def __init__(self):
- self._param_map = {} # key是fun的参数,value是以该值从传入的dict取出value
- self._checked = False
-
- @property
- def param_map(self):
- if len(self._param_map) == 0: # 如果为空说明还没有初始化
- func_spect = inspect.getfullargspec(self.get_loss)
- func_args = [arg for arg in func_spect.args if arg != 'self']
- for arg in func_args:
- self._param_map[arg] = arg
- return self._param_map
-
- def get_loss(self, *args, **kwargs):
- """
-
- :param args:
- :param kwargs:
- :return: torch.Tensor
- """
- raise NotImplementedError
-
- def _init_param_map(self, key_map=None, **kwargs):
- r"""检查key_map和其他参数map,并将这些映射关系添加到self._param_map
-
- :param dict key_map: 表示key的映射关系
- :param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系
- :return: None
- """
- value_counter = defaultdict(set)
- if key_map is not None:
- if not isinstance(key_map, dict):
- raise TypeError("key_map must be `dict`, got {}.".format(type(key_map)))
- for key, value in key_map.items():
- if value is None:
- self._param_map[key] = key
- continue
- if not isinstance(key, str):
- raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.")
- if not isinstance(value, str):
- raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.")
- self._param_map[key] = value
- value_counter[value].add(key)
- for key, value in kwargs.items():
- if value is None:
- self._param_map[key] = key
- continue
- if not isinstance(value, str):
- raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.")
- self._param_map[key] = value
- value_counter[value].add(key)
- for value, key_set in value_counter.items():
- if len(key_set) > 1:
- raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.")
-
- # check consistence between signature and _param_map
- func_spect = inspect.getfullargspec(self.get_loss)
- func_args = [arg for arg in func_spect.args if arg != 'self']
- for func_param, input_param in self._param_map.items():
- if func_param not in func_args:
- raise NameError(
- f"Parameter `{func_param}` is not in {_get_func_signature(self.get_loss)}. Please check the "
- f"initialization parameters, or change its signature.")
-
- # evaluate should not have varargs.
- # if func_spect.varargs:
- # raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use "
- # f"positional argument.).")
-
- def __call__(self, pred_dict, target_dict, check=False):
- r"""
- :param dict pred_dict: 模型的forward函数返回的dict
- :param dict target_dict: DataSet.batch_y里的键-值对所组成的dict
- :param Boolean check: 每一次执行映射函数的时候是否检查映射表,默认为不检查
- :return:
- """
-
- if not self._checked:
- # 1. check consistence between signature and _param_map
- func_spect = inspect.getfullargspec(self.get_loss)
- func_args = set([arg for arg in func_spect.args if arg != 'self'])
- for func_arg, input_arg in self._param_map.items():
- if func_arg not in func_args:
- raise NameError(f"`{func_arg}` not in {_get_func_signature(self.get_loss)}.")
-
- # 2. only part of the _param_map are passed, left are not
- for arg in func_args:
- if arg not in self._param_map:
- self._param_map[arg] = arg # This param does not need mapping.
- self._evaluate_args = func_args
- self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self._param_map.items()}
-
- mapped_pred_dict = {}
- mapped_target_dict = {}
- for input_arg, mapped_arg in self._reverse_param_map.items():
- if input_arg in pred_dict:
- mapped_pred_dict[mapped_arg] = pred_dict[input_arg]
- if input_arg in target_dict:
- mapped_target_dict[mapped_arg] = target_dict[input_arg]
-
- # missing
- if not self._checked:
- duplicated = []
- for input_arg, mapped_arg in self._reverse_param_map.items():
- if input_arg in pred_dict and input_arg in target_dict:
- duplicated.append(input_arg)
- check_res = _check_arg_dict_list(self.get_loss, [mapped_pred_dict, mapped_target_dict])
- # replace missing.
- missing = check_res.missing
- replaced_missing = list(missing)
- for idx, func_arg in enumerate(missing):
- # Don't delete `` in this information, nor add ``
- replaced_missing[idx] = f"{self._param_map[func_arg]}" + f"(assign to `{func_arg}` " \
- f"in `{self.__class__.__name__}`)"
-
- check_res = _CheckRes(missing=replaced_missing,
- unused=check_res.unused,
- duplicated=duplicated,
- required=check_res.required,
- all_needed=check_res.all_needed,
- varargs=check_res.varargs)
-
- if check_res.missing or check_res.duplicated:
- raise _CheckError(check_res=check_res,
- func_signature=_get_func_signature(self.get_loss))
- self._checked = True
-
- refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict)
-
- loss = self.get_loss(**refined_args)
- self._checked = True
-
- return loss
-
-
-class LossFunc(LossBase):
- r"""
- 提供给用户使用自定义损失函数的类
-
- :param func: 用户自行定义的损失函数,应当为一个函数。
- :param dict key_map: 参数映射表。键为Model/DataSet参数名,值为损失函数参数名。
- fastNLP的trainer将在训练时从模型返回值或者训练数据DataSet的target=True的field中
- 找到相对应的参数名为value的参数,并传入func中作为参数名为key的参数
- :param kwargs: 除了参数映射表以外可以用key word args的方式设置参数映射关系
-
- 使用方法::
-
- import torch.nn.functional as F
- loss_func = LossFunc(F.cross_entropy, input="pred", target="label")
- # 这表示构建了一个损失函数类,由func计算损失函数,其中将从模型返回值或者DataSet的target=True的field
- # 当中找到一个参数名为`pred`的参数传入func一个参数名为`input`的参数;找到一个参数名为`label`的参数
- # 传入func作为一个名为`target`的参数
-
- """
-
- def __init__(self, func, key_map=None, **kwargs):
-
- super(LossFunc, self).__init__()
- _check_function_or_method(func)
- self.get_loss = func
- if key_map is not None:
- if not isinstance(key_map, dict):
- raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}")
- self._init_param_map(key_map, **kwargs)
-
-
-class CrossEntropyLoss(LossBase):
- r"""
- 交叉熵损失函数
-
- :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
- :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
- :param seq_len: 句子的长度, 长度之外的token不会计算loss。
- :param int class_in_dim: 在序列标注的场景中,pred可能的shape为(batch_size, max_len, num_classes)
- 或(batch_size, num_classes, max_len), CrossEntropyLoss需要知道哪一维是class的维度以计算loss。如果为-1,就根据pred的第
- 二维是否等于target的第二维来判断是否需要交换pred的第二维和第三维,因为target的第二维是length的维度,如果这一维度上和pred相等,
- 那么pred可能第二维也是长度维(存在误判的可能,如果有误判的情况,请显示设置该值)。其它大于0的值则认为该维度是class的维度。
- :param ignore_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替
- 传入seq_len.
- :param str reduction: 支持 `mean` ,`sum` 和 `none` .
-
- Example::
-
- loss = CrossEntropyLoss(pred='pred', target='label', padding_idx=0)
-
- """
-
- def __init__(self, pred=None, target=None, seq_len=None, class_in_dim=-1, ignore_idx=-100, reduction='mean', **kwargs):
- super(CrossEntropyLoss, self).__init__()
- self._init_param_map(pred=pred, target=target, seq_len=seq_len)
- ignore_idx = kwargs.pop('padding_idx', ignore_idx)
- self.ignore_idx = ignore_idx
- assert reduction in ('mean', 'sum', 'none')
- self.reduction = reduction
- self.class_in_dim = class_in_dim
-
- def get_loss(self, pred, target, seq_len=None):
- if seq_len is not None and target.dim()>1:
- mask = seq_len_to_mask(seq_len, max_len=target.size(1)).eq(False)
- target = target.masked_fill(mask, self.ignore_idx)
-
- if pred.dim() > 2:
- if self.class_in_dim == -1:
- if pred.size(1) != target.size(1): # 有可能顺序替换了
- pred = pred.transpose(1, 2)
- else:
- pred = pred.transpose(-1, self.class_in_dim)
- pred = pred.reshape(-1, pred.size(-1))
- target = target.reshape(-1)
-
- return F.cross_entropy(input=pred, target=target,
- ignore_index=self.ignore_idx, reduction=self.reduction)
-
-
-class L1Loss(LossBase):
- r"""
- L1损失函数
-
- :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
- :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` >`target`
- :param str reduction: 支持'mean','sum'和'none'.
-
- """
-
- def __init__(self, pred=None, target=None, reduction='mean'):
- super(L1Loss, self).__init__()
- self._init_param_map(pred=pred, target=target)
- assert reduction in ('mean', 'sum', 'none')
- self.reduction = reduction
-
- def get_loss(self, pred, target):
- return F.l1_loss(input=pred, target=target, reduction=self.reduction)
-
-
-class MSELoss(LossBase):
- r"""
- MSE损失函数
-
- :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
- :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` >`target`
- :param str reduction: 支持'mean','sum'和'none'.
-
- """
-
- def __init__(self, pred=None, target=None, reduction='mean'):
- super(MSELoss, self).__init__()
- self._init_param_map(pred=pred, target=target)
- assert reduction in ('mean', 'sum', 'none')
- self.reduction = reduction
-
- def get_loss(self, pred, target):
- return F.mse_loss(input=pred, target=target, reduction=self.reduction)
-
-
-class BCELoss(LossBase):
- r"""
- 二分类交叉熵损失函数
-
- :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
- :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
- :param str reduction: 支持 `mean` ,`sum` 和 `none` .
- """
-
- def __init__(self, pred=None, target=None, reduction='mean'):
- super(BCELoss, self).__init__()
- self._init_param_map(pred=pred, target=target)
- assert reduction in ('mean', 'sum', 'none')
- self.reduction = reduction
-
- def get_loss(self, pred, target):
- return F.binary_cross_entropy(input=pred, target=target, reduction=self.reduction)
-
-
-class BCEWithLogits(LossBase):
- r"""
- 二分类交叉熵损失函数, 传入数据之前不需要做sigmoid操作
-
- :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
- :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
- :param int class_in_dim: 在序列标注的场景中,pred可能的shape为(batch_size, max_len, num_classes)
- 或(batch_size, num_classes, max_len), BCEWithLogits需要知道哪一维是class的维度以计算loss。如果为-1,就根据pred的第
- 二维是否等于target的第二维来判断是否需要交换pred的第二维和第三维,因为target的第二维是length的维度,如果这一维度上和pred相等,
- 那么pred可能第二维也是长度维(存在误判的可能,如果有误判的情况,请显示设置该值)。其它大于0的值则认为该维度是class的维度。
- :param str reduction: 支持 `mean` ,`sum` 和 `none` .
- """
-
- def __init__(self, pred=None, target=None, class_in_dim=-1, reduction='mean'):
- super(BCEWithLogits, self).__init__()
- self._init_param_map(pred=pred, target=target)
- assert reduction in ('mean', 'sum', 'none')
- self.reduction = reduction
- self.class_in_dim = class_in_dim
-
- def get_loss(self, pred, target):
- if pred.dim() > 2:
- if self.class_in_dim == -1:
- if pred.size(1) != target.size(1): # 有可能顺序替换了
- pred = pred.transpose(1, 2)
- else:
- pred = pred.transpose(-1, self.class_in_dim)
- pred = pred.reshape(-1)
- target = target.reshape(-1)
-
- return F.binary_cross_entropy_with_logits(input=pred, target=target, reduction=self.reduction)
-
-
-class NLLLoss(LossBase):
- r"""
- 负对数似然损失函数
- """
-
- def __init__(self, pred=None, target=None, seq_len=None, class_in_dim=-1, ignore_idx=-100, reduction='mean'):
- r"""
-
- :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
- :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
- :param seq_len: 句子的长度, 长度之外的token不会计算loss。仅在输出为3d时需要
- :param int class_in_dim: 在序列标注的场景中,pred可能的shape为(batch_size, max_len, num_classes)
- 或(batch_size, num_classes, max_len), CrossEntropyLoss需要知道哪一维是class的维度以计算loss。如果为-1,就根据pred的第
- 二维是否等于target的第二维来判断是否需要交换pred的第二维和第三维,因为target的第二维是length的维度,如果这一维度上和pred相等,
- 那么pred可能第二维也是长度维(存在误判的可能,如果有误判的情况,请显示设置该值)。其它大于0的值则认为该维度是class的维度。
- :param ignore_idx: ignore的index,在计算loss时将忽略target中标号为ignore_idx的内容, 可以通过该值代替
- 传入seq_len.
- :param str reduction: 支持 `mean` ,`sum` 和 `none` .
- """
- super(NLLLoss, self).__init__()
- self._init_param_map(pred=pred, target=target, seq_len=seq_len)
- assert reduction in ('mean', 'sum', 'none')
- self.reduction = reduction
- self.ignore_idx = ignore_idx
- self.class_in_dim = class_in_dim
-
- def get_loss(self, pred, target, seq_len=None):
- if seq_len is not None and target.dim()>1:
- mask = seq_len_to_mask(seq_len, max_len=target.size(1)).eq(False)
- target = target.masked_fill(mask, self.ignore_idx)
-
- if pred.dim() > 2:
- if self.class_in_dim == -1:
- if pred.size(1) != target.size(1): # 有可能顺序替换了
- pred = pred.transpose(1, 2)
- else:
- pred = pred.transpose(-1, self.class_in_dim)
- pred = pred.reshape(-1, pred.size(-1))
- target = target.reshape(-1)
-
- return F.nll_loss(input=pred, target=target, ignore_index=self.ignore_idx, reduction=self.reduction)
-
-
-class LossInForward(LossBase):
- r"""
- 从forward()函数返回结果中获取loss
- """
-
- def __init__(self, loss_key=Const.LOSS):
- r"""
-
- :param str loss_key: 在forward函数中loss的键名,默认为loss
- """
- super().__init__()
- if not isinstance(loss_key, str):
- raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.")
- self.loss_key = loss_key
-
- def get_loss(self, **kwargs):
- if self.loss_key not in kwargs:
- check_res = _CheckRes(
- missing=[self.loss_key + f"(assign to `{self.loss_key}` in `{self.__class__.__name__}`"],
- unused=[],
- duplicated=[],
- required=[],
- all_needed=[],
- varargs=[])
- raise _CheckError(check_res=check_res, func_signature=_get_func_signature(self.get_loss))
- return kwargs[self.loss_key]
-
- def __call__(self, pred_dict, target_dict, check=False):
-
- loss = self.get_loss(**pred_dict)
-
- if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0):
- if not isinstance(loss, torch.Tensor):
- raise TypeError(f"Loss excepted to be a torch.Tensor, got {type(loss)}")
- loss = torch.sum(loss) / (loss.view(-1)).size(0)
- # raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}")
-
- return loss
-
-
-class CMRC2018Loss(LossBase):
- r"""
- 用于计算CMRC2018中文问答任务。
-
- """
- def __init__(self, target_start=None, target_end=None, context_len=None, pred_start=None, pred_end=None,
- reduction='mean'):
- super().__init__()
-
- assert reduction in ('mean', 'sum')
-
- self._init_param_map(target_start=target_start, target_end=target_end, context_len=context_len,
- pred_start=pred_start, pred_end=pred_end)
- self.reduction = reduction
-
- def get_loss(self, target_start, target_end, context_len, pred_start, pred_end):
- r"""
-
- :param target_start: batch_size
- :param target_end: batch_size
- :param context_len: batch_size
- :param pred_start: batch_size x max_len
- :param pred_end: batch_size x max_len
- :return:
- """
- batch_size, max_len = pred_end.size()
- mask = seq_len_to_mask(context_len, max_len).eq(False)
-
- pred_start = pred_start.masked_fill(mask, float('-inf'))
- pred_end = pred_end.masked_fill(mask, float('-inf'))
-
- start_loss = F.cross_entropy(pred_start, target_start, reduction='sum')
- end_loss = F.cross_entropy(pred_end, target_end, reduction='sum')
-
- loss = start_loss + end_loss
-
- if self.reduction == 'mean':
- loss = loss / batch_size
-
- return loss/2
-
-def _prepare_losser(losser):
- if losser is None:
- losser = LossInForward()
- return losser
- elif isinstance(losser, LossBase):
- return losser
- else:
- raise TypeError(f"Type of loss should be `fastNLP.LossBase`, got {type(losser)}")
diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py
deleted file mode 100644
index 31f69cb9..00000000
--- a/fastNLP/core/metrics.py
+++ /dev/null
@@ -1,1246 +0,0 @@
-r"""
-metrics 模块实现了 fastNLP 所需的各种常用衡量指标,一般做为 :class:`~fastNLP.Trainer` 的参数使用。
-
-"""
-__all__ = [
- "MetricBase",
- "AccuracyMetric",
- "SpanFPreRecMetric",
- "CMRC2018Metric",
- "ClassifyFPreRecMetric",
- "ConfusionMatrixMetric"
-]
-
-import inspect
-import warnings
-from abc import abstractmethod
-from collections import defaultdict
-from typing import Union
-from copy import deepcopy
-import re
-
-import numpy as np
-import torch
-
-from .utils import _CheckError
-from .utils import _CheckRes
-from .utils import _build_args
-from .utils import _check_arg_dict_list
-from .utils import _get_func_signature
-from .utils import seq_len_to_mask
-from .vocabulary import Vocabulary
-from .utils import ConfusionMatrix
-
-
-class MetricBase(object):
- r"""
- 所有metrics的基类,所有的传入到Trainer, Tester的Metric需要继承自该对象,需要覆盖写入evaluate(), get_metric()方法。
-
- evaluate(xxx)中传入的是一个batch的数据。
-
- get_metric(xxx)当所有数据处理完毕,调用该方法得到最终的metric值
-
- 以分类问题中,Accuracy计算为例
- 假设model的forward返回dict中包含 `pred` 这个key, 并且该key需要用于Accuracy::
-
- class Model(nn.Module):
- def __init__(xxx):
- # do something
- def forward(self, xxx):
- # do something
- return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes
-
- 假设dataset中 `label` 这个field是需要预测的值,并且该field被设置为了target
- 对应的AccMetric可以按如下的定义, version1, 只使用这一次::
-
- class AccMetric(MetricBase):
- def __init__(self):
- super().__init__()
-
- # 根据你的情况自定义指标
- self.corr_num = 0
- self.total = 0
-
- def evaluate(self, label, pred): # 这里的名称需要和dataset中target field与model返回的key是一样的,不然找不到对应的value
- # dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric
- self.total += label.size(0)
- self.corr_num += label.eq(pred).sum().item()
-
- def get_metric(self, reset=True): # 在这里定义如何计算metric
- acc = self.corr_num/self.total
- if reset: # 是否清零以便重新计算
- self.corr_num = 0
- self.total = 0
- return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中
-
-
- version2,如果需要复用Metric,比如下一次使用AccMetric时,dataset中目标field不叫label而叫y,或者model的输出不是pred::
-
- class AccMetric(MetricBase):
- def __init__(self, label=None, pred=None):
- # 假设在另一场景使用时,目标field叫y,model给出的key为pred_y。则只需要在初始化AccMetric时,
- # acc_metric = AccMetric(label='y', pred='pred_y')即可。
- # 当初始化为acc_metric = AccMetric(),即label=None, pred=None, fastNLP会直接使用'label', 'pred'作为key去索取对
- # 应的的值
- super().__init__()
- self._init_param_map(label=label, pred=pred) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可
- # 如果没有注册该则效果与version1就是一样的
-
- # 根据你的情况自定义指标
- self.corr_num = 0
- self.total = 0
-
- def evaluate(self, label, pred): # 这里的参数名称需要和self._init_param_map()注册时一致。
- # dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric
- self.total += label.size(0)
- self.corr_num += label.eq(pred).sum().item()
-
- def get_metric(self, reset=True): # 在这里定义如何计算metric
- acc = self.corr_num/self.total
- if reset: # 是否清零以便重新计算
- self.corr_num = 0
- self.total = 0
- return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中
-
-
- ``MetricBase`` 将会在输入的字典 ``pred_dict`` 和 ``target_dict`` 中进行检查.
- ``pred_dict`` 是模型当中 ``forward()`` 函数或者 ``predict()`` 函数的返回值.
- ``target_dict`` 是DataSet当中的ground truth, 判定ground truth的条件是field的 ``is_target`` 被设置为True.
-
- ``MetricBase`` 会进行以下的类型检测:
-
- 1. self.evaluate当中是否有varargs, 这是不支持的.
- 2. self.evaluate当中所需要的参数是否既不在 ``pred_dict`` 也不在 ``target_dict`` .
- 3. self.evaluate当中所需要的参数是否既在 ``pred_dict`` 也在 ``target_dict`` .
-
- 除此以外,在参数被传入self.evaluate以前,这个函数会检测 ``pred_dict`` 和 ``target_dict`` 当中没有被用到的参数
- 如果kwargs是self.evaluate的参数,则不会检测
-
-
- self.evaluate将计算一个批次(batch)的评价指标,并累计。 没有返回值
- self.get_metric将统计当前的评价指标并返回评价结果, 返回值需要是一个dict, key是指标名称,value是指标的值
-
- """
-
- def __init__(self):
- self._param_map = {} # key is param in function, value is input param.
- self._checked = False
- self._metric_name = self.__class__.__name__
-
- @property
- def param_map(self):
- if len(self._param_map) == 0: # 如果为空说明还没有初始化
- func_spect = inspect.getfullargspec(self.evaluate)
- func_args = [arg for arg in func_spect.args if arg != 'self']
- for arg in func_args:
- self._param_map[arg] = arg
- return self._param_map
-
- @abstractmethod
- def evaluate(self, *args, **kwargs):
- raise NotImplementedError
-
- @abstractmethod
- def get_metric(self, reset=True):
- raise NotImplemented
-
- def set_metric_name(self, name: str):
- r"""
- 设置metric的名称,默认是Metric的class name.
-
- :param str name:
- :return: self
- """
- self._metric_name = name
- return self
-
- def get_metric_name(self):
- r"""
- 返回metric的名称
-
- :return:
- """
- return self._metric_name
-
- def _init_param_map(self, key_map=None, **kwargs):
- r"""检查key_map和其他参数map,并将这些映射关系添加到self._param_map
-
- :param dict key_map: 表示key的映射关系
- :param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系
- :return: None
- """
- value_counter = defaultdict(set)
- if key_map is not None:
- if not isinstance(key_map, dict):
- raise TypeError("key_map must be `dict`, got {}.".format(type(key_map)))
- for key, value in key_map.items():
- if value is None:
- self._param_map[key] = key
- continue
- if not isinstance(key, str):
- raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.")
- if not isinstance(value, str):
- raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.")
- self._param_map[key] = value
- value_counter[value].add(key)
- for key, value in kwargs.items():
- if value is None:
- self._param_map[key] = key
- continue
- if not isinstance(value, str):
- raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.")
- self._param_map[key] = value
- value_counter[value].add(key)
- for value, key_set in value_counter.items():
- if len(key_set) > 1:
- raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.")
-
- # check consistence between signature and _param_map
- func_spect = inspect.getfullargspec(self.evaluate)
- func_args = [arg for arg in func_spect.args if arg != 'self']
- for func_param, input_param in self._param_map.items():
- if func_param not in func_args:
- raise NameError(
- f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the "
- f"initialization parameters, or change its signature.")
-
- def __call__(self, pred_dict, target_dict):
- r"""
- 这个方法会调用self.evaluate 方法.
- 在调用之前,会进行以下检测:
- 1. self.evaluate当中是否有varargs, 这是不支持的.
- 2. self.evaluate当中所需要的参数是否既不在``pred_dict``也不在``target_dict``.
- 3. self.evaluate当中所需要的参数是否既在``pred_dict``也在``target_dict``.
-
- 除此以外,在参数被传入self.evaluate以前,这个函数会检测``pred_dict``和``target_dict``当中没有被用到的参数
- 如果kwargs是self.evaluate的参数,则不会检测
- :param pred_dict: 模型的forward函数或者predict函数返回的dict
- :param target_dict: DataSet.batch_y里的键-值对所组成的dict(即is_target=True的fields的内容)
- :return:
- """
-
- if not self._checked:
- if not callable(self.evaluate):
- raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.")
- # 1. check consistence between signature and _param_map
- func_spect = inspect.getfullargspec(self.evaluate)
- func_args = set([arg for arg in func_spect.args if arg != 'self'])
- for func_arg, input_arg in self._param_map.items():
- if func_arg not in func_args:
- raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.")
-
- # 2. only part of the _param_map are passed, left are not
- for arg in func_args:
- if arg not in self._param_map:
- self._param_map[arg] = arg # This param does not need mapping.
- self._evaluate_args = func_args
- self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self._param_map.items()}
-
- # need to wrap inputs in dict.
- mapped_pred_dict = {}
- mapped_target_dict = {}
- for input_arg, mapped_arg in self._reverse_param_map.items():
- if input_arg in pred_dict:
- mapped_pred_dict[mapped_arg] = pred_dict[input_arg]
- if input_arg in target_dict:
- mapped_target_dict[mapped_arg] = target_dict[input_arg]
-
- # missing
- if not self._checked:
- duplicated = []
- for input_arg, mapped_arg in self._reverse_param_map.items():
- if input_arg in pred_dict and input_arg in target_dict:
- duplicated.append(input_arg)
- check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict])
- # only check missing.
- # replace missing.
- missing = check_res.missing
- replaced_missing = list(missing)
- for idx, func_arg in enumerate(missing):
- # Don't delete `` in this information, nor add ``
- replaced_missing[idx] = f"{self._param_map[func_arg]}" + f"(assign to `{func_arg}` " \
- f"in `{self.__class__.__name__}`)"
-
- check_res = _CheckRes(missing=replaced_missing,
- unused=check_res.unused,
- duplicated=duplicated,
- required=check_res.required,
- all_needed=check_res.all_needed,
- varargs=check_res.varargs)
-
- if check_res.missing or check_res.duplicated:
- raise _CheckError(check_res=check_res,
- func_signature=_get_func_signature(self.evaluate))
- self._checked = True
- refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict)
-
- self.evaluate(**refined_args)
-
- return
-
-
-class ConfusionMatrixMetric(MetricBase):
- r"""
- 分类问题计算混淆矩阵的Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` )
- 最后返回结果为::
-
- dict,{'confusion_matrix': ConfusionMatrix实例}
-
- ConfusionMatrix实例的print()函数将输出矩阵字符串。
-
- .. code ::
-
- pred_dict = {"pred": torch.Tensor([2,1,3])}
- target_dict = {'target': torch.Tensor([2,2,1])}
- metric = ConfusionMatrixMetric()
- metric(pred_dict=pred_dict, target_dict=target_dict, )
- print(metric.get_metric())
-
- .. code ::
-
- {'confusion_matrix':
- target 1.0 2.0 3.0 all
- pred
- 1.0 0 1 0 1
- 2.0 0 1 0 1
- 3.0 1 0 0 1
- all 1 2 0 3
- }
-
- """
- def __init__(self,
- vocab=None,
- pred=None,
- target=None,
- seq_len=None,
- print_ratio=False
- ):
- r"""
- :param vocab: vocab词表类,要求有to_word()方法。
- :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
- :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
- :param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len`
- :param print_ratio: 限制print的输出,false only for result, true for result, percent(dim=0), percent(dim = 1)
- """
- super().__init__()
- self._init_param_map(pred=pred, target=target, seq_len=seq_len)
- self.confusion_matrix = ConfusionMatrix(
- vocab=vocab,
- print_ratio=print_ratio,
- )
-
- def evaluate(self, pred, target, seq_len=None):
- r"""
- evaluate函数将针对一个批次的预测结果做评价指标的累计
-
- :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
- torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes])
- :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),
- torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len])
- :param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, torch.Size([B]), 或者torch.Size([B]).
- """
- if not isinstance(pred, torch.Tensor):
- raise TypeError(
- f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
- f"got {type(pred)}.")
- if not isinstance(target, torch.Tensor):
- raise TypeError(
- f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
- f"got {type(target)}.")
-
- if seq_len is not None and not isinstance(seq_len, torch.Tensor):
- raise TypeError(
- f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
- f"got {type(seq_len)}.")
-
- if pred.dim() == target.dim():
- if torch.numel(pred) !=torch.numel(target):
- raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have same dimensions with target, they should have same element numbers. while target have "
- f"element numbers:{torch.numel(target)}, pred have element numbers: {torch.numel(pred)}")
-
- pass
- elif pred.dim() == target.dim() + 1:
- pred = pred.argmax(dim=-1)
- if seq_len is None and target.dim() > 1:
- warnings.warn("You are not passing `seq_len` to exclude pad.")
- else:
- raise RuntimeError(
- f"In {_get_func_signature(self.evaluate)}, when pred have "
- f"size:{pred.size()}, target should have size: {pred.size()} or "
- f"{pred.size()[:-1]}, got {target.size()}.")
-
- target = target.to(pred)
- if seq_len is not None and target.dim() > 1:
- for p, t, l in zip(pred.tolist(), target.tolist(),
- seq_len.tolist()):
- l = int(l)
- self.confusion_matrix.add_pred_target(p[:l], t[:l])
- elif target.dim() > 1: #对于没有传入seq_len,但是又是高维的target,按全长输出
- for p, t in zip(pred.tolist(), target.tolist()):
- self.confusion_matrix.add_pred_target(p, t)
- else:
- self.confusion_matrix.add_pred_target(pred.tolist(),
- target.tolist())
-
- def get_metric(self, reset=True):
- r"""
- get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.
- :param bool reset: 在调用完get_metric后是否清空评价指标统计量.
- :return dict evaluate_result: {"confusion_matrix": ConfusionMatrix}
- """
- confusion = {'confusion_matrix': deepcopy(self.confusion_matrix)}
- if reset:
- self.confusion_matrix.clear()
- return confusion
-
-
-
-
-
-class AccuracyMetric(MetricBase):
- r"""
- 准确率Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` )
- """
-
- def __init__(self, pred=None, target=None, seq_len=None):
- r"""
-
- :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
- :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
- :param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len`
- """
-
- super().__init__()
-
- self._init_param_map(pred=pred, target=target, seq_len=seq_len)
-
- self.total = 0
- self.acc_count = 0
-
- def evaluate(self, pred, target, seq_len=None):
- r"""
- evaluate函数将针对一个批次的预测结果做评价指标的累计
-
- :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
- torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes])
- :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),
- torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len])
- :param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]).
- 如果mask也被传进来的话seq_len会被忽略.
-
- """
- # TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value
- if not isinstance(pred, torch.Tensor):
- raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
- f"got {type(pred)}.")
- if not isinstance(target, torch.Tensor):
- raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
- f"got {type(target)}.")
-
- if seq_len is not None and not isinstance(seq_len, torch.Tensor):
- raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
- f"got {type(seq_len)}.")
-
- if seq_len is not None and target.dim() > 1:
- max_len = target.size(1)
- masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len)
- else:
- masks = None
-
- if pred.dim() == target.dim():
- if torch.numel(pred) !=torch.numel(target):
- raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have same dimensions with target, they should have same element numbers. while target have "
- f"element numbers:{torch.numel(target)}, pred have element numbers: {torch.numel(pred)}")
-
- pass
- elif pred.dim() == target.dim() + 1:
- pred = pred.argmax(dim=-1)
- if seq_len is None and target.dim() > 1:
- warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")
- else:
- raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have "
- f"size:{pred.size()}, target should have size: {pred.size()} or "
- f"{pred.size()[:-1]}, got {target.size()}.")
-
- target = target.to(pred)
- if masks is not None:
- self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks.eq(False), 0)).item()
- self.total += torch.sum(masks).item()
- else:
- self.acc_count += torch.sum(torch.eq(pred, target)).item()
- self.total += np.prod(list(pred.size()))
-
- def get_metric(self, reset=True):
- r"""
- get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.
-
- :param bool reset: 在调用完get_metric后是否清空评价指标统计量.
- :return dict evaluate_result: {"acc": float}
- """
- evaluate_result = {'acc': round(float(self.acc_count) / (self.total + 1e-12), 6)}
- if reset:
- self.acc_count = 0
- self.total = 0
- return evaluate_result
-
-class ClassifyFPreRecMetric(MetricBase):
- r"""
- 分类问题计算FPR值的Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` )
-
- 最后得到的metric结果为::
-
- {
- 'f': xxx, # 这里使用f考虑以后可以计算f_beta值
- 'pre': xxx,
- 'rec':xxx
- }
-
- 若only_gross=False, 即还会返回各个label的metric统计值::
-
- {
- 'f': xxx,
- 'pre': xxx,
- 'rec':xxx,
- 'f-label': xxx,
- 'pre-label': xxx,
- 'rec-label':xxx,
- ...
- }
-
- """
-
- def __init__(self, tag_vocab=None, pred=None, target=None, seq_len=None, ignore_labels=None,
- only_gross=True, f_type='micro', beta=1):
- r"""
-
- :param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` . 默认值为None。若为None则使用数字来作为标签内容,否则使用vocab来作为标签内容。
- :param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据
- :param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据
- :param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。
- :param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'个label
- :param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个label的f1, pre, rec
- :param str f_type: `micro` 或 `macro` . `micro` :通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; `macro` : 分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同)
- :param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . 常用为 `beta=0.5, 1, 2` 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。
- """
-
- if tag_vocab:
- if not isinstance(tag_vocab, Vocabulary):
- raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab)))
- if f_type not in ('micro', 'macro'):
- raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type))
-
- self.ignore_labels = ignore_labels
- self.f_type = f_type
- self.beta = beta
- self.beta_square = self.beta ** 2
- self.only_gross = only_gross
-
- super().__init__()
- self._init_param_map(pred=pred, target=target, seq_len=seq_len)
-
- self.tag_vocab = tag_vocab
-
- self._tp, self._fp, self._fn = defaultdict(int), defaultdict(int), defaultdict(int)
- # tp: truth=T, classify=T; fp: truth=T, classify=F; fn: truth=F, classify=T
-
- def evaluate(self, pred, target, seq_len=None):
- r"""
- evaluate函数将针对一个批次的预测结果做评价指标的累计
-
- :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
- torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes])
- :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),
- torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len])
- :param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]).
- 如果mask也被传进来的话seq_len会被忽略.
-
- """
- # TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value
- if not isinstance(pred, torch.Tensor):
- raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
- f"got {type(pred)}.")
- if not isinstance(target, torch.Tensor):
- raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
- f"got {type(target)}.")
-
- if seq_len is not None and not isinstance(seq_len, torch.Tensor):
- raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
- f"got {type(seq_len)}.")
-
- if seq_len is not None and target.dim() > 1:
- max_len = target.size(1)
- masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len)
- else:
- masks = torch.ones_like(target).long().to(target.device)
-
- masks = masks.eq(1)
-
- if pred.dim() == target.dim():
- if torch.numel(pred) !=torch.numel(target):
- raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have same dimensions with target, they should have same element numbers. while target have "
- f"element numbers:{torch.numel(target)}, pred have element numbers: {torch.numel(pred)}")
-
- pass
- elif pred.dim() == target.dim() + 1:
- pred = pred.argmax(dim=-1)
- if seq_len is None and target.dim() > 1:
- warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")
- else:
- raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have "
- f"size:{pred.size()}, target should have size: {pred.size()} or "
- f"{pred.size()[:-1]}, got {target.size()}.")
-
- target = target.to(pred)
- target = target.masked_select(masks)
- pred = pred.masked_select(masks)
- target_idxes = set(target.reshape(-1).tolist())
- for target_idx in target_idxes:
- self._tp[target_idx] += torch.sum((pred == target_idx).long().masked_fill(target != target_idx, 0)).item()
- self._fp[target_idx] += torch.sum((pred == target_idx).long().masked_fill(target == target_idx, 0)).item()
- self._fn[target_idx] += torch.sum((pred != target_idx).long().masked_fill(target != target_idx, 0)).item()
-
- def get_metric(self, reset=True):
- r"""
- get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.
-
- :param bool reset: 在调用完get_metric后是否清空评价指标统计量.
- :return dict evaluate_result: {"acc": float}
- """
- evaluate_result = {}
- if not self.only_gross or self.f_type == 'macro':
- tags = set(self._fn.keys())
- tags.update(set(self._fp.keys()))
- tags.update(set(self._tp.keys()))
- f_sum = 0
- pre_sum = 0
- rec_sum = 0
- for tag in tags:
- if self.tag_vocab is not None:
- tag_name = self.tag_vocab.to_word(tag)
- else:
- tag_name = int(tag)
- tp = self._tp[tag]
- fn = self._fn[tag]
- fp = self._fp[tag]
- f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp)
- f_sum += f
- pre_sum += pre
- rec_sum += rec
- if not self.only_gross and tag != '': # tag!=''防止无tag的情况
- f_key = 'f-{}'.format(tag_name)
- pre_key = 'pre-{}'.format(tag_name)
- rec_key = 'rec-{}'.format(tag_name)
- evaluate_result[f_key] = f
- evaluate_result[pre_key] = pre
- evaluate_result[rec_key] = rec
-
- if self.f_type == 'macro':
- evaluate_result['f'] = f_sum / len(tags)
- evaluate_result['pre'] = pre_sum / len(tags)
- evaluate_result['rec'] = rec_sum / len(tags)
-
- if self.f_type == 'micro':
- f, pre, rec = _compute_f_pre_rec(self.beta_square,
- sum(self._tp.values()),
- sum(self._fn.values()),
- sum(self._fp.values()))
- evaluate_result['f'] = f
- evaluate_result['pre'] = pre
- evaluate_result['rec'] = rec
-
- if reset:
- self._tp = defaultdict(int)
- self._fp = defaultdict(int)
- self._fn = defaultdict(int)
-
- for key, value in evaluate_result.items():
- evaluate_result[key] = round(value, 6)
-
- return evaluate_result
-
-
-def _bmes_tag_to_spans(tags, ignore_labels=None):
- r"""
- 给定一个tags的lis,比如['S-song', 'B-singer', 'M-singer', 'E-singer', 'S-moive', 'S-actor']。
- 返回[('song', (0, 1)), ('singer', (1, 4)), ('moive', (4, 5)), ('actor', (5, 6))] (左闭右开区间)
- 也可以是单纯的['S', 'B', 'M', 'E', 'B', 'M', 'M',...]序列
-
- :param tags: List[str],
- :param ignore_labels: List[str], 在该list中的label将被忽略
- :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
- """
- ignore_labels = set(ignore_labels) if ignore_labels else set()
-
- spans = []
- prev_bmes_tag = None
- for idx, tag in enumerate(tags):
- tag = tag.lower()
- bmes_tag, label = tag[:1], tag[2:]
- if bmes_tag in ('b', 's'):
- spans.append((label, [idx, idx]))
- elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]:
- spans[-1][1][1] = idx
- else:
- spans.append((label, [idx, idx]))
- prev_bmes_tag = bmes_tag
- return [(span[0], (span[1][0], span[1][1] + 1))
- for span in spans
- if span[0] not in ignore_labels
- ]
-
-
-def _bmeso_tag_to_spans(tags, ignore_labels=None):
- r"""
- 给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。
- 返回[('singer', (1, 4))] (左闭右开区间)
-
- :param tags: List[str],
- :param ignore_labels: List[str], 在该list中的label将被忽略
- :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
- """
- ignore_labels = set(ignore_labels) if ignore_labels else set()
-
- spans = []
- prev_bmes_tag = None
- for idx, tag in enumerate(tags):
- tag = tag.lower()
- bmes_tag, label = tag[:1], tag[2:]
- if bmes_tag in ('b', 's'):
- spans.append((label, [idx, idx]))
- elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]:
- spans[-1][1][1] = idx
- elif bmes_tag == 'o':
- pass
- else:
- spans.append((label, [idx, idx]))
- prev_bmes_tag = bmes_tag
- return [(span[0], (span[1][0], span[1][1] + 1))
- for span in spans
- if span[0] not in ignore_labels
- ]
-
-
-def _bioes_tag_to_spans(tags, ignore_labels=None):
- r"""
- 给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'E-singer', 'O', 'O']。
- 返回[('singer', (1, 4))] (左闭右开区间)
-
- :param tags: List[str],
- :param ignore_labels: List[str], 在该list中的label将被忽略
- :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
- """
- ignore_labels = set(ignore_labels) if ignore_labels else set()
-
- spans = []
- prev_bioes_tag = None
- for idx, tag in enumerate(tags):
- tag = tag.lower()
- bioes_tag, label = tag[:1], tag[2:]
- if bioes_tag in ('b', 's'):
- spans.append((label, [idx, idx]))
- elif bioes_tag in ('i', 'e') and prev_bioes_tag in ('b', 'i') and label == spans[-1][0]:
- spans[-1][1][1] = idx
- elif bioes_tag == 'o':
- pass
- else:
- spans.append((label, [idx, idx]))
- prev_bioes_tag = bioes_tag
- return [(span[0], (span[1][0], span[1][1] + 1))
- for span in spans
- if span[0] not in ignore_labels
- ]
-
-
-def _bio_tag_to_spans(tags, ignore_labels=None):
- r"""
- 给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。
- 返回[('singer', (1, 4))] (左闭右开区间)
-
- :param tags: List[str],
- :param ignore_labels: List[str], 在该list中的label将被忽略
- :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
- """
- ignore_labels = set(ignore_labels) if ignore_labels else set()
-
- spans = []
- prev_bio_tag = None
- for idx, tag in enumerate(tags):
- tag = tag.lower()
- bio_tag, label = tag[:1], tag[2:]
- if bio_tag == 'b':
- spans.append((label, [idx, idx]))
- elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label == spans[-1][0]:
- spans[-1][1][1] = idx
- elif bio_tag == 'o': # o tag does not count
- pass
- else:
- spans.append((label, [idx, idx]))
- prev_bio_tag = bio_tag
- return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels]
-
-
-def _get_encoding_type_from_tag_vocab(tag_vocab: Union[Vocabulary, dict]) -> str:
- r"""
- 给定Vocabulary自动判断是哪种类型的encoding, 支持判断bmes, bioes, bmeso, bio
-
- :param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。
- :return:
- """
- tag_set = set()
- unk_token = ''
- pad_token = ''
- if isinstance(tag_vocab, Vocabulary):
- unk_token = tag_vocab.unknown
- pad_token = tag_vocab.padding
- tag_vocab = tag_vocab.idx2word
- for idx, tag in tag_vocab.items():
- if tag in (unk_token, pad_token):
- continue
- tag = tag[:1].lower()
- tag_set.add(tag)
-
- bmes_tag_set = set('bmes')
- if tag_set == bmes_tag_set:
- return 'bmes'
- bio_tag_set = set('bio')
- if tag_set == bio_tag_set:
- return 'bio'
- bmeso_tag_set = set('bmeso')
- if tag_set == bmeso_tag_set:
- return 'bmeso'
- bioes_tag_set = set('bioes')
- if tag_set == bioes_tag_set:
- return 'bioes'
- raise RuntimeError("encoding_type cannot be inferred automatically. Only support "
- "'bio', 'bmes', 'bmeso', 'bioes' type.")
-
-
-def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encoding_type: str):
- r"""
- 检查vocab中的tag是否与encoding_type是匹配的
-
- :param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。
- :param encoding_type: bio, bmes, bioes, bmeso
- :return:
- """
- tag_set = set()
- unk_token = ''
- pad_token = ''
- if isinstance(tag_vocab, Vocabulary):
- unk_token = tag_vocab.unknown
- pad_token = tag_vocab.padding
- tag_vocab = tag_vocab.idx2word
- for idx, tag in tag_vocab.items():
- if tag in (unk_token, pad_token):
- continue
- tag = tag[:1].lower()
- tag_set.add(tag)
-
- tags = encoding_type
- for tag in tag_set:
- assert tag in tags, f"{tag} is not a valid tag in encoding type:{encoding_type}. Please check your " \
- f"encoding_type."
- tags = tags.replace(tag, '') # 删除该值
- if tags: # 如果不为空,说明出现了未使用的tag
- warnings.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your "
- "encoding_type.")
-
-
-class SpanFPreRecMetric(MetricBase):
- r"""
- 在序列标注问题中,以span的方式计算F, pre, rec.
- 比如中文Part of speech中,会以character的方式进行标注,句子 `中国在亚洲` 对应的POS可能为(以BMES为例)
- ['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。
- 最后得到的metric结果为::
-
- {
- 'f': xxx, # 这里使用f考虑以后可以计算f_beta值
- 'pre': xxx,
- 'rec':xxx
- }
-
- 若only_gross=False, 即还会返回各个label的metric统计值::
-
- {
- 'f': xxx,
- 'pre': xxx,
- 'rec':xxx,
- 'f-label': xxx,
- 'pre-label': xxx,
- 'rec-label':xxx,
- ...
- }
- """
-
- def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type=None, ignore_labels=None,
- only_gross=True, f_type='micro', beta=1):
- r"""
-
- :param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN),
- 在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'.
- :param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据
- :param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据
- :param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。
- :param str encoding_type: 目前支持bio, bmes, bmeso, bioes。默认为None,通过tag_vocab自动判断.
- :param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'个label
- :param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个label的f1, pre, rec
- :param str f_type: `micro` 或 `macro` . `micro` :通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; `macro` : 分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同)
- :param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . 常用为 `beta=0.5, 1, 2` 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。
- """
-
- if not isinstance(tag_vocab, Vocabulary):
- raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab)))
- if f_type not in ('micro', 'macro'):
- raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type))
-
- if encoding_type:
- encoding_type = encoding_type.lower()
- _check_tag_vocab_and_encoding_type(tag_vocab, encoding_type)
- self.encoding_type = encoding_type
- else:
- self.encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab)
-
- if self.encoding_type == 'bmes':
- self.tag_to_span_func = _bmes_tag_to_spans
- elif self.encoding_type == 'bio':
- self.tag_to_span_func = _bio_tag_to_spans
- elif self.encoding_type == 'bmeso':
- self.tag_to_span_func = _bmeso_tag_to_spans
- elif self.encoding_type == 'bioes':
- self.tag_to_span_func = _bioes_tag_to_spans
- else:
- raise ValueError("Only support 'bio', 'bmes', 'bmeso', 'bioes' type.")
-
- self.ignore_labels = ignore_labels
- self.f_type = f_type
- self.beta = beta
- self.beta_square = self.beta ** 2
- self.only_gross = only_gross
-
- super().__init__()
- self._init_param_map(pred=pred, target=target, seq_len=seq_len)
-
- self.tag_vocab = tag_vocab
-
- self._true_positives = defaultdict(int)
- self._false_positives = defaultdict(int)
- self._false_negatives = defaultdict(int)
-
- def evaluate(self, pred, target, seq_len):
- r"""evaluate函数将针对一个批次的预测结果做评价指标的累计
-
- :param pred: [batch, seq_len] 或者 [batch, seq_len, len(tag_vocab)], 预测的结果
- :param target: [batch, seq_len], 真实值
- :param seq_len: [batch] 文本长度标记
- :return:
- """
- if not isinstance(pred, torch.Tensor):
- raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
- f"got {type(pred)}.")
- if not isinstance(target, torch.Tensor):
- raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
- f"got {type(target)}.")
-
- if not isinstance(seq_len, torch.Tensor):
- raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
- f"got {type(seq_len)}.")
-
- if pred.size() == target.size() and len(target.size()) == 2:
- pass
- elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2:
- num_classes = pred.size(-1)
- pred = pred.argmax(dim=-1)
- if (target >= num_classes).any():
- raise ValueError("A gold label passed to SpanBasedF1Metric contains an "
- "id >= {}, the number of classes.".format(num_classes))
- else:
- raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have "
- f"size:{pred.size()}, target should have size: {pred.size()} or "
- f"{pred.size()[:-1]}, got {target.size()}.")
-
- batch_size = pred.size(0)
- pred = pred.tolist()
- target = target.tolist()
- for i in range(batch_size):
- pred_tags = pred[i][:int(seq_len[i])]
- gold_tags = target[i][:int(seq_len[i])]
-
- pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags]
- gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags]
-
- pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels)
- gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels)
-
- for span in pred_spans:
- if span in gold_spans:
- self._true_positives[span[0]] += 1
- gold_spans.remove(span)
- else:
- self._false_positives[span[0]] += 1
- for span in gold_spans:
- self._false_negatives[span[0]] += 1
-
- def get_metric(self, reset=True):
- r"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果."""
- evaluate_result = {}
- if not self.only_gross or self.f_type == 'macro':
- tags = set(self._false_negatives.keys())
- tags.update(set(self._false_positives.keys()))
- tags.update(set(self._true_positives.keys()))
- f_sum = 0
- pre_sum = 0
- rec_sum = 0
- for tag in tags:
- tp = self._true_positives[tag]
- fn = self._false_negatives[tag]
- fp = self._false_positives[tag]
- f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp)
- f_sum += f
- pre_sum += pre
- rec_sum += rec
- if not self.only_gross and tag != '': # tag!=''防止无tag的情况
- f_key = 'f-{}'.format(tag)
- pre_key = 'pre-{}'.format(tag)
- rec_key = 'rec-{}'.format(tag)
- evaluate_result[f_key] = f
- evaluate_result[pre_key] = pre
- evaluate_result[rec_key] = rec
-
- if self.f_type == 'macro':
- evaluate_result['f'] = f_sum / len(tags)
- evaluate_result['pre'] = pre_sum / len(tags)
- evaluate_result['rec'] = rec_sum / len(tags)
-
- if self.f_type == 'micro':
- f, pre, rec = _compute_f_pre_rec(self.beta_square,
- sum(self._true_positives.values()),
- sum(self._false_negatives.values()),
- sum(self._false_positives.values()))
- evaluate_result['f'] = f
- evaluate_result['pre'] = pre
- evaluate_result['rec'] = rec
-
- if reset:
- self._true_positives = defaultdict(int)
- self._false_positives = defaultdict(int)
- self._false_negatives = defaultdict(int)
-
- for key, value in evaluate_result.items():
- evaluate_result[key] = round(value, 6)
-
- return evaluate_result
-
-
-def _compute_f_pre_rec(beta_square, tp, fn, fp):
- r"""
-
- :param tp: int, true positive
- :param fn: int, false negative
- :param fp: int, false positive
- :return: (f, pre, rec)
- """
- pre = tp / (fp + tp + 1e-13)
- rec = tp / (fn + tp + 1e-13)
- f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13)
-
- return f, pre, rec
-
-
-def _prepare_metrics(metrics):
- r"""
-
- Prepare list of Metric based on input
- :param metrics:
- :return: List[fastNLP.MetricBase]
- """
- _metrics = []
- if metrics:
- if isinstance(metrics, list):
- for metric in metrics:
- if isinstance(metric, type):
- metric = metric()
- if isinstance(metric, MetricBase):
- metric_name = metric.__class__.__name__
- if not callable(metric.evaluate):
- raise TypeError(f"{metric_name}.evaluate must be callable, got {type(metric.evaluate)}.")
- if not callable(metric.get_metric):
- raise TypeError(f"{metric_name}.get_metric must be callable, got {type(metric.get_metric)}.")
- _metrics.append(metric)
- else:
- raise TypeError(
- f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.")
- elif isinstance(metrics, MetricBase):
- _metrics = [metrics]
- else:
- raise TypeError(f"The type of metrics should be `list[fastNLP.MetricBase]` or `fastNLP.MetricBase`, "
- f"got {type(metrics)}.")
- return _metrics
-
-
-def _accuracy_topk(y_true, y_prob, k=1):
- r"""Compute accuracy of y_true matching top-k probable labels in y_prob.
-
- :param y_true: ndarray, true label, [n_samples]
- :param y_prob: ndarray, label probabilities, [n_samples, n_classes]
- :param k: int, k in top-k
- :returns acc: accuracy of top-k
-
- """
- y_pred_topk = np.argsort(y_prob, axis=-1)[:, -1:-k - 1:-1]
- y_true_tile = np.tile(np.expand_dims(y_true, axis=1), (1, k))
- y_match = np.any(y_pred_topk == y_true_tile, axis=-1)
- acc = np.sum(y_match) / y_match.shape[0]
- return acc
-
-
-def _pred_topk(y_prob, k=1):
- r"""Return top-k predicted labels and corresponding probabilities.
-
- :param y_prob: ndarray, size [n_samples, n_classes], probabilities on labels
- :param k: int, k of top-k
- :returns (y_pred_topk, y_prob_topk):
- y_pred_topk: ndarray, size [n_samples, k], predicted top-k labels
- y_prob_topk: ndarray, size [n_samples, k], probabilities for top-k labels
-
- """
- y_pred_topk = np.argsort(y_prob, axis=-1)[:, -1:-k - 1:-1]
- x_axis_index = np.tile(
- np.arange(len(y_prob))[:, np.newaxis],
- (1, k))
- y_prob_topk = y_prob[x_axis_index, y_pred_topk]
- return y_pred_topk, y_prob_topk
-
-
-class CMRC2018Metric(MetricBase):
- r"""
- CRMC2018任务的评价metric
- """
- def __init__(self, answers=None, raw_chars=None, context_len=None, pred_start=None, pred_end=None):
- super().__init__()
- self._init_param_map(answers=answers, raw_chars=raw_chars, context_len=context_len, pred_start=pred_start,
- pred_end=pred_end)
- self.em = 0
- self.total = 0
- self.f1 = 0
-
- def evaluate(self, answers, raw_chars, pred_start, pred_end, context_len=None):
- r"""
-
- :param list[str] answers: 如[["答案1", "答案2", "答案3"], [...], ...]
- :param list[str] raw_chars: [["这", "是", ...], [...]]
- :param tensor pred_start: batch_size x length 或 batch_size,
- :param tensor pred_end: batch_size x length 或 batch_size(是闭区间,包含end位置),
- :param tensor context_len: context长度, batch_size
- :return:
- """
- if pred_start.dim() > 1:
- batch_size, max_len = pred_start.size()
- context_mask = seq_len_to_mask(context_len, max_len=max_len).eq(False)
- pred_start.masked_fill_(context_mask, float('-inf'))
- pred_end.masked_fill_(context_mask, float('-inf'))
- max_pred_start, pred_start_index = pred_start.max(dim=-1, keepdim=True) # batch_size,
- pred_start_mask = pred_start.eq(max_pred_start).cumsum(dim=-1).eq(0) # 只能预测这之后的值
- pred_end.masked_fill_(pred_start_mask, float('-inf'))
- pred_end_index = pred_end.argmax(dim=-1) + 1
- else:
- pred_start_index = pred_start
- pred_end_index = pred_end + 1
- pred_ans = []
- for index, (start, end) in enumerate(zip(pred_start_index.flatten().tolist(), pred_end_index.tolist())):
- pred_ans.append(''.join(raw_chars[index][start:end]))
- for answer, pred_an in zip(answers, pred_ans):
- pred_an = pred_an.strip()
- self.f1 += _calc_cmrc2018_f1_score(answer, pred_an)
- self.total += 1
- self.em += _calc_cmrc2018_em_score(answer, pred_an)
-
- def get_metric(self, reset=True):
- eval_res = {'f1': round(self.f1 / self.total*100, 2), 'em': round(self.em / self.total*100, 2)}
- if reset:
- self.em = 0
- self.total = 0
- self.f1 = 0
- return eval_res
-
-# split Chinese
-def _cn_segmentation(in_str, rm_punc=False):
- in_str = str(in_str).lower().strip()
- segs_out = []
- temp_str = ""
- sp_char = {'-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', ':', '?', '!', '“', '”', ';', '’', '《',
- '》', '……', '·', '、', '「', '」', '(', ')', '-', '~', '『', '』'}
- for char in in_str:
- if rm_punc and char in sp_char:
- continue
- if re.search(r'[\u4e00-\u9fa5]', char) or char in sp_char:
- if temp_str != "":
- ss = list(temp_str)
- segs_out.extend(ss)
- temp_str = ""
- segs_out.append(char)
- else:
- temp_str += char
-
- # handling last part
- if temp_str != "":
- ss = list(temp_str)
- segs_out.extend(ss)
-
- return segs_out
-
-
-# remove punctuation
-def _remove_punctuation(in_str):
- in_str = str(in_str).lower().strip()
- sp_char = ['-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=',
- ',', '。', ':', '?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、',
- '「', '」', '(', ')', '-', '~', '『', '』']
- out_segs = []
- for char in in_str:
- if char in sp_char:
- continue
- else:
- out_segs.append(char)
- return ''.join(out_segs)
-
-
-# find longest common string
-def _find_lcs(s1, s2):
- m = [[0 for i in range(len(s2) + 1)] for j in range(len(s1) + 1)]
- mmax = 0
- p = 0
- for i in range(len(s1)):
- for j in range(len(s2)):
- if s1[i] == s2[j]:
- m[i + 1][j + 1] = m[i][j] + 1
- if m[i + 1][j + 1] > mmax:
- mmax = m[i + 1][j + 1]
- p = i + 1
- return s1[p - mmax:p], mmax
-
-
-def _calc_cmrc2018_f1_score(answers, prediction):
- f1_scores = []
- for ans in answers:
- ans_segs = _cn_segmentation(ans, rm_punc=True)
- prediction_segs = _cn_segmentation(prediction, rm_punc=True)
- lcs, lcs_len = _find_lcs(ans_segs, prediction_segs)
- if lcs_len == 0:
- f1_scores.append(0)
- continue
- precision = 1.0 * lcs_len / len(prediction_segs)
- recall = 1.0 * lcs_len / len(ans_segs)
- f1 = (2 * precision * recall) / (precision + recall)
- f1_scores.append(f1)
- return max(f1_scores)
-
-
-def _calc_cmrc2018_em_score(answers, prediction):
- em = 0
- for ans in answers:
- ans_ = _remove_punctuation(ans)
- prediction_ = _remove_punctuation(prediction)
- if ans_ == prediction_:
- em = 1
- break
- return em
diff --git a/fastNLP/core/metrics/__init__.py b/fastNLP/core/metrics/__init__.py
new file mode 100644
index 00000000..b7f572e8
--- /dev/null
+++ b/fastNLP/core/metrics/__init__.py
@@ -0,0 +1,12 @@
+__all__ = [
+ "Metric",
+ "Accuracy",
+ "TransformersAccuracy",
+ 'SpanFPreRecMetric',
+ 'ClassifyFPreRecMetric',
+]
+
+from .metric import Metric
+from .accuracy import Accuracy, TransformersAccuracy
+from .span_f1_pre_rec_metric import SpanFPreRecMetric
+from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric
diff --git a/fastNLP/core/metrics/accuracy.py b/fastNLP/core/metrics/accuracy.py
new file mode 100644
index 00000000..5475046e
--- /dev/null
+++ b/fastNLP/core/metrics/accuracy.py
@@ -0,0 +1,99 @@
+__all__ = [
+ 'Accuracy',
+ "TransformersAccuracy"
+]
+
+from typing import Union
+
+import numpy as np
+
+from fastNLP.core.metrics.metric import Metric
+from fastNLP.core.metrics.backend import Backend
+from fastNLP.core.utils.seq_len_to_mask import seq_len_to_mask
+from fastNLP.core.log import logger
+
+
+class Accuracy(Metric):
+ """
+ 计算 准确率 的 metric 。
+
+ :param backend: 目前支持五种类型的backend, ``['auto', 'torch', 'paddle', 'jittor', 'oneflow']``。其中 ``'auto'`` 表示根据实际调用
+ :meth:`update` 函数时传入的参数决定具体的 backend ,一般情况下直接使用 ``'auto'`` 即可。
+ :param aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到 metric,
+ 当 ``backend`` 不支持分布式时,该参数无意义。如果为 ``None`` ,将在 :class:`~fastNLP.core.controllers.Evaluator`
+ 中根据 ``sampler`` 是否使用分布式进行自动设置。
+ """
+ def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None):
+ super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric)
+ self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend)
+ self.register_element(name='total', value=0, aggregate_method="sum", backend=backend)
+
+ def get_metric(self) -> dict:
+ r"""
+ :meth:`get_metric` 函数将根据 :meth:`update` 函数累计的评价指标统计量来计算最终的评价结果。
+
+ :return: 包含以下内容的字典:``{"acc": float, 'total': float, 'correct': float}``;
+ """
+ evaluate_result = {'acc': round(self.correct.get_scalar() / (self.total.get_scalar() + 1e-12), 6)}
+ return evaluate_result
+
+ def update(self, pred, target, seq_len=None):
+ r"""
+ :meth:`update` 函数将针对一个批次的预测结果做评价指标的累计。
+
+ :param pred: 预测的 tensor, tensor 的形状可以是 ``[B,]`` 、``[B, n_classes]`` 、
+ ``[B, max_len]`` 或 ``[B, max_len, n_classes]``
+ :param target: 真实值的 tensor, tensor 的形状可以是 ``[B,]`` 、``[B,]`` 、``[B, max_len]``
+ 或 ``[B, max_len]``
+ :param seq_len: 序列长度标记, 标记的形状可以是 ``None``, 或者 ``[B]`` 。
+ 如果 mask 也被传进来的话 ``seq_len`` 会被忽略
+ """
+ # 为了兼容不同框架,我们将输入变量全部转为numpy类型来进行计算。
+ pred = self.tensor2numpy(pred)
+ target = self.tensor2numpy(target)
+ if seq_len is not None:
+ seq_len = self.tensor2numpy(seq_len)
+
+ if seq_len is not None and target.ndim > 1:
+ max_len = target.shape[1]
+ masks = seq_len_to_mask(seq_len, max_len)
+ else:
+ masks = None
+
+ if pred.ndim == target.ndim:
+ if np.prod(pred.shape) != np.prod(target.shape):
+ raise RuntimeError(f"when pred have same dimensions with target, they should have same element numbers."
+ f" while target have shape:{target.shape}, "
+ f"pred have shape: {pred.shape}")
+
+ elif pred.ndim == target.ndim + 1:
+ pred = pred.argmax(axis=-1)
+ if seq_len is None and target.ndim > 1:
+ logger.warning("You are not passing `seq_len` to exclude pad when calculate accuracy.")
+
+ else:
+ raise RuntimeError(f"when pred have size:{pred.shape}, target should have size: {pred.shape} or "
+ f"{pred.shape[:-1]}, got {target.shape}.")
+
+ if masks is not None:
+ self.total += masks.sum().item()
+ self.correct += ((pred == target) * masks).sum().item()
+ else:
+ self.total += np.prod(list(pred.shape)).item()
+ self.correct += (target == pred).sum().item()
+
+
+class TransformersAccuracy(Accuracy):
+ """
+ 适配 :mod:`transformers` 中相关模型的 Accuracy metric 。
+ """
+ def update(self, logits, labels, attention_mask=None):
+ r"""
+ :meth:`update` 函数将针对一个批次的预测结果做评价指标的累计。
+
+ :param logits: 形状为 ``[B, n_classes]`` 或 ``[B, max_len, n_classes]`` 。
+ :param labels: 形状为 ``[B, ]`` 或 ``[B, max_len]``
+ :param attention_mask: 序列长度标记。
+ """
+ seq_len = attention_mask.sum(dim=-1)
+ super().update(pred=logits, target=labels, seq_len=seq_len)
\ No newline at end of file
diff --git a/fastNLP/core/metrics/backend/__init__.py b/fastNLP/core/metrics/backend/__init__.py
new file mode 100644
index 00000000..d09efb53
--- /dev/null
+++ b/fastNLP/core/metrics/backend/__init__.py
@@ -0,0 +1,16 @@
+__all__ = [
+ 'Backend',
+ 'AutoBackend',
+ 'TorchBackend',
+ 'PaddleBackend',
+ 'JittorBackend',
+ 'OneflowBackend',
+]
+
+
+from .backend import Backend
+from .auto_backend import AutoBackend
+from .torch_backend import TorchBackend
+from .paddle_backend import PaddleBackend
+from .jittor_backend import JittorBackend
+from .oneflow_backend import OneflowBackend
diff --git a/fastNLP/core/metrics/backend/auto_backend.py b/fastNLP/core/metrics/backend/auto_backend.py
new file mode 100644
index 00000000..3bb04232
--- /dev/null
+++ b/fastNLP/core/metrics/backend/auto_backend.py
@@ -0,0 +1,115 @@
+from typing import Union
+import sys
+
+from ....envs.imports import SUPPORT_BACKENDS
+from ...log import logger
+
+from .backend import Backend
+from .torch_backend.backend import TorchBackend
+from .paddle_backend.backend import PaddleBackend
+from .jittor_backend.backend import JittorBackend
+from .oneflow_backend.backend import OneflowBackend
+
+__all__ = []
+
+class AutoBackend(Backend):
+ """
+ 不需要初始化 ``backend`` 的 :class:`AutoBackend`,能够根据 :meth:`get_metric` 时候判断输入数据类型来选择 ``backend``。
+ """
+
+ def __init__(self, backend: Union[str, Backend, None]):
+ """
+ 初始化 backend.
+
+ :param backend: 目前支持三种值,为 ``[str, Backend, None]``。
+
+ * 当 backend 为 :class:`str` 时, 其只能为 ``'auto'``;
+ * 当 backend 为 ``Backend`` 对象时, 其直接使用该对象方法覆盖 :class:`AutoBackend`;
+ * 当 backend 为 ``None`` 时, 根据 :meth:`get_metric` 时候判断输入数据类型来选择 ``backend``;
+
+ """
+ super(AutoBackend, self).__init__()
+ if backend != 'auto':
+ self._convert_backend(backend)
+
+ def _convert_backend(self, backend):
+ """
+ 将 AutoBackend 转换为合适的 Backend 对象
+
+ :param backend: 传入的 backend 值。
+
+ * 当 backend 为 ``'torch'`` 时, 选择 :class:`~fastNLP.core.metric.TorchBackend`
+ * 当 backend 为 ``'paddle'` 时, 选择 :class:`~fastNLP.core.metric.PaddleBackend`
+ * 当 backend 为 ``'jittor'`` 时, 选择 :class:`~fastNLP.core.metric.JittorBackend`
+ * 当 backend 为 ``'oneflow'`` 时, 选择 :class:`~fastNLP.core.metric.OneflowBackend`
+ * 当 backend 为 ``None`` 时, 直接初始化
+
+ """
+ if isinstance(backend, Backend):
+ self.__class__ = backend.__class__
+ # 如果是str,直接选择就好了
+ elif backend == 'torch':
+ self.__class__ = TorchBackend
+ elif backend == 'paddle':
+ self.__class__ = PaddleBackend
+ elif backend == 'jittor':
+ self.__class__ = JittorBackend
+ elif backend == 'oneflow':
+ self.__class__ = OneflowBackend
+ elif backend is None:
+ # 不用做任何事情就可以初始化了
+ pass
+ else:
+ raise RuntimeError(f"We did not support `{backend}` to be used as backend for now.")
+ self._specified = True
+
+ def choose_real_backend(self, args):
+ """
+ 根据 args 参数类型来选择需要真正初始化的 backend
+
+ :param args: args 参数, 可能为 ``'jittor'``, ``'torch'``, ``'paddle'``, ``'oneflow'``, ``'numpy'`` 类型, 能够检测并选择真正的 backend。
+
+ """
+ assert not self.is_specified(), "This method should not be called after backend has been specified. " \
+ "This must be a bug, please report."
+ types = []
+ for arg in args:
+ types.append(str(type(arg)))
+
+ torch_types = []
+ jittor_types = []
+ paddle_types = []
+ for type_name in types:
+ if 'torch' in type_name:
+ torch_types.append(type_name)
+ if 'paddle' in type_name:
+ paddle_types.append(type_name)
+ if 'jittor' in type_name:
+ jittor_types.append(type_name)
+
+ # 根据 https://stackoverflow.com/a/3464154 ,可以通过这种方法实现切换成真实的 backend 上
+ if len(torch_types) > 0 and len(jittor_types) == 0 and len(paddle_types) == 0:
+ backend = 'torch'
+ elif len(torch_types) == 0 and len(jittor_types) > 0 and len(paddle_types) == 0:
+ backend = 'jittor'
+ elif len(torch_types) == 0 and len(jittor_types) == 0 and len(paddle_types) > 0:
+ backend = 'paddle'
+ elif len(torch_types) == 0 and len(jittor_types) == 0 and len(paddle_types) == 0:
+ backend = None
+ # 尝试通过 modules 的方式自动寻找
+ find_backends = []
+ for backend in SUPPORT_BACKENDS:
+ if backend in sys.modules:
+ find_backends.append(backend)
+ if len(find_backends) == 1:
+ backend = find_backends[0]
+ logger.debug(f'Find backend:{backend} through sys.modules.')
+ else:
+ logger.debug(f'Cannot find backend through sys.modules, since find:{find_backends}.')
+ else:
+ types = list(set(torch_types + jittor_types + paddle_types))
+ raise RuntimeError(
+ f"Mixture of tensor type:{types} have been accept, please manually set backend instead of "
+ f"using backend=auto.")
+
+ self._convert_backend(backend)
diff --git a/fastNLP/core/metrics/backend/backend.py b/fastNLP/core/metrics/backend/backend.py
new file mode 100644
index 00000000..b5afb3db
--- /dev/null
+++ b/fastNLP/core/metrics/backend/backend.py
@@ -0,0 +1,90 @@
+from ..utils import AggregateMethodError
+
+__all__ = []
+
+class Backend:
+ """
+ 执行评测时使用的 backend,是所有 backend 的父类。Backend 及其子类的所有方法都必须是无状态的。
+ """
+
+ def __init__(self):
+ self._specified = False
+
+ def aggregate(self, tensor, method: str):
+ """
+ 聚集结果,并根据 ``method 计算后`` ,返回结果。
+
+ :param tensor: 传入的张量
+ :param method: 聚合的方法
+ """
+ if method is not None:
+ return AggregateMethodError(should_have_aggregate_method=False, only_warn=True)
+
+ return tensor
+
+ def create_tensor(self, value: float):
+ """
+ 创建 tensor,并且填入 ``value`` 作为值。
+
+ :param value: 需要初始化的 ``value`` 值
+ """
+ return value
+
+ def fill_value(self, tensor, value: float):
+ """
+ 将 tensor 的值设置为 ``value``
+
+ :param tensor: 传进来的张量
+ :param value: 需要填充的值
+ """
+ return value
+
+ def get_scalar(self, tensor) -> float:
+ """
+ ``tensor`` 的 saclar 值.
+
+ :param tensor: 传入的张量;
+ :return:
+ """
+ return tensor
+
+ def is_specified(self) -> bool:
+ """
+ 判断是否是某种框架的 backend。
+
+ :return:
+ """
+ return self._specified
+
+ def tensor2numpy(self, tensor):
+ """
+ 将 ``tensor`` 转为 :class:`numpy.array`。
+
+ :param tensor: 传入的张量
+ :return:
+ """
+ return tensor
+
+ def move_tensor_to_device(self, tensor, device):
+ """
+ 将张量移动到某个设备上。
+
+ :param tensor: 传入的张量
+ :param device: 设备号, 一般为 ``'cpu'``, ``'cuda:0'`` 等
+ """
+ return tensor
+
+ def all_gather_object(self, obj, group=None):
+ """
+ 给定 ``obj`` 将各个 rank 上的 ``obj`` 汇总到每个 ``obj`` 上。返回一个 :class:`list` 对象,里面依次为各个 rank 对应的 ``obj`` 。
+
+ :param obj:
+ :param group:
+ :return:
+ """
+ if self.__class__.__name__ == 'AutoBackend':
+ raise RuntimeError("fastNLP cannot determine the backend automatically, please pass in the backend through "
+ "initialization.")
+
+ raise NotImplementedError(f"all_gather_object() function is not implemented for {self.__class__.__name__}.")
+
diff --git a/fastNLP/core/metrics/backend/jittor_backend/__init__.py b/fastNLP/core/metrics/backend/jittor_backend/__init__.py
new file mode 100644
index 00000000..01aed511
--- /dev/null
+++ b/fastNLP/core/metrics/backend/jittor_backend/__init__.py
@@ -0,0 +1,5 @@
+__all__ = [
+ "JittorBackend",
+]
+
+from .backend import JittorBackend
\ No newline at end of file
diff --git a/fastNLP/core/metrics/backend/jittor_backend/backend.py b/fastNLP/core/metrics/backend/jittor_backend/backend.py
new file mode 100644
index 00000000..a3fcfb5a
--- /dev/null
+++ b/fastNLP/core/metrics/backend/jittor_backend/backend.py
@@ -0,0 +1,73 @@
+import numpy as np
+
+from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
+from fastNLP.core.metrics.backend import Backend
+
+if _NEED_IMPORT_JITTOR:
+ import jittor
+
+__all__ = []
+
+class JittorBackend(Backend):
+
+ def __init__(self):
+ super(JittorBackend, self).__init__()
+ self._specified = True
+
+ def aggregate(self, tensor, method: str):
+ """
+ 聚集结果,并根据 method 计算后,返回结果
+ """
+ return tensor
+
+ def create_tensor(self, value: float):
+ """
+ 创建 tensor,并且填入 value 作为值
+ """
+ value = jittor.Var(value)
+ return value
+
+ def fill_value(self, tensor, value: float):
+ """
+ 将 tensor 的值设置为 value
+
+ """
+ value = jittor.full_like(tensor, value)
+ return value
+
+ def get_scalar(self, tensor) -> float:
+ """
+ tensor 的 saclar 值
+
+ :param tensor:
+ :return:
+ """
+ return tensor.item()
+
+ def is_specified(self) -> bool:
+ """
+ 判断是否是某种框架的 backend
+
+ :return:
+ """
+ return self._specified
+
+ def tensor2numpy(self, tensor):
+ """
+ 将 tensor 转为 numpy
+
+ :param tensor:
+ :return:
+ """
+ if isinstance(tensor, jittor.Var):
+ return tensor.detach().numpy()
+ elif isinstance(tensor, np.array):
+ return tensor
+ else:
+ raise ValueError(f"tensor: {tensor} can not convert to ndarray!")
+
+ def move_tensor_to_device(self, tensor, device):
+ """
+ jittor 的没有转移设备的函数,因此该函数实际上无效
+ """
+ return tensor
diff --git a/fastNLP/core/metrics/backend/oneflow_backend/__init__.py b/fastNLP/core/metrics/backend/oneflow_backend/__init__.py
new file mode 100644
index 00000000..f5774a14
--- /dev/null
+++ b/fastNLP/core/metrics/backend/oneflow_backend/__init__.py
@@ -0,0 +1,5 @@
+__all__ = [
+ "OneflowBackend",
+]
+
+from .backend import OneflowBackend
\ No newline at end of file
diff --git a/fastNLP/core/metrics/backend/oneflow_backend/backend.py b/fastNLP/core/metrics/backend/oneflow_backend/backend.py
new file mode 100644
index 00000000..cf19e382
--- /dev/null
+++ b/fastNLP/core/metrics/backend/oneflow_backend/backend.py
@@ -0,0 +1,130 @@
+from typing import List
+
+import numpy as np
+
+from fastNLP.core.metrics.backend import Backend
+from fastNLP.core.metrics.utils import AggregateMethodError
+from fastNLP.core.utils import is_in_oneflow_dist
+from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW
+from fastNLP.core.drivers.oneflow_driver.dist_utils import fastnlp_oneflow_all_gather
+
+
+if _NEED_IMPORT_ONEFLOW:
+ import oneflow
+ import oneflow.comm as comm
+
+__all__ = []
+
+class OneflowBackend(Backend):
+ def __init__(self):
+ super().__init__()
+ self._specified = True
+
+ def aggregate(self, tensor, method: str):
+ """
+ 聚集结果,并根据 method 计算后,返回结果
+
+ :param tensor: 需要聚合的张量
+ :param method: 聚合的方法, 目前支持 ``['sum', 'mean', 'max', 'min']``:
+
+ * method 为 ``'sum'`` 时, 会将多张卡上聚合结果在维度为 `0` 上 累加起来。
+ * method 为 ``'mean'`` 时,会将多张卡上聚合结果在维度为 `0` 上取平均值。
+ * method 为 ``'max'`` 时,会将多张卡上聚合结果在维度为 `0` 上取最大值。
+ * method 为 ``'min'`` 时,会将多张卡上聚合结果在维度为 `0` 上取最小值。
+
+ """
+ if isinstance(tensor, oneflow.Tensor):
+ # TODO 暂时没有找到 oneflow 中检测是否初始化了分布式环境的方法
+ if is_in_oneflow_dist():
+ if method is None:
+ raise AggregateMethodError(should_have_aggregate_method=True)
+ tensor = self.all_gather_object(tensor)
+ if isinstance(tensor[0], oneflow.Tensor):
+ tensor = oneflow.stack(tensor)
+ # 第一步, aggregate结果
+ if method == 'sum':
+ tensor = oneflow.sum(tensor, dim=0)
+ elif method == 'mean':
+ tensor = oneflow.mean(tensor, dim=0)
+ elif method == 'max':
+ tensor, _ = oneflow.max(tensor, dim=0)
+ elif method == 'min':
+ tensor, _ = oneflow.min(tensor, dim=0)
+ else:
+ raise AggregateMethodError(should_have_aggregate_method=False)
+
+ return tensor
+
+ def create_tensor(self, value: float):
+ """
+ 创建 tensor,并且填入 value 作为值
+
+ :param value: 创建张量的初始值
+ """
+ tensor = oneflow.ones(1).fill_(value)
+ return tensor
+
+ def fill_value(self, tensor, value: float):
+ """
+ 将 tensor 的值设置为 value
+
+ :param tensor: 传入的张量
+ :param value: 需要 fill 的值。
+ """
+ tensor.fill_(value)
+ return tensor
+
+ def get_scalar(self, tensor) -> float:
+ """
+ 获取 tensor 的 scalar 值
+
+ :param tensor: 传入的张量
+ """
+ return tensor.item()
+
+ def tensor2numpy(self, tensor) -> np.array:
+ """
+ 将 tensor 转为 numpy 值, 主要是在 metric 计算中使用
+
+ :param tensor: 传入的张量
+ """
+
+ if isinstance(tensor, oneflow.Tensor):
+ return tensor.cpu().detach().numpy()
+ elif isinstance(tensor, np.ndarray):
+ return tensor
+ elif isinstance(tensor, (float, int)):
+ return tensor
+ else:
+ raise ValueError(f"tensor: {tensor} can not convert to ndarray!")
+
+ @staticmethod
+ def is_distributed() -> bool:
+ """
+ 判断是否为 ddp 状态
+
+ :return:
+ """
+ return is_in_oneflow_dist()
+
+ def move_tensor_to_device(self, tensor, device):
+ """
+ 将张量移到设备上
+
+ :param tensor: 需要移动的张量
+ :param device: 设备名, 一般为 "cpu", "cuda:0"等字符串
+ """
+ return tensor.to(device)
+
+ def all_gather_object(self, obj, group=None) -> List:
+ """
+ 给定 obj 将各个 rank 上的 obj 汇总到每个 obj 上。返回一个 list 对象,里面依次为各个 rank 对应的 obj 。
+
+ :param obj:
+ :param group:
+ """
+ if self.is_distributed():
+ obj_list = fastnlp_oneflow_all_gather(obj)
+ return obj_list
+ return [obj]
+
diff --git a/fastNLP/core/metrics/backend/paddle_backend/__init__.py b/fastNLP/core/metrics/backend/paddle_backend/__init__.py
new file mode 100644
index 00000000..1f409e32
--- /dev/null
+++ b/fastNLP/core/metrics/backend/paddle_backend/__init__.py
@@ -0,0 +1,5 @@
+__all__ = [
+ 'PaddleBackend'
+]
+
+from .backend import PaddleBackend
diff --git a/fastNLP/core/metrics/backend/paddle_backend/backend.py b/fastNLP/core/metrics/backend/paddle_backend/backend.py
new file mode 100644
index 00000000..fc88bf10
--- /dev/null
+++ b/fastNLP/core/metrics/backend/paddle_backend/backend.py
@@ -0,0 +1,129 @@
+import os
+from typing import List, Any
+
+import numpy as np
+
+from fastNLP.core.metrics.backend import Backend
+from fastNLP.core.utils.paddle_utils import paddle_to, _convert_data_device, is_in_paddle_dist
+from fastNLP.core.metrics.utils import AggregateMethodError
+from fastNLP.core.drivers.paddle_driver.dist_utils import fastnlp_paddle_all_gather
+from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
+
+if _NEED_IMPORT_PADDLE:
+ import paddle
+ import paddle.distributed as dist
+ from paddle.fluid.dygraph import parallel_helper
+
+__all__ = []
+
+class PaddleBackend(Backend):
+ def __init__(self):
+ super().__init__()
+ self._specified = True
+
+ def aggregate(self, tensor, method: str):
+ """
+ 聚集结果,并根据 method 计算后,返回结果
+
+ :param tensor: 需要聚合的张量
+ :param method: 聚合的方法, 目前支持 ``['sum', 'mean', 'max', 'min']``:
+
+ * method 为 ``'sum'`` 时, 会将多张卡上聚合结果在维度为 `0` 上 累加起来。
+ * method 为 ``'mean'`` 时,会将多张卡上聚合结果在维度为 `0` 上取平均值。
+ * method 为 ``'max'`` 时,会将多张卡上聚合结果在维度为 `0` 上取最大值。
+ * method 为 ``'min'`` 时,会将多张卡上聚合结果在维度为 `0` 上取最小值。
+
+ """
+ if isinstance(tensor, paddle.Tensor):
+ if parallel_helper._is_parallel_ctx_initialized():
+ if method is None:
+ raise AggregateMethodError(should_have_aggregate_method=True)
+ tensor = self.all_gather_object(tensor)
+ if isinstance(tensor[0], paddle.Tensor):
+ tensor = paddle.stack(tensor)
+ # 第一步, aggregate结果
+ if method == 'sum':
+ tensor = paddle.sum(tensor, axis=0)
+ elif method == 'mean':
+ tensor = paddle.mean(tensor, axis=0)
+ elif method == 'max':
+ tensor, _ = paddle.max(tensor, axis=0)
+ elif method == 'min':
+ tensor, _ = paddle.min(tensor, axis=0)
+ else:
+ raise AggregateMethodError(should_have_aggregate_method=False)
+
+ return tensor
+
+ def create_tensor(self, value: float):
+ """
+ 创建 tensor,并且填入 value 作为值
+
+ :param value: 创建张量的初始值
+ """
+ tensor = paddle.ones((1,)).fill_(value)
+ return tensor
+
+ def fill_value(self, tensor, value: float):
+ """
+ 将 tensor 的值设置为 value
+
+ :param tensor: 传入的张量
+ :param value: 需要 fill 的值。
+ """
+ tensor.fill_(value)
+ return tensor
+
+ def get_scalar(self, tensor) -> float:
+ """
+ 获取 tensor 的 scalar 值
+
+ :param tensor: 传入的张量
+ """
+ return tensor.item()
+
+ def tensor2numpy(self, tensor) -> np.array:
+ """
+ 将 tensor 转为 numpy 值, 主要是在 metric 计算中使用
+
+ :param tensor: 传入的张量
+ """
+ if isinstance(tensor, paddle.Tensor):
+ return tensor.cpu().detach().numpy()
+ elif isinstance(tensor, np.array):
+ return tensor
+ elif isinstance(tensor, (float, int)):
+ return tensor
+ else:
+ raise ValueError(f"tensor: {tensor} can not convert to ndarray!")
+
+ @staticmethod
+ def is_distributed() -> bool:
+ """
+ 判断是否为 ddp 状态
+
+ :return:
+ """
+ return is_in_paddle_dist()
+
+ def move_tensor_to_device(self, tensor, device):
+ """
+ 将张量移到设备上
+
+ :param tensor: 需要移动的张量
+ :param device: 设备名, 一般为 "cpu", "cuda:0"等字符串
+ """
+ device = _convert_data_device(device)
+ return paddle_to(tensor, device)
+
+ def all_gather_object(self, obj, group=None) -> List:
+ """
+ 给定 obj 将各个 rank 上的 obj 汇总到每个 obj 上。返回一个 list 对象,里面依次为各个 rank 对应的 obj 。
+
+ :param obj:
+ :param group:
+ """
+ if self.is_distributed():
+ obj_list = fastnlp_paddle_all_gather(obj, group=group)
+ return obj_list
+ return [obj]
diff --git a/fastNLP/core/metrics/backend/torch_backend/__init__.py b/fastNLP/core/metrics/backend/torch_backend/__init__.py
new file mode 100644
index 00000000..0077d97c
--- /dev/null
+++ b/fastNLP/core/metrics/backend/torch_backend/__init__.py
@@ -0,0 +1,6 @@
+__all__ = [
+ 'TorchBackend'
+]
+
+
+from .backend import TorchBackend
diff --git a/fastNLP/core/metrics/backend/torch_backend/backend.py b/fastNLP/core/metrics/backend/torch_backend/backend.py
new file mode 100644
index 00000000..5a73a4a3
--- /dev/null
+++ b/fastNLP/core/metrics/backend/torch_backend/backend.py
@@ -0,0 +1,128 @@
+from typing import Any, List, Optional
+
+import numpy as np
+
+from fastNLP.core.metrics.backend import Backend
+from fastNLP.core.metrics.utils import AggregateMethodError
+from fastNLP.envs.imports import _NEED_IMPORT_TORCH
+from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather
+
+
+if _NEED_IMPORT_TORCH:
+ import torch
+ import torch.distributed as dist
+
+__all__ = []
+
+class TorchBackend(Backend):
+ def __init__(self):
+ super().__init__()
+ self._specified = True
+
+ def aggregate(self, tensor, method: str):
+ """
+ 聚集结果,并根据 method 计算后,返回结果
+
+ :param tensor: 需要聚合的张量
+ :param method: 聚合的方法, 目前支持 ``['sum', 'mean', 'max', 'min']``:
+
+ * method 为 ``'sum'`` 时, 会将多张卡上聚合结果在维度为 `0` 上 累加起来。
+ * method 为 ``'mean'`` 时,会将多张卡上聚合结果在维度为 `0` 上取平均值。
+ * method 为 ``'max'`` 时,会将多张卡上聚合结果在维度为 `0` 上取最大值。
+ * method 为 ``'min'`` 时,会将多张卡上聚合结果在维度为 `0` 上取最小值。
+
+ """
+ if isinstance(tensor, torch.Tensor):
+ if dist.is_initialized():
+ if method is None:
+ raise AggregateMethodError(should_have_aggregate_method=True)
+ tensor = self.all_gather_object(tensor)
+ if isinstance(tensor[0], torch.Tensor):
+ tensor = torch.stack(tensor)
+ # 第一步, aggregate结果
+ if method == 'sum':
+ tensor = torch.sum(tensor, dim=0)
+ elif method == 'mean':
+ tensor = torch.mean(tensor, dim=0)
+ elif method == 'max':
+ tensor, _ = torch.max(tensor, dim=0)
+ elif method == 'min':
+ tensor, _ = torch.min(tensor, dim=0)
+ else:
+ raise AggregateMethodError(should_have_aggregate_method=False)
+
+ return tensor
+
+ def create_tensor(self, value: float):
+ """
+ 创建 tensor,并且填入 value 作为值
+
+ :param value: 创建张量的初始值
+ """
+ tensor = torch.ones(1).fill_(value)
+ return tensor
+
+ def fill_value(self, tensor, value: float):
+ """
+ 将 tensor 的值设置为 value
+
+ :param tensor: 传入的张量
+ :param value: 需要 fill 的值。
+ """
+ tensor.fill_(value)
+ return tensor
+
+ def get_scalar(self, tensor) -> float:
+ """
+ 获取 tensor 的 scalar 值
+
+ :param tensor: 传入的张量
+ """
+ return tensor.item()
+
+ def tensor2numpy(self, tensor) -> np.array:
+ """
+ 将 tensor 转为 numpy 值, 主要是在 metric 计算中使用
+
+ :param tensor: 传入的张量
+ """
+
+ if isinstance(tensor, torch.Tensor):
+ return tensor.cpu().detach().numpy()
+ elif isinstance(tensor, np.ndarray):
+ return tensor
+ elif isinstance(tensor, (float, int)):
+ return tensor
+ else:
+ raise ValueError(f"tensor: {tensor} can not convert to ndarray!")
+
+ @staticmethod
+ def is_distributed() -> bool:
+ """
+ 判断是否为 ddp 状态
+
+ :return:
+ """
+ return dist.is_available() and dist.is_initialized()
+
+ def move_tensor_to_device(self, tensor, device):
+ """
+ 将张量移到设备上
+
+ :param tensor: 需要移动的张量
+ :param device: 设备名, 一般为 "cpu", "cuda:0"等字符串
+ """
+ return tensor.to(device)
+
+ def all_gather_object(self, obj, group=None) -> List:
+ """
+ 给定 obj 将各个 rank 上的 obj 汇总到每个 obj 上。返回一个 list 对象,里面依次为各个 rank 对应的 obj 。
+
+ :param obj:
+ :param group:
+ """
+ if self.is_distributed():
+ obj_list = fastnlp_torch_all_gather(obj, group=group)
+ return obj_list
+ return [obj]
+
diff --git a/fastNLP/core/metrics/classify_f1_pre_rec_metric.py b/fastNLP/core/metrics/classify_f1_pre_rec_metric.py
new file mode 100644
index 00000000..711729be
--- /dev/null
+++ b/fastNLP/core/metrics/classify_f1_pre_rec_metric.py
@@ -0,0 +1,177 @@
+__all__ = [
+ 'ClassifyFPreRecMetric'
+]
+
+from typing import Union, List
+from collections import Counter
+import numpy as np
+
+from .metric import Metric
+from .backend import Backend
+from fastNLP.core.vocabulary import Vocabulary
+from fastNLP.core.utils.seq_len_to_mask import seq_len_to_mask
+from .utils import _compute_f_pre_rec
+from fastNLP.core.log import logger
+
+class ClassifyFPreRecMetric(Metric):
+ """
+ 计算分类结果 **F值** 的 **Metric** 。
+
+ :param tag_vocab: 标签的 :class:`~fastNLP.core.Vocabulary` 。 默认值为 ``None``。若为 ``None`` 则使用数字来作为标签内容,
+ 否则使用 vocab 来作为标签内容
+ :param ignore_labels: :class:`str` 组成的 :class:`list`. 这个 :class:`list` 中的 class 不会被用于计算。例如在 POS tagging 时传入 ``['NN']``,
+ 则不会计算 'NN' 个 label
+ :param only_gross: 是否只计算总的 ``f1``, ``precision``, ``recall``的值;如果为 ``False``,不仅返回总的 ``f1``, ``pre``,
+ ``rec``, 还会返回每个 label 的 ``f1``, ``pre``, ``rec``
+ :param f_type: `micro` 或 `macro` 。
+
+ * `micro` : 通过先计算总体的 TP,FN 和 FP 的数量,再计算 f, precision, recall;
+ * `macro` : 分布计算每个类别的 f, precision, recall,然后做平均(各类别 f 的权重相同)
+
+ :param beta: **f_beta** 分数中的 ``beta`` 值。 常用为 ``beta=0.5, 1, 2`` 若为 0.5 则 **精确率** 的权重高于 **召回率** ;若为1,则两者平等;若为2,则
+ **召回率** 权重高于 **精确率** 。**f_beta** 分数的计算公式为:
+
+ .. math::
+
+ f_{beta} = \\frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}
+
+ :param backend: 目前支持五种类型的 backend, ``['torch', 'paddle', 'jittor', 'oneflow', 'auto']``。其中 ``'auto'`` 表示根据实际调用 :meth:`update`
+ 函数时传入的参数决定具体的 backend ,大部分情况下直接使用 ``'auto'`` 即可。
+ :param aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到 metric,
+ 当 backend 不支持分布式时,该参数无意义。如果为 ``None`` ,将在 :class:`~fastNLP.core.controllers.Evaluator` 中根据
+ sampler 是否使用分布式进行自动设置。
+ """
+ def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None,
+ only_gross: bool = True, f_type='micro', beta=1, backend: Union[str, Backend, None] = 'auto',
+ aggregate_when_get_metric: bool = None) -> None:
+ super(ClassifyFPreRecMetric, self).__init__(backend=backend,
+ aggregate_when_get_metric=aggregate_when_get_metric)
+ if f_type not in ('micro', 'macro'):
+ raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type))
+ if tag_vocab:
+ if not isinstance(tag_vocab, Vocabulary):
+ raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab)))
+ self.ignore_labels = ignore_labels
+ self.f_type = f_type
+ self.beta = beta
+ self.beta_square = self.beta ** 2
+ self.only_gross = only_gross
+
+ self.tag_vocab = tag_vocab
+
+ self._tp = Counter()
+ self._fp = Counter()
+ self._fn = Counter()
+
+ def reset(self):
+ """
+ 重置 ``tp``, ``fp``, ``fn`` 的值
+ """
+ # 由于不是 element 了,需要自己手动清零一下
+ self._tp.clear()
+ self._fp.clear()
+ self._fn.clear()
+
+ def get_metric(self) -> dict:
+ r"""
+ :meth:`get_metric` 函数将根据 :meth:`update` 函数累计的评价指标统计量来计算最终的评价结果。
+
+ :return: 包含以下内容的字典:``{"acc": float}``
+ """
+ evaluate_result = {}
+
+ # 通过 all_gather_object 将各个卡上的结果收集过来,并加和。
+ ls = self.all_gather_object([self._tp, self._fp, self._fn])
+ tps, fps, fns = zip(*ls)
+ _tp, _fp, _fn = Counter(), Counter(), Counter()
+ for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]):
+ for _c in cs:
+ c.update(_c)
+
+ if not self.only_gross or self.f_type == 'macro':
+ tags = set(_fn.keys())
+ tags.update(set(_fp.keys()))
+ tags.update(set(_tp.keys()))
+ f_sum = 0
+ pre_sum = 0
+ rec_sum = 0
+ for tag in tags:
+ if self.tag_vocab is not None:
+ tag_name = self.tag_vocab.to_word(tag)
+ else:
+ tag_name = int(tag)
+ tp = _tp[tag]
+ fn = _fn[tag]
+ fp = _fp[tag]
+ if tp == fn == fp == 0:
+ continue
+ f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp)
+ f_sum += f
+ pre_sum += pre
+ rec_sum += rec
+ if not self.only_gross and tag != '': # tag!=''防止无tag的情况
+ f_key = 'f-{}'.format(tag_name)
+ pre_key = 'pre-{}'.format(tag_name)
+ rec_key = 'rec-{}'.format(tag_name)
+ evaluate_result[f_key] = f
+ evaluate_result[pre_key] = pre
+ evaluate_result[rec_key] = rec
+
+ if self.f_type == 'macro':
+ evaluate_result['f'] = f_sum / len(tags)
+ evaluate_result['pre'] = pre_sum / len(tags)
+ evaluate_result['rec'] = rec_sum / len(tags)
+
+ if self.f_type == 'micro':
+ f, pre, rec = _compute_f_pre_rec(self.beta_square, sum(_tp.values()), sum(_fn.values()), sum(_fp.values()))
+ evaluate_result['f'] = f
+ evaluate_result['pre'] = pre
+ evaluate_result['rec'] = rec
+
+ for key, value in evaluate_result.items():
+ evaluate_result[key] = round(value, 6)
+
+ return evaluate_result
+
+ def update(self, pred, target, seq_len=None):
+ r"""
+ :meth:`update` 函数将针对一个批次的预测结果做评价指标的累计。
+
+ :param pred: 预测的 tensor, tensor 的形状可以是 ``[B,]`` 、``[B, n_classes]`` 、
+ ``[B, max_len]`` 或 ``[B, max_len, n_classes]``
+ :param target: 真实值的 tensor, tensor 的形状可以是 ``[B,]`` 、``[B,]`` 、``[B, max_len]``
+ 或 ``[B, max_len]``
+ :param seq_len: 序列长度标记, 标记的形状可以是 ``None``, 或者 ``[B]``
+
+ """
+ pred = self.tensor2numpy(pred)
+ target = self.tensor2numpy(target)
+ if seq_len is not None:
+ seq_len = self.tensor2numpy(seq_len)
+
+ if seq_len is not None and target.ndim > 1:
+ max_len = target.shape[-1]
+ masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len)
+ else:
+ masks = np.ones_like(target)
+
+ if pred.ndim == target.ndim:
+ if len(pred.flatten()) != len(target.flatten()):
+ raise RuntimeError(f"when pred have same dimensions with target, they should have same element numbers."
+ f" while target have element numbers:{len(pred.flatten())}, "
+ f"pred have element numbers: {len(target.flatten())}")
+
+ elif pred.ndim == target.ndim + 1:
+ pred = pred.argmax(axis=-1)
+ if seq_len is None and target.ndim > 1:
+ logger.warning("You are not passing `seq_len` to exclude pad when calculate accuracy.")
+ else:
+ raise RuntimeError(f"when pred have "
+ f"size:{pred.shape}, target should have size: {pred.shape} or "
+ f"{pred.shape[:-1]}, got {target.shape}.")
+
+ target_idxes = set(target.reshape(-1).tolist()+pred.reshape(-1).tolist())
+ for target_idx in target_idxes:
+ self._tp[target_idx] += ((pred == target_idx) * (target == target_idx) * masks).sum().item()
+ self._fp[target_idx] += ((pred == target_idx) * (target != target_idx) * masks).sum().item()
+ self._fn[target_idx] += ((pred != target_idx) * (target == target_idx) * masks).sum().item()
diff --git a/fastNLP/core/metrics/element.py b/fastNLP/core/metrics/element.py
new file mode 100644
index 00000000..359df3b8
--- /dev/null
+++ b/fastNLP/core/metrics/element.py
@@ -0,0 +1,315 @@
+__all__ = [
+ 'Element'
+]
+
+import os
+import functools
+
+from .backend import Backend, AutoBackend
+from fastNLP.core.log import logger
+from .utils import AggregateMethodError
+from fastNLP.envs.env import FASTNLP_GLOBAL_RANK
+
+
+def _wrap_cal_value(func):
+ @functools.wraps(func)
+ def _wrap_cal(*args, **kwargs):
+ self = args[0]
+ value = func(*args, **kwargs)
+ value = self.backend.get_scalar(value)
+ return value
+
+ return _wrap_cal
+
+
+class Element:
+ """
+ 保存 :class:`~fastNLP.core.metrics.Metric` 中计算的元素值的对象
+
+ :param name: 名称
+ :param value: 元素的值
+ :param aggregate_method: 聚合的方法, 目前支持 ``['sum', 'mean', 'max', 'min']``:
+
+ * method 为 ``'sum'`` 时, 会将多张卡上聚合结果在维度为 `0` 上 累加起来。
+ * method 为 ``'mean'`` 时,会将多张卡上聚合结果在维度为 `0` 上取平均值。
+ * method 为 ``'max'`` 时,会将多张卡上聚合结果在维度为 `0` 上取最大值。
+ * method 为 ``'min'`` 时,会将多张卡上聚合结果在维度为 `0` 上取最小值。
+
+ :param backend: 使用的 backend 。Element 的类型会根据 ``backend`` 进行实际的初始化。例如 ``backend`` 为 ``'torch'`` 则该对象为
+ :class:`torch.Tensor` ; 如果 ``'backend'`` 为 ``'paddle'`` 则该对象为 :class:`paddle.Tensor` ;如果 ``backend`` 为
+ ``'jittor'`` , 则该对象为 :class:`jittor.Var` 。一般情况下直接默认为 ``'auto'`` 就行了, **fastNLP** 会根据实际调用 :meth`Metric.update`
+ 函数时传入的参数进行合理的初始化,例如当传入的参数中只包含 :class:`torch.Tensor` 这一种 tensor 时(可以有其它非 tensor 类型的输入)
+ 则认为 ``backend`` 为 ``'torch'`` ;只包含 :class:`jittor.Var` 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 ``backend``
+ 为 ``'jittor'`` 。如果没有检测到任何一种 tensor ,就默认使用 :class:`float` 类型作为 element 。
+
+ """
+ def __init__(self, name, value: float, aggregate_method, backend: Backend):
+ self.name = name
+ self.init_value = value
+ self.aggregate_method = aggregate_method
+ if backend == 'auto':
+ raise RuntimeError(f"You have to specify the backend for Element:{self.name}.")
+ elif isinstance(backend, AutoBackend):
+ self.backend = backend
+ else:
+ self.backend = AutoBackend(backend)
+
+ if self.backend.is_specified():
+ value = self.backend.create_tensor(self.init_value)
+ else:
+ value = None
+ self._value = value
+ self.device = None
+
+ def aggregate(self):
+ """
+ 自动 aggregate 对应的元素
+ """
+ self._check_value_initialized()
+ if self.aggregate_method is None: # 如果没有 aggregate 则不进行聚合。
+ return
+ try:
+ self._value = self.backend.aggregate(self._value, self.aggregate_method)
+ except AggregateMethodError as e:
+ msg = 'If you see this message, please report a bug.'
+ if self.name and e.should_have_aggregate_method:
+ msg = f"Element:{self.name} has no specified `aggregate_method`."
+ elif self.name and not e.should_have_aggregate_method:
+ msg = f"Element:{self.name}'s backend:{self.backend.__class__.__name__} does not support " \
+ f'aggregate_method:{self.aggregate_method}.'
+ if e.only_warn:
+ if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
+ logger.warning(msg)
+ self._value = self.backend.aggregate(self._value, method=None)
+ else:
+ raise RuntimeError(msg)
+
+ def reset(self):
+ """
+ 重置 value
+ """
+ if self.backend.is_specified() and self._value is not None:
+ self._value = self.backend.fill_value(self._value, self.init_value)
+
+ @property
+ def value(self):
+ return self._value
+
+ @value.setter
+ def value(self, value):
+ self._check_value_initialized()
+ self._value = value
+
+ @value.getter
+ def value(self):
+ self._check_value_initialized()
+ return self._value
+
+ def get_scalar(self) -> float:
+ """
+ 获取元素的 scalar 值
+
+ """
+ self._check_value_initialized()
+ return self.backend.get_scalar(self._value)
+
+ def fill_value(self, value):
+ """
+ 对元素进行 :meth:`fill_value` , 会执行对应 backend 的 :meth:`fill_value` 方法
+
+ """
+ self._check_value_initialized()
+ self._value = self.backend.fill_value(self._value, value)
+
+ def to(self, device):
+ """
+ 将元素移到某个设备上
+
+ :param device: 设备名, 一般为 ``"cpu"``, ``"cuda:0"`` 等
+ """
+ # device这里如何处理呢?
+ if self._value is not None:
+ self._value = self.backend.move_tensor_to_device(self._value, device)
+ self.device = device
+
+ def _check_value_initialized(self):
+ """
+ 检查 Element 的 value 是否初始化了
+ """
+ if self._value is None:
+ assert self.backend.is_specified(), f"Backend is not specified, please specify backend in the Metric " \
+ f"initialization."
+ self._value = self.backend.create_tensor(self.init_value)
+ if self.device is not None:
+ self.to(device=self.device)
+
+ def _check_value_when_call(self):
+ if self.value is None:
+ prefix = f'Element:`{self.name}`'
+ raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this "
+ "element, or use it after it being used by the `Metric.update()` method.")
+
+ @_wrap_cal_value
+ def __add__(self, other):
+ self._check_value_when_call()
+ if isinstance(other, Element):
+ other = other.value
+ return self.value + other
+
+ @_wrap_cal_value
+ def __radd__(self, other):
+ self._check_value_when_call()
+ if isinstance(other, Element):
+ other = other.value
+ return self.value + other
+
+ @_wrap_cal_value
+ def __sub__(self, other):
+ self._check_value_when_call()
+ if isinstance(other, Element):
+ other = other.value
+ return self.value - other
+
+ @_wrap_cal_value
+ def __rsub__(self, other):
+ self._check_value_when_call()
+ if isinstance(other, Element):
+ other = other.value
+ return self.value - other
+
+ @_wrap_cal_value
+ def __mul__(self, other):
+ self._check_value_when_call()
+ if isinstance(other, Element):
+ other = other.value
+ return self.value * other
+
+ @_wrap_cal_value
+ def __imul__(self, other):
+ self._check_value_when_call()
+ if isinstance(other, Element):
+ other = other.value
+ return self.value * other
+
+ @_wrap_cal_value
+ def __floordiv__(self, other):
+ self._check_value_when_call()
+ if isinstance(other, Element):
+ other = other.value
+ return self.value // other
+
+ @_wrap_cal_value
+ def __rfloordiv__(self, other):
+ self._check_value_when_call()
+ if isinstance(other, Element):
+ other = other.value
+ return self.value // other
+
+ @_wrap_cal_value
+ def __truediv__(self, other):
+ self._check_value_when_call()
+ if isinstance(other, Element):
+ other = other.value
+ return self.value / other
+
+ @_wrap_cal_value
+ def __rtruediv__(self, other):
+ self._check_value_when_call()
+ if isinstance(other, Element):
+ other = other.value
+ return self.value / other
+
+ @_wrap_cal_value
+ def __mod__(self, other):
+ self._check_value_when_call()
+ if isinstance(other, Element):
+ other = other.value
+ return self.value % other
+
+ @_wrap_cal_value
+ def __rmod__(self, other):
+ self._check_value_when_call()
+ if isinstance(other, Element):
+ other = other.value
+ return self.value % other
+
+ @_wrap_cal_value
+ def __pow__(self, other, modulo=None):
+ self._check_value_when_call()
+ if isinstance(other, Element):
+ other = other.value
+ if modulo is None:
+ return self.value ** other
+ else:
+ return pow(self.value, other, modulo)
+
+ @_wrap_cal_value
+ def __rpow__(self, other):
+ self._check_value_when_call()
+ if isinstance(other, Element):
+ other = other.value
+ return self.value ** other
+
+ @_wrap_cal_value
+ def __lt__(self, other) -> bool:
+ self._check_value_when_call()
+ if isinstance(other, Element):
+ other = other.value
+ return self.value < other
+
+ @_wrap_cal_value
+ def __le__(self, other) -> bool:
+ self._check_value_when_call()
+ if isinstance(other, Element):
+ other = other.value
+ return self.value <= other
+
+ @_wrap_cal_value
+ def __eq__(self, other):
+ self._check_value_when_call()
+ if isinstance(other, Element):
+ other = other.value
+ return self.value == other
+
+ @_wrap_cal_value
+ def __ne__(self, other) -> bool:
+ self._check_value_when_call()
+ if isinstance(other, Element):
+ other = other.value
+ return self.value != other
+
+ @_wrap_cal_value
+ def __ge__(self, other) -> bool:
+ self._check_value_when_call()
+ if isinstance(other, Element):
+ other = other.value
+ return self.value >= other
+
+ @_wrap_cal_value
+ def __gt__(self, other) -> bool:
+ self._check_value_when_call()
+ if isinstance(other, Element):
+ other = other.value
+ return self.value > other
+
+ def __str__(self):
+ return str(self.value)
+
+ def __repr__(self):
+ return str(self.value)
+
+ def __getattr__(self, item):
+ """
+ 为 FDataLoader 提供 dataset 的方法和属性,实现该方法后,用户可以在 FDataLoader 实例化后使用 apply 等 dataset 的方法
+ :param item:
+ :return:
+ """
+ try:
+ if self._value is None:
+ prefix = f'Element:`{self.name}`'
+ raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this "
+ "element, or use it after it being used by the `Metric.update()` method.")
+ return getattr(self._value, item)
+ except AttributeError as e:
+ logger.error(f"Element:{self.name} has no `{item}` attribute.")
+ raise e
diff --git a/fastNLP/core/metrics/metric.py b/fastNLP/core/metrics/metric.py
new file mode 100644
index 00000000..c3d23bef
--- /dev/null
+++ b/fastNLP/core/metrics/metric.py
@@ -0,0 +1,211 @@
+__all__ = [
+ 'Metric'
+]
+
+from abc import abstractmethod
+
+from typing import Union, List
+import functools
+from contextlib import contextmanager
+import numpy as np
+
+from fastNLP.core.metrics.backend import Backend, AutoBackend
+from fastNLP.core.metrics.element import Element
+
+
+class Metric:
+ """
+ **fastNLP** 中 :class:`Metric` 的基类,自定义 :class:`Metric` 时,请继承该对象。使用该对象,将有助于减少在分布式状态下的 Metric 计算。
+
+ :param backend: 目前支持五种类型的 backend, ``['torch', 'paddle', 'jittor', 'oneflow', 'auto']``。其中 ``'auto'`` 表示根据实际调用 :meth:`update`
+ 函数时传入的参数决定具体的 backend ,大部分情况下直接使用 ``'auto'`` 即可。
+ :param aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到 metric,
+ 当 backend 不支持分布式时,该参数无意义。如果为 ``None`` ,将在 :class:`~fastNLP.core.controllers.Evaluator` 中根据
+ sampler 是否使用分布式进行自动设置。
+ """
+ def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None):
+ self.backend = AutoBackend(backend)
+ self._updated = False
+ self.get_metric = self._sync_get_metric(self.get_metric)
+ self.update = self._wrap_update(self.update)
+ self.reset = self._wrap_auto_reset_elements(self.reset)
+ self.aggregate_when_get_metric = aggregate_when_get_metric
+ self._cannot_change_element = False
+ self._elements = {}
+
+ @property
+ def elements(self) -> dict:
+ return self._elements
+
+ def register_element(self, name, value: float = 0, aggregate_method=None, backend='auto') -> Element:
+ """
+ 注册一个 element 对象,注册之后便可以通过在 Metric 中直接通过 ``self.{name}`` 进行调用,可以认为该对象即为对应 backend 的
+ tensor 直接进行加减乘除计算即可。
+
+ .. warning::
+
+ 如果想使得该 metric 可自动扩展到多卡的情况,请一定申明 ``aggregate_method`` 。
+
+ :param name: 当前 element 的名字,注册后,在 Metric 中可以通过 ``self.{name}`` 访问该变量。
+ :param value: 初始化的值。在调用 :meth:`Metric.reset` 方法时也将自动设置为该值
+ :param aggregate_method: 如何聚合多卡上的结果,如果为单卡执行,该值无意义。如果设置为 None 则表示该 element 不进行聚合。
+ :param backend: 使用的 backend 。Element 的类型会根据 ``backend`` 进行实际的初始化。例如 ``backend`` 为 ``'torch'`` 则该对象为
+ :class:`torch.Tensor` ; 如果 ``'backend'`` 为 ``'paddle'`` 则该对象为 :class:`paddle.Tensor` ;如果 ``backend`` 为
+ ``'jittor'`` , 则该对象为 :class:`jittor.Var` 。一般情况下直接默认为 ``'auto'`` 就行了, **fastNLP** 会根据实际调用 :meth`Metric.update`
+ 函数时传入的参数进行合理的初始化,例如当传入的参数中只包含 :class:`torch.Tensor` 这一种 tensor 时(可以有其它非 tensor 类型的输入)
+ 则认为 ``backend`` 为 ``'torch'`` ;只包含 :class:`jittor.Var` 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 ``backend``
+ 为 ``'jittor'`` 。如果没有检测到任何一种 tensor ,就默认使用 :class:`float` 类型作为 element 。
+ :return: 注册的 Element 对象
+ """
+ if backend == 'auto':
+ backend = self.backend
+ else:
+ backend = AutoBackend(backend)
+
+ assert name is not None and name not in self.elements
+
+ element = Element(name=name, value=value, aggregate_method=aggregate_method, backend=backend)
+ self.elements[name] = element
+ setattr(self, name, element)
+ return element
+
+ def reset(self):
+ """
+ 在对每个 ``evaluate_dataloaders`` 遍历进行验证之前,:meth:`reset` 函数会被调用来重置每个非 element 对象;
+ 如果有非 element 的对象需要重置的时候,在本方法中写下非 element 的重置方式。注册的 element 对象则会自动 reset 为初始值。
+ """
+ pass
+
+ def _wrap_auto_reset_elements(self, reset):
+ @functools.wraps(reset)
+ def _wrap_reset(*args, **kwargs):
+ self._updated = False
+ for ele in self.elements.values():
+ ele.reset()
+ reset(*args, **kwargs)
+
+ return _wrap_reset
+
+ def _sync_get_metric(self, get_metric):
+ @functools.wraps(get_metric)
+ def _wrap_get_metric(*args, **kwargs):
+ assert self._updated, f"You have to call `{self.__class__.__name__}'s update() function before calling " \
+ f"get_metric()."
+ with self.sync(recover=True, aggregate=self.aggregate_when_get_metric):
+ results = get_metric(*args, **kwargs)
+ return results
+
+ return _wrap_get_metric
+
+ def __setattr__(self, key, value):
+ if getattr(self, '_cannot_change_element', False):
+ if key in self.elements and isinstance(value, (float, int, bool)):
+ self.elements[key].fill_value(value)
+ return
+ elif key in self.elements:
+ raise TypeError(f"self.{key} is an Element, only float/int/bool type value can be assigned to it, "
+ f"instead of {type(value)}.")
+ if isinstance(value, Element) and key not in self.elements:
+ raise RuntimeError("Please use register_element() function to add Element.")
+ attrs = self.__dict__
+ if key in attrs and isinstance(value, Element):
+ raise RuntimeError(f'`{key}` has been registered as an attribute, cannot be registered as an Element!')
+ object.__setattr__(self, key, value)
+
+ # 当调用 __getattribute__ 没有找到时才会触发这个, 保留这个的目的只是为了防止 ide 的 warning
+ def __getattr__(self, name: str) -> Element:
+ if 'elements' in self.__dict__:
+ elements = self.__dict__['elements']
+ if name in elements:
+ return elements[name]
+ raise AttributeError("`{}` object has no attribute `{}`.".format(type(self).__name__, name))
+
+ def _wrap_update(self, update):
+ @functools.wraps(update)
+ def _wrap_update(*args, **kwargs):
+ self.check_backend(*args, **kwargs)
+ self._cannot_change_element = True
+ self._updated = True
+ return update(*args, **kwargs)
+
+ return _wrap_update
+
+ def check_backend(self, *args, **kwargs):
+ """
+ 根据传入的参数的类型选择当前需要的 backend
+ """
+ if not self.backend.is_specified():
+ _args = []
+ for arg in args:
+ _args.append(arg)
+ for arg in kwargs.values():
+ _args.append(arg)
+ self.backend.choose_real_backend(_args)
+
+ @contextmanager
+ def sync(self, recover=True, aggregate=False):
+ """
+ 在这个上下文下, :meth:`Metric` 会自动先同步需要同步操作的 element 。当 ``recover`` 为 ``True`` 时,在退出环境的时候,会重新将 element 的
+ 值恢复到计算前的值。
+ """
+ keep_value = {}
+ if aggregate:
+ for name, element in self.elements.items():
+ # 保存过去的值
+ keep_value[name] = element.get_scalar()
+ # 聚合结果
+ element.aggregate()
+
+ yield
+
+ if recover and aggregate:
+ for name, element in self.elements.items():
+ # 恢复结果
+ if name in keep_value:
+ element.fill_value(value=keep_value.get(name))
+
+ @abstractmethod
+ def update(self, *args, **kwargs):
+ raise NotImplementedError()
+
+ @abstractmethod
+ def get_metric(self) -> dict:
+ raise NotImplementedError()
+
+ def set_auto_aggregate_when_get_metric(self, flag: bool):
+ """
+ 设置是否在 :meth:`get_metric` 的时候自动 aggregate
+
+ """
+ self.aggregate_when_get_metric = flag
+
+ def tensor2numpy(self, tensor) -> np.array:
+ """
+ 将 ``tensor`` 向量转为 :class:`numpy.array` 类型变量。
+
+ :param tensor:
+ :return:
+ """
+ return self.backend.tensor2numpy(tensor)
+
+ def to(self, device):
+ """
+ 将所有的 element 变量移动到 ``device`` 设备上
+
+ :param device:
+ :return:
+ """
+ for element in self.elements.values():
+ element.to(device)
+
+ def all_gather_object(self, obj, group=None)->List:
+ """
+ 给定 ``obj`` 将各个 rank 上的 ``obj`` 汇总到每个 ``obj`` 上。返回一个 list 对象,里面依次为各个 rank 对应的 ``obj`` 。
+
+ :param obj: 需要汇总的对象,必须是个 pickable 的对象。
+ :param group:
+ :return: -> List[obj0, obj1, ...] 其中 obj0 是rank 0 上的 obj;obj1 是 rank 1 上的 obj...
+ """
+ if self.aggregate_when_get_metric:
+ return self.backend.all_gather_object(obj, group=group)
+ return [obj]
\ No newline at end of file
diff --git a/fastNLP/core/metrics/span_f1_pre_rec_metric.py b/fastNLP/core/metrics/span_f1_pre_rec_metric.py
new file mode 100644
index 00000000..a3c5a722
--- /dev/null
+++ b/fastNLP/core/metrics/span_f1_pre_rec_metric.py
@@ -0,0 +1,376 @@
+__all__ = [
+ 'SpanFPreRecMetric'
+]
+
+from typing import Union, List, Optional
+from collections import Counter
+
+from fastNLP.core.metrics.backend import Backend
+from fastNLP.core.metrics.metric import Metric
+from fastNLP.core.vocabulary import Vocabulary
+from fastNLP.core.log import logger
+from .utils import _compute_f_pre_rec
+
+
+def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encoding_type: str):
+ r"""
+ 检查vocab中的tag是否与encoding_type是匹配的
+
+ :param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。
+ :param encoding_type: bio, bmes, bioes, bmeso
+ :return:
+ """
+ tag_set = set()
+ unk_token = ''
+ pad_token = ''
+ if isinstance(tag_vocab, Vocabulary):
+ unk_token = tag_vocab.unknown
+ pad_token = tag_vocab.padding
+ tag_vocab = tag_vocab.idx2word
+ for idx, tag in tag_vocab.items():
+ if tag in (unk_token, pad_token):
+ continue
+ tag = tag[:1].lower()
+ tag_set.add(tag)
+
+ tags = encoding_type
+ for tag in tag_set:
+ assert tag in tags, f"{tag} is not a valid tag in encoding type:{encoding_type}. Please check your " \
+ f"encoding_type."
+ tags = tags.replace(tag, '') # 删除该值
+ if tags: # 如果不为空,说明出现了未使用的tag
+ logger.warning(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your "
+ "encoding_type.")
+
+
+def _get_encoding_type_from_tag_vocab(tag_vocab: Union[Vocabulary, dict]) -> str:
+ r"""
+ 给定 Vocabular y自动判断是哪种类型的 encoding, 支持判断 bmes, bioes, bmeso, bio
+
+ :param tag_vocab: 支持传入 tag Vocabulary; 或者传入形如 {0:"O", 1:"B-tag1"},即 index 在前,tag 在后的 dict。
+ :return:
+ """
+ tag_set = set()
+ unk_token = ''
+ pad_token = ''
+ if isinstance(tag_vocab, Vocabulary):
+ unk_token = tag_vocab.unknown
+ pad_token = tag_vocab.padding
+ tag_vocab = tag_vocab.idx2word
+ for idx, tag in tag_vocab.items():
+ if tag in (unk_token, pad_token):
+ continue
+ tag = tag[:1].lower()
+ tag_set.add(tag)
+
+ bmes_tag_set = set('bmes')
+ if tag_set == bmes_tag_set:
+ return 'bmes'
+ bio_tag_set = set('bio')
+ if tag_set == bio_tag_set:
+ return 'bio'
+ bmeso_tag_set = set('bmeso')
+ if tag_set == bmeso_tag_set:
+ return 'bmeso'
+ bioes_tag_set = set('bioes')
+ if tag_set == bioes_tag_set:
+ return 'bioes'
+ raise RuntimeError("encoding_type cannot be inferred automatically. Only support "
+ "'bio', 'bmes', 'bmeso', 'bioes' type.")
+
+
+def _bmes_tag_to_spans(tags, ignore_labels=None):
+ r"""
+ 给定一个 tags 的 lis,比如 ['S-song', 'B-singer', 'M-singer', 'E-singer', 'S-moive', 'S-actor']。
+ 返回 [('song', (0, 1)), ('singer', (1, 4)), ('moive', (4, 5)), ('actor', (5, 6))] (左闭右开区间)
+ 也可以是单纯的 ['S', 'B', 'M', 'E', 'B', 'M', 'M',...]序列
+
+ :param tags: List[str],
+ :param ignore_labels: List[str], 在该list中的label将被忽略
+ :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
+ """
+ ignore_labels = set(ignore_labels) if ignore_labels else set()
+
+ spans = []
+ prev_bmes_tag = None
+ for idx, tag in enumerate(tags):
+ tag = tag.lower()
+ bmes_tag, label = tag[:1], tag[2:]
+ if bmes_tag in ('b', 's'):
+ spans.append((label, [idx, idx]))
+ elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]:
+ spans[-1][1][1] = idx
+ else:
+ spans.append((label, [idx, idx]))
+ prev_bmes_tag = bmes_tag
+ return [(span[0], (span[1][0], span[1][1] + 1))
+ for span in spans
+ if span[0] not in ignore_labels
+ ]
+
+
+def _bmeso_tag_to_spans(tags, ignore_labels=None):
+ r"""
+ 给定一个 tag s的 lis,比如 ['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。
+ 返回 [('singer', (1, 4))] (左闭右开区间)
+
+ :param tags: List[str],
+ :param ignore_labels: List[str], 在该list中的label将被忽略
+ :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
+ """
+ ignore_labels = set(ignore_labels) if ignore_labels else set()
+
+ spans = []
+ prev_bmes_tag = None
+ for idx, tag in enumerate(tags):
+ tag = tag.lower()
+ bmes_tag, label = tag[:1], tag[2:]
+ if bmes_tag in ('b', 's'):
+ spans.append((label, [idx, idx]))
+ elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]:
+ spans[-1][1][1] = idx
+ elif bmes_tag == 'o':
+ pass
+ else:
+ spans.append((label, [idx, idx]))
+ prev_bmes_tag = bmes_tag
+ return [(span[0], (span[1][0], span[1][1] + 1))
+ for span in spans
+ if span[0] not in ignore_labels
+ ]
+
+
+def _bioes_tag_to_spans(tags, ignore_labels=None):
+ r"""
+ 给定一个 tags 的 lis,比如 ['O', 'B-singer', 'I-singer', 'E-singer', 'O', 'O']。
+ 返回 [('singer', (1, 4))] (左闭右开区间)
+
+ :param tags: List[str],
+ :param ignore_labels: List[str], 在该list中的label将被忽略
+ :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
+ """
+ ignore_labels = set(ignore_labels) if ignore_labels else set()
+
+ spans = []
+ prev_bioes_tag = None
+ for idx, tag in enumerate(tags):
+ tag = tag.lower()
+ bioes_tag, label = tag[:1], tag[2:]
+ if bioes_tag in ('b', 's'):
+ spans.append((label, [idx, idx]))
+ elif bioes_tag in ('i', 'e') and prev_bioes_tag in ('b', 'i') and label == spans[-1][0]:
+ spans[-1][1][1] = idx
+ elif bioes_tag == 'o':
+ pass
+ else:
+ spans.append((label, [idx, idx]))
+ prev_bioes_tag = bioes_tag
+ return [(span[0], (span[1][0], span[1][1] + 1))
+ for span in spans
+ if span[0] not in ignore_labels
+ ]
+
+
+def _bio_tag_to_spans(tags, ignore_labels=None):
+ r"""
+ 给定一个 tags 的 lis,比如 ['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。
+ 返回 [('singer', (1, 4))] (左闭右开区间)
+
+ :param tags: List[str],
+ :param ignore_labels: List[str], 在该list中的label将被忽略
+ :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
+ """
+ ignore_labels = set(ignore_labels) if ignore_labels else set()
+
+ spans = []
+ prev_bio_tag = None
+ for idx, tag in enumerate(tags):
+ tag = tag.lower()
+ bio_tag, label = tag[:1], tag[2:]
+ if bio_tag == 'b':
+ spans.append((label, [idx, idx]))
+ elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label == spans[-1][0]:
+ spans[-1][1][1] = idx
+ elif bio_tag == 'o': # o tag does not count
+ pass
+ else:
+ spans.append((label, [idx, idx]))
+ prev_bio_tag = bio_tag
+ return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels]
+
+
+class SpanFPreRecMetric(Metric):
+ r"""
+ 在 **序列标注** 任务中评估抽取结果匹配度的 **Metric** 。
+
+ :param tag_vocab: 标签的 :class:`~fastNLP.core.Vocabulary` 。支持的标签有 ``"B"`` (没有label);或 ``"B-xxx"`` ( ``xxx`` 为某种 label ,比如 POS 中的 NN),
+ 在解码时,会将相同 ``xxx`` 的认为是同一个 label ,比如 ['B-NN', 'E-NN'] 会被合并为一个 'NN' 。
+ :param encoding_type: 目前支持 ``['bio', 'bmes', 'bmeso', 'bioes', None]`` 。默认为 ``None`` ,通过 ``tag_vocab`` 自动判断
+ :param ignore_labels: 字符串组成的列表,这个列表中包含的内容不会被用于计算。例如在 *POS tagging* 时传入 ``['NN']`` ,则不会计算 ``'NN'`` 这个 ``label``
+ :param only_gross: 是否只计算总的 ``f1``, ``precision`` , ``recall`` 的值。如果为 ``False`` ,不仅返回总的 ``f1`` , ``pre`` , ``rec`` , 还会返回每个 label
+ ``f1`` , ``pre`` , ``rec``
+ :param f_type: ``'micro'`` 或 ``'macro'``。
+
+ - *micro* -- 通过先计算总体的 ``TP``, ``FN`` 和 ``FP`` 的数量,再计算 ``f``, ``precision``, ``recall`` ;
+ - *macro* -- 分别计算每个类别的 ``f`` , ``precision`` , ``recall`` ,然后做平均(各类别 ``f`` 的权重相同);
+
+ :param beta: **f_beta** 分数中的 ``beta`` 值。 常用为 ``beta=0.5, 1, 2`` 若为 0.5 则 **精确率** 的权重高于 **召回率** ;若为1,则两者平等;若为2,则
+ **召回率** 权重高于 **精确率** 。**f_beta** 分数的计算公式为:
+
+ .. math::
+
+ f_{beta} = \\frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}
+
+ :param backend: 目前支持五种类型的 backend, ``['torch', 'paddle', 'jittor', 'oneflow', 'auto']``。其中 ``'auto'`` 表示根据实际调用 :meth:`update`
+ 函数时传入的参数决定具体的 backend ,大部分情况下直接使用 ``'auto'`` 即可。
+ :param aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到 metric,
+ 当 backend 不支持分布式时,该参数无意义。如果为 ``None`` ,将在 :class:`~fastNLP.core.controllers.Evaluator` 中根据
+ sampler 是否使用分布式进行自动设置。
+ """
+ def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None,
+ only_gross: bool = True, f_type='micro',
+ beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None) -> None:
+ super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric)
+ if f_type not in ('micro', 'macro'):
+ raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type))
+ if not isinstance(tag_vocab, Vocabulary):
+ raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab)))
+ if encoding_type:
+ encoding_type = encoding_type.lower()
+ _check_tag_vocab_and_encoding_type(tag_vocab, encoding_type)
+ self.encoding_type = encoding_type
+ else:
+ self.encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab)
+
+ if self.encoding_type == 'bmes':
+ self.tag_to_span_func = _bmes_tag_to_spans
+ elif self.encoding_type == 'bio':
+ self.tag_to_span_func = _bio_tag_to_spans
+ elif self.encoding_type == 'bmeso':
+ self.tag_to_span_func = _bmeso_tag_to_spans
+ elif self.encoding_type == 'bioes':
+ self.tag_to_span_func = _bioes_tag_to_spans
+ else:
+ raise ValueError("Only support 'bio', 'bmes', 'bmeso', 'bioes' type.")
+
+ self.ignore_labels = ignore_labels
+ self.f_type = f_type
+ self.beta = beta
+ self.beta_square = self.beta ** 2
+ self.only_gross = only_gross
+ self.tag_vocab = tag_vocab
+
+ self._tp = Counter()
+ self._fp = Counter()
+ self._fn = Counter()
+
+ def reset(self):
+ """
+ 重置所有元素
+ """
+ self._tp.clear()
+ self._fp.clear()
+ self._fn.clear()
+
+ def get_metric(self) -> dict:
+ """
+ :meth:`get_metric` 函数将根据 :meth:`update` 函数累计的评价指标统计量来计算最终的评价结果。
+ """
+ evaluate_result = {}
+
+ # 通过 all_gather_object 将各个卡上的结果收集过来,并加和。
+ ls = self.all_gather_object([self._tp, self._fp, self._fn])
+ tps, fps, fns = zip(*ls)
+ _tp, _fp, _fn = Counter(), Counter(), Counter()
+ for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]):
+ for _c in cs:
+ c.update(_c)
+
+ if not self.only_gross or self.f_type == 'macro':
+ tags = set(_fn.keys())
+ tags.update(_fp.keys())
+ tags.update(_tp.keys())
+ f_sum = 0
+ pre_sum = 0
+ rec_sum = 0
+ for tag in tags:
+ tp = _tp[tag]
+ fn = _fn[tag]
+ fp = _fp[tag]
+ if tp == fn == fp == 0:
+ continue
+
+ f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp)
+ f_sum += f
+ pre_sum += pre
+ rec_sum += rec
+ if not self.only_gross and tag != '': # tag!=''防止无tag的情况
+ f_key = 'f-{}'.format(tag)
+ pre_key = 'pre-{}'.format(tag)
+ rec_key = 'rec-{}'.format(tag)
+ evaluate_result[f_key] = f
+ evaluate_result[pre_key] = pre
+ evaluate_result[rec_key] = rec
+
+ if self.f_type == 'macro':
+ evaluate_result['f'] = f_sum / len(tags)
+ evaluate_result['pre'] = pre_sum / len(tags)
+ evaluate_result['rec'] = rec_sum / len(tags)
+
+ if self.f_type == 'micro':
+ f, pre, rec = _compute_f_pre_rec(self.beta_square, sum(_tp.values()), sum(_fn.values()), sum(_fp.values()))
+ evaluate_result['f'] = f
+ evaluate_result['pre'] = pre
+ evaluate_result['rec'] = rec
+
+ for key, value in evaluate_result.items():
+ evaluate_result[key] = round(value, 6)
+
+ return evaluate_result
+
+ def update(self, pred, target, seq_len) -> None:
+ r"""
+ :meth:`update` 函数将针对一个批次的预测结果做评价指标的累计。
+
+ :param pred: 预测的结果,大小为 ``[batch, seq_len]`` 或者 ``[batch, seq_len, len(tag_vocab)]``
+ :param target: 真实值,大小为 ``[batch, seq_len]``
+ :param seq_len: 文本长度标记,大小为 ``[batch]``
+ :return:
+ """
+ pred = self.tensor2numpy(pred)
+ target = self.tensor2numpy(target)
+
+ if pred.ndim == target.ndim and target.ndim == 2:
+ pass
+
+ elif pred.ndim == target.ndim + 1 and target.ndim == 2:
+ num_classes = pred.shape[-1]
+ pred = pred.argmax(axis=-1)
+ if (target >= num_classes).any():
+ raise ValueError("A gold label passed to SpanBasedF1Metric contains an "
+ "id >= {}, the number of classes.".format(num_classes))
+ else:
+ raise RuntimeError(f"when pred have size:{pred.ndim}, target should have size: {pred.ndim} or "
+ f"{pred.shape[:-1]}, got {target.ndim}.")
+
+ batch_size = pred.shape[0]
+ pred = pred.tolist()
+ target = target.tolist()
+ for i in range(batch_size):
+ pred_tags = pred[i][:int(seq_len[i])]
+ gold_tags = target[i][:int(seq_len[i])]
+
+ pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags]
+ gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags]
+
+ pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels)
+ gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels)
+
+ for span in pred_spans:
+ if span in gold_spans:
+ self._tp[span[0]] += 1
+ gold_spans.remove(span)
+ else:
+ self._fp[span[0]] += 1
+ for span in gold_spans:
+ self._fn[span[0]] += 1
diff --git a/fastNLP/core/metrics/utils.py b/fastNLP/core/metrics/utils.py
new file mode 100644
index 00000000..6d3fd74a
--- /dev/null
+++ b/fastNLP/core/metrics/utils.py
@@ -0,0 +1,79 @@
+__all__ = [
+]
+
+from typing import Any
+from functools import wraps
+from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH
+from fastNLP.envs.utils import _module_available
+
+_IS_TORCHMETRICS_AVAILABLE = _module_available('torchmetrics')
+_IS_ALLENNLP_AVAILABLE = _module_available('allennlp')
+if _IS_ALLENNLP_AVAILABLE:
+ from allennlp.training.metrics import Metric as allennlp_Metric
+
+if _IS_TORCHMETRICS_AVAILABLE:
+ from torchmetrics import Metric as torchmetrics_Metric
+
+if _NEED_IMPORT_PADDLE:
+ from paddle.metric import Metric as paddle_Metric
+
+
+def _is_torchmetrics_metric(metric: Any) -> bool:
+ """
+ 检查输入的对象是否为torchmetrics对象
+
+ :param metric:
+ :return:
+ """
+ if _IS_TORCHMETRICS_AVAILABLE:
+ return isinstance(metric, torchmetrics_Metric)
+ else:
+ return False
+
+
+def _is_allennlp_metric(metric: Any) -> bool:
+ """
+ 检查输入的对象是否为allennlp对象
+
+ :param metric:
+ :return:
+ """
+ if _IS_ALLENNLP_AVAILABLE:
+ return isinstance(metric, allennlp_Metric)
+ else:
+ return False
+
+
+def _is_paddle_metric(metric: Any) -> bool:
+ """
+ 检查输入的对象是否为allennlp对象
+
+ :param metric:
+ :return:
+ """
+ if _NEED_IMPORT_PADDLE:
+ return isinstance(metric, paddle_Metric)
+ else:
+ return False
+
+
+class AggregateMethodError(BaseException):
+ def __init__(self, should_have_aggregate_method, only_warn=False):
+ super(AggregateMethodError, self).__init__(self)
+ self.should_have_aggregate_method = should_have_aggregate_method
+ self.only_warn = only_warn
+
+
+def _compute_f_pre_rec(beta_square, tp, fn, fp):
+ r"""
+
+ :param tp: int, true positive
+ :param fn: int, false negative
+ :param fp: int, false positive
+ :return: (f, pre, rec)
+ """
+ pre = tp / (fp + tp + 1e-13)
+ rec = tp / (fn + tp + 1e-13)
+ f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13)
+
+ return f, pre, rec
diff --git a/fastNLP/core/optimizer.py b/fastNLP/core/optimizer.py
deleted file mode 100644
index 8c53176a..00000000
--- a/fastNLP/core/optimizer.py
+++ /dev/null
@@ -1,227 +0,0 @@
-r"""
-optimizer 模块定义了 fastNLP 中所需的各种优化器,一般做为 :class:`~fastNLP.Trainer` 的参数使用。
-
-"""
-__all__ = [
- "Optimizer",
- "SGD",
- "Adam",
- "AdamW"
-]
-
-import math
-
-import torch
-from torch.optim.optimizer import Optimizer as TorchOptimizer
-
-
-class Optimizer(object):
- r"""
- Optimizer
- """
-
- def __init__(self, model_params, **kwargs):
- r"""
-
- :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models.
- :param kwargs: additional parameters.
- """
- if model_params is not None and not hasattr(model_params, "__next__"):
- raise RuntimeError("model parameters should be a generator, rather than {}.".format(type(model_params)))
- self.model_params = model_params
- self.settings = kwargs
-
- def construct_from_pytorch(self, model_params):
- raise NotImplementedError
-
- @staticmethod
- def _get_require_grads_param(params):
- r"""
- 将params中不需要gradient的删除
-
- :param iterable params: parameters
- :return: list(nn.Parameters)
- """
- return [param for param in params if param.requires_grad]
-
-
-class NullOptimizer(Optimizer):
- r"""
- 当不希望Trainer更新optimizer时,传入本optimizer,但请确保通过callback的方式对参数进行了更新。
-
- """
- def __init__(self):
- super().__init__(None)
-
- def construct_from_pytorch(self, model_params):
- return self
-
- def __getattr__(self, item):
- def pass_func(*args, **kwargs):
- pass
-
- return pass_func
-
-
-class SGD(Optimizer):
- r"""
- SGD
- """
-
- def __init__(self, lr=0.001, momentum=0, model_params=None):
- r"""
- :param float lr: learning rate. Default: 0.01
- :param float momentum: momentum. Default: 0
- :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models.
- """
- if not isinstance(lr, float):
- raise TypeError("learning rate has to be float.")
- super(SGD, self).__init__(model_params, lr=lr, momentum=momentum)
-
- def construct_from_pytorch(self, model_params):
- if self.model_params is None:
- # careful! generator cannot be assigned.
- return torch.optim.SGD(self._get_require_grads_param(model_params), **self.settings)
- else:
- return torch.optim.SGD(self._get_require_grads_param(self.model_params), **self.settings)
-
-
-class Adam(Optimizer):
- r"""
- Adam
- """
-
- def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None):
- r"""
-
- :param float lr: learning rate
- :param float weight_decay:
- :param eps:
- :param amsgrad:
- :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models.
- """
- if not isinstance(lr, float):
- raise TypeError("learning rate has to be float.")
- super(Adam, self).__init__(model_params, lr=lr, betas=betas, eps=eps, amsgrad=amsgrad,
- weight_decay=weight_decay)
-
- def construct_from_pytorch(self, model_params):
- if self.model_params is None:
- # careful! generator cannot be assigned.
- return torch.optim.Adam(self._get_require_grads_param(model_params), **self.settings)
- else:
- return torch.optim.Adam(self._get_require_grads_param(self.model_params), **self.settings)
-
-
-class AdamW(TorchOptimizer):
- r"""
- 对AdamW的实现,该实现在pytorch 1.2.0版本中已经出现,https://github.com/pytorch/pytorch/pull/21250。
- 这里加入以适配低版本的pytorch
-
- .. todo::
- 翻译成中文
-
- The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
- The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
-
- .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980
-
- .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
-
- .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ
- """
-
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=1e-2, amsgrad=False):
- r"""
-
- :param params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- :param lr (float, optional): learning rate (default: 1e-3)
- :param betas (Tuple[float, float], optional): coefficients used for computing
- running averages of gradient and its square (default: (0.9, 0.99))
- :param eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-8)
- :param weight_decay (float, optional): weight decay coefficient (default: 1e-2)
- algorithm from the paper `On the Convergence of Adam and Beyond`_
- (default: False)
- """
- if not 0.0 <= lr:
- raise ValueError("Invalid learning rate: {}".format(lr))
- if not 0.0 <= eps:
- raise ValueError("Invalid epsilon value: {}".format(eps))
- if not 0.0 <= betas[0] < 1.0:
- raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
- if not 0.0 <= betas[1] < 1.0:
- raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
- defaults = dict(lr=lr, betas=betas, eps=eps,
- weight_decay=weight_decay, amsgrad=amsgrad)
- super(AdamW, self).__init__(params, defaults)
-
- def __setstate__(self, state):
- super(AdamW, self).__setstate__(state)
- for group in self.param_groups:
- group.setdefault('amsgrad', False)
-
- def step(self, closure=None):
- r"""Performs a single optimization step.
-
- :param closure: (callable, optional) A closure that reevaluates the model
- and returns the loss.
- """
- loss = None
- if closure is not None:
- loss = closure()
-
- for group in self.param_groups:
- for p in group['params']:
- if p.grad is None:
- continue
-
- # Perform stepweight decay
- p.data.mul_(1 - group['lr'] * group['weight_decay'])
-
- # Perform optimization step
- grad = p.grad.data
- if grad.is_sparse:
- raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
- amsgrad = group['amsgrad']
-
- state = self.state[p]
-
- # State initialization
- if len(state) == 0:
- state['step'] = 0
- # Exponential moving average of gradient values
- state['exp_avg'] = torch.zeros_like(p.data)
- # Exponential moving average of squared gradient values
- state['exp_avg_sq'] = torch.zeros_like(p.data)
- if amsgrad:
- # Maintains max of all exp. moving avg. of sq. grad. values
- state['max_exp_avg_sq'] = torch.zeros_like(p.data)
-
- exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
- if amsgrad:
- max_exp_avg_sq = state['max_exp_avg_sq']
- beta1, beta2 = group['betas']
-
- state['step'] += 1
-
- # Decay the first and second moment running average coefficient
- exp_avg.mul_(beta1).add_(1 - beta1, grad)
- exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
- if amsgrad:
- # Maintains the maximum of all 2nd moment running avg. till now
- torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
- # Use the max. for normalizing running avg. of gradient
- denom = max_exp_avg_sq.sqrt().add_(group['eps'])
- else:
- denom = exp_avg_sq.sqrt().add_(group['eps'])
-
- bias_correction1 = 1 - beta1 ** state['step']
- bias_correction2 = 1 - beta2 ** state['step']
- step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
-
- p.data.addcdiv_(-step_size, exp_avg, denom)
-
- return loss
diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py
deleted file mode 100644
index 613a4993..00000000
--- a/fastNLP/core/predictor.py
+++ /dev/null
@@ -1,83 +0,0 @@
-r"""undocumented"""
-
-__all__ = [
- "Predictor"
-]
-
-from collections import defaultdict
-
-import torch
-
-from . import DataSet
-from . import DataSetIter
-from . import SequentialSampler
-from .utils import _build_args, _move_dict_value_to_device, _get_model_device
-
-
-class Predictor(object):
- r"""
- 一个根据训练模型预测输出的预测器(Predictor)
-
- 与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。
- 这是一个fastNLP调用的高级模型包装器。它与Trainer、Tester不共享任何操作。
- """
-
- def __init__(self, network):
- r"""
-
- :param torch.nn.Module network: 用来完成预测任务的模型
- """
- if not isinstance(network, torch.nn.Module):
- raise ValueError(
- "Only fastNLP.models.BaseModel or torch.nn,Module is allowed, not {}".format(type(network)))
- self.network = network
- self.batch_size = 1
- self.batch_output = []
-
- def predict(self, data: DataSet, seq_len_field_name=None):
- r"""用已经训练好的模型进行inference.
-
- :param fastNLP.DataSet data: 待预测的数据集
- :param str seq_len_field_name: 表示序列长度信息的field名字
- :return: dict dict里面的内容为模型预测的结果
- """
- if not isinstance(data, DataSet):
- raise ValueError("Only Dataset class is allowed, not {}.".format(type(data)))
- if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays:
- raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data))
-
- prev_training = self.network.training
- self.network.eval()
- network_device = _get_model_device(self.network)
- batch_output = defaultdict(list)
- data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False)
-
- if hasattr(self.network, "predict"):
- predict_func = self.network.predict
- else:
- predict_func = self.network.forward
-
- with torch.no_grad():
- for batch_x, _ in data_iterator:
- _move_dict_value_to_device(batch_x, _, device=network_device)
- refined_batch_x = _build_args(predict_func, **batch_x)
- prediction = predict_func(**refined_batch_x)
-
- if seq_len_field_name is not None:
- seq_lens = batch_x[seq_len_field_name].tolist()
-
- for key, value in prediction.items():
- value = value.cpu().numpy()
- if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1):
- batch_output[key].extend(value.tolist())
- else:
- if seq_len_field_name is not None:
- tmp_batch = []
- for idx, seq_len in enumerate(seq_lens):
- tmp_batch.append(value[idx, :seq_len])
- batch_output[key].extend(tmp_batch)
- else:
- batch_output[key].append(value)
-
- self.network.train(prev_training)
- return batch_output
diff --git a/fastNLP/core/sampler.py b/fastNLP/core/sampler.py
deleted file mode 100644
index 8ad10e26..00000000
--- a/fastNLP/core/sampler.py
+++ /dev/null
@@ -1,418 +0,0 @@
-r"""
-sampler 子类实现了 fastNLP 所需的各种采样器。
-"""
-__all__ = [
- "Sampler",
- "BucketSampler",
- "SequentialSampler",
- "RandomSampler",
- "SortedSampler",
- "ConstantTokenNumSampler"
-]
-
-from itertools import chain
-
-import numpy as np
-
-
-class Sampler(object):
- r"""
- `Sampler` 类的基类. 规定以何种顺序取出data中的元素
-
- 子类必须实现 ``__call__`` 方法. 输入 `DataSet` 对象, 返回其中元素的下标序列
- """
-
- def __call__(self, data_set):
- r"""
- :param DataSet data_set: `DataSet` 对象, 需要Sample的数据
- :return result: list(int) 其中元素的下标序列, ``data_set`` 中元素会按 ``result`` 中顺序取出
- """
- raise NotImplementedError
-
-
-class SequentialSampler(Sampler):
- r"""
- 顺序取出元素的 `Sampler`
-
- """
-
- def __call__(self, data_set):
- return list(range(len(data_set)))
-
-
-class RandomSampler(Sampler):
- r"""
- 随机化取元素的 `Sampler`
-
- """
-
- def __call__(self, data_set):
- return list(np.random.permutation(len(data_set)))
-
-
-class BucketSampler(Sampler):
- r"""
- 带Bucket的 `Random Sampler`. 可以随机地取出长度相似的元素
- """
-
- def __init__(self, num_buckets=10, batch_size=None, seq_len_field_name='seq_len'):
- r"""
-
- :param int num_buckets: bucket的数量
- :param int batch_size: batch的大小. 默认为None,Trainer/Tester在调用BucketSampler时,会将该值正确设置,如果是非
- Trainer/Tester场景使用,需要显示传递该值
- :param str seq_len_field_name: 对应序列长度的 `field` 的名字
- """
- self.num_buckets = num_buckets
- self.batch_size = batch_size
- self.seq_len_field_name = seq_len_field_name
-
- def set_batch_size(self, batch_size):
- r"""
-
- :param int batch_size: 每个batch的大小
- :return:
- """
- self.batch_size = batch_size
-
- def __call__(self, data_set):
- if self.batch_size is None:
- raise RuntimeError("batch_size is None.")
- seq_lens = data_set.get_all_fields()[self.seq_len_field_name].content
- total_sample_num = len(seq_lens)
-
- bucket_indexes = []
- assert total_sample_num >= self.num_buckets, "The number of samples is smaller than the number of buckets."
- num_sample_per_bucket = total_sample_num // self.num_buckets
- for i in range(self.num_buckets):
- bucket_indexes.append([num_sample_per_bucket * i, num_sample_per_bucket * (i + 1)])
- bucket_indexes[-1][1] = total_sample_num
-
- sorted_seq_lens = list(sorted([(idx, seq_len) for
- idx, seq_len in zip(range(total_sample_num), seq_lens)],
- key=lambda x: x[1]))
-
- batchs = []
-
- left_init_indexes = []
- for b_idx in range(self.num_buckets):
- start_idx = bucket_indexes[b_idx][0]
- end_idx = bucket_indexes[b_idx][1]
- sorted_bucket_seq_lens = sorted_seq_lens[start_idx:end_idx]
- left_init_indexes.extend([tup[0] for tup in sorted_bucket_seq_lens])
- num_batch_per_bucket = len(left_init_indexes) // self.batch_size
- np.random.shuffle(left_init_indexes)
- for i in range(num_batch_per_bucket):
- batchs.append(left_init_indexes[i * self.batch_size:(i + 1) * self.batch_size])
- left_init_indexes = left_init_indexes[num_batch_per_bucket * self.batch_size:]
- if (left_init_indexes) != 0:
- batchs.append(left_init_indexes)
- np.random.shuffle(batchs)
-
- return list(chain(*batchs))
-
-
-class ConstTokenNumSampler(Sampler):
- """
- 尽量保证每个batch的输入token数量是接近的。
-
- 使用示例
- >>> # 假设已经有了tr_data并有一个field叫做seq_len保存了每个instance的token数量
- >>> from fastNLP import DataSetIter, Trainer
- >>> sampler = ConstTokenNumSampler('src_seq_len', max_token=4096)
- >>>
- >>> # 直接将sampler传入Trainer中,此时batch_size参数的值会被忽略
- >>> trainer = Trainer(tr_data, model, optimizer=optimizer, loss=TranslationLoss(),
- >>> batch_size=1, sampler=sampler, drop_last=False, update_every=1)
- """
- def __init__(self, seq_len_field_name, max_token=4096, max_sentence=-1, need_be_multiple_of=1, num_bucket=-1):
- """
-
- :param List[int] seq_len_field_name: 哪个field指示的sample的长度
- :param int max_token: 每个batch的最大的token数量
- :param int max_sentence: 每个batch最多多少个instance, -1表示根据max_token决定
- :param int need_be_multiple_of: 生成的batch的instance的数量需要是几的倍数,在DataParallel场景下会用到
- :param int num_bucket: 将数据按长度拆分为num_bucket个bucket,batch中的sample尽量在bucket之中进行组合,这样可以减少padding。
- """
- assert (max_sentence!=-1 and max_sentence>=need_be_multiple_of) or max_sentence<1
- self.seq_len_field_name = seq_len_field_name
- self.num_bucket = num_bucket
- self.max_token = max_token
- self._max_sentence = max_sentence
- self.need_be_multiple_of = need_be_multiple_of
-
- def __call__(self, data_set):
- assert len(data_set)>self.num_bucket, "The number of samples should be larger than buckets."
- seq_len = data_set.get_field(self.seq_len_field_name)
- self.seq_len = seq_len
- seq_len_indice = [(length, i) for i, length in enumerate(seq_len)]
- seq_len_indice.sort(key=lambda x: x[0])
- indice_in_buckets = []
- if self.num_bucket>0:
- sample_per_bucket = len(seq_len_indice)//self.num_bucket
- i = 0
- while len(indice_in_buckets)self.max_token or len(batch)>=self.max_sentence:
- left_sample = len(batch) % self.need_be_multiple_of
- add_samples = batch.copy()
- cur_max_len =length
- if left_sample!=0:
- add_samples = add_samples[:-left_sample]
- batch = batch[-left_sample:]
- cur_max_len = max(cur_max_len, max(batch))
- else:
- batch = []
- if len(add_samples)==0:
- raise RuntimeError(f"The sample `{i}` is too long to make a batch with {self.need_be_multiple_of} samples.")
- batches.append(add_samples)
- else:
- cur_max_len = max_len
- batch.append(i)
- if batch:
- left_sample = len(batch) % self.need_be_multiple_of
- add_samples = batch.copy()
- if left_sample != 0:
- add_samples = add_samples[:-left_sample].copy()
- if add_samples:
- batches.append(add_samples)
- np.random.shuffle(batches)
- self.batches = batches
-
- def __iter__(self):
- for batch in self.batches:
- yield batch
- self.get_new_order()
-
- def __len__(self):
- return len(self.batches)
-
-
-class ConstantTokenNumSampler:
- """
- 尽量保证每个batch的输入token数量是接近的。
-
- 使用示例
- >>> # 假设已经有了tr_data并有一个field叫做seq_len保存了每个instance的token数量
- >>> from fastNLP import DataSetIter, Trainer
- >>> sampler = ConstantTokenNumSampler(tr_data.get_field('seq_len').content, max_token=4096)
- >>> tr_iter = DataSetIter(tr_data,
- >>> batch_size=1, sampler=None, as_numpy=False, num_workers=0, pin_memory=False,
- >>> drop_last=False, timeout=0, worker_init_fn=None,
- >>> batch_sampler=sampler)
- >>>
- >>> # 直接将tr_iter传入Trainer中,此时batch_size参数的值会被忽略
- >>> trainer = Trainer(tr_iter, model, optimizer=optimizer, loss=TranslationLoss(),
- >>> batch_size=1, sampler=None, drop_last=False, update_every=1)
- """
- def __init__(self, seq_len, max_token=4096, max_sentence=-1, need_be_multiple_of=1, num_bucket=-1):
- """
-
- :param List[int] seq_len: list[int], 是每个sample的长度。一般可以通过dataset.get_field('seq_len').content传入
- :param int max_token: 每个batch的最大的token数量
- :param int max_sentence: 每个batch最多多少个instance, -1表示根据max_token决定
- :param int need_be_multiple_of: 生成的batch的instance的数量需要是几的倍数,在DataParallel场景下会用到
- :param int num_bucket: 将数据按长度拆分为num_bucket个bucket,batch中的sample尽量在bucket之中进行组合,这样可以减少padding。
- """
- assert (max_sentence!=-1 and max_sentence>=need_be_multiple_of) or max_sentence<1
- assert len(seq_len)>num_bucket, "The number of samples should be larger than buckets."
- self.seq_len = seq_len
- self.max_token = max_token
- self._max_sentence = max_sentence
- self.need_be_multiple_of = need_be_multiple_of
- seq_len_indice = [(length, i) for i, length in enumerate(seq_len)]
- seq_len_indice.sort(key=lambda x: x[0])
- indice_in_buckets = []
- if num_bucket>0:
- sample_per_bucket = len(seq_len_indice)//num_bucket
- i = 0
- while len(indice_in_buckets)self.max_token or len(batch)>=self.max_sentence:
- left_sample = len(batch) % self.need_be_multiple_of
- add_samples = batch.copy()
- cur_max_len =length
- if left_sample!=0:
- add_samples = add_samples[:-left_sample]
- batch = batch[-left_sample:]
- cur_max_len = max(cur_max_len, max(batch))
- else:
- batch = []
- if len(add_samples)==0:
- raise RuntimeError(f"The sample `{i}` is too long to make a batch with {self.need_be_multiple_of} samples.")
- batches.append(add_samples)
- else:
- cur_max_len = max_len
- batch.append(i)
- if batch:
- left_sample = len(batch) % self.need_be_multiple_of
- add_samples = batch.copy()
- if left_sample != 0:
- add_samples = add_samples[:-left_sample].copy()
- if add_samples:
- batches.append(add_samples)
- np.random.shuffle(batches)
- self.batches = batches
-
- def __iter__(self):
- for batch in self.batches:
- yield batch
- self.get_new_order()
-
- def __len__(self):
- return len(self.batches)
-
-
-class SortedSampler(Sampler):
- r"""
- 按照sample的长度进行排序,主要在测试的时候使用,可以加速测试(因为减少了padding)
- """
- def __init__(self, seq_len_field_name='seq_len', descending=True):
- """
-
- :param str seq_len_field_name: 按哪个field进行排序。如果传入的field是数字,则直接按照该数字大小排序;如果传入的field不是
- 数字,则使用该field的长度进行排序
- :param bool descending: 是否降序排列
- """
- self.seq_len_field_name = seq_len_field_name
- self.descending = descending
-
- def __call__(self, data_set):
- seq_lens = data_set.get_field(self.seq_len_field_name).content
- try:
- seq_lens = list(map(len, seq_lens))
- except:
- pass
-
- orders = np.argsort(seq_lens).tolist() # 从小到大的顺序
- if self.descending:
- orders = orders[::-1]
- return orders
-
-
-def simple_sort_bucketing(lengths):
- r"""
-
- :param lengths: list of int, the lengths of all examples.
- :return data: 2-level list
- ::
-
- [
- [index_11, index_12, ...], # bucket 1
- [index_21, index_22, ...], # bucket 2
- ...
- ]
-
- """
- lengths_mapping = [(idx, length) for idx, length in enumerate(lengths)]
- sorted_lengths = sorted(lengths_mapping, key=lambda x: x[1])
- # TODO: need to return buckets
- return [idx for idx, _ in sorted_lengths]
-
-
-def k_means_1d(x, k, max_iter=100):
- r"""Perform k-means on 1-D data.
-
- :param x: list of int, representing points in 1-D.
- :param k: the number of clusters required.
- :param max_iter: maximum iteration
- :return centroids: numpy array, centroids of the k clusters
- assignment: numpy array, 1-D, the bucket id assigned to each example.
- """
- sorted_x = sorted(list(set(x)))
- x = np.array(x)
- if len(sorted_x) < k:
- raise ValueError("too few buckets")
- gap = len(sorted_x) / k
-
- centroids = np.array([sorted_x[int(x * gap)] for x in range(k)])
- assign = None
-
- for i in range(max_iter):
- # Cluster Assignment step
- assign = np.array([np.argmin([np.absolute(x_i - x) for x in centroids]) for x_i in x])
- # Move centroids step
- new_centroids = np.array([x[assign == k].mean() for k in range(k)])
- if (new_centroids == centroids).all():
- centroids = new_centroids
- break
- centroids = new_centroids
- return np.array(centroids), assign
-
-
-def k_means_bucketing(lengths, buckets):
- r"""Assign all instances into possible buckets using k-means, such that instances in the same bucket have similar lengths.
-
- :param lengths: list of int, the length of all samples.
- :param buckets: list of int. The length of the list is the number of buckets. Each integer is the maximum length
- threshold for each bucket (This is usually None.).
- :return data: 2-level list
- ::
-
- [
- [index_11, index_12, ...], # bucket 1
- [index_21, index_22, ...], # bucket 2
- ...
- ]
-
- """
- bucket_data = [[] for _ in buckets]
- num_buckets = len(buckets)
- _, assignments = k_means_1d(lengths, num_buckets)
-
- for idx, bucket_id in enumerate(assignments):
- if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]:
- bucket_data[bucket_id].append(idx)
- return bucket_data
diff --git a/fastNLP/core/samplers/__init__.py b/fastNLP/core/samplers/__init__.py
new file mode 100644
index 00000000..53c29689
--- /dev/null
+++ b/fastNLP/core/samplers/__init__.py
@@ -0,0 +1,31 @@
+__all__ = [
+ 'MixSampler',
+ 'DopedSampler',
+ 'MixSequentialSampler',
+ 'PollingSampler',
+
+ 'ReproducibleSampler',
+ 'RandomSampler',
+ "SequentialSampler",
+ "SortedSampler",
+
+ 'UnrepeatedSampler',
+ 'UnrepeatedRandomSampler',
+ "UnrepeatedSortedSampler",
+ "UnrepeatedSequentialSampler",
+
+ "ReproduceBatchSampler",
+ "BucketedBatchSampler",
+ "ReproducibleBatchSampler",
+ "RandomBatchSampler",
+
+ "re_instantiate_sampler"
+]
+
+from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler
+from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler
+from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler
+from .utils import re_instantiate_sampler
+from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler
+from .reproducible_batch_sampler import ReproduceBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler, RandomBatchSampler
+
diff --git a/fastNLP/core/samplers/conversion_utils.py b/fastNLP/core/samplers/conversion_utils.py
new file mode 100644
index 00000000..7d057bb9
--- /dev/null
+++ b/fastNLP/core/samplers/conversion_utils.py
@@ -0,0 +1,33 @@
+from fastNLP.core.samplers import re_instantiate_sampler
+from fastNLP.core.samplers.reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, \
+ SortedSampler
+from fastNLP.core.samplers.unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, \
+ UnrepeatedSequentialSampler, UnrepeatedSortedSampler
+
+
+def conversion_between_reproducible_and_unrepeated_sampler(sampler):
+ """
+ 将 ``sampler`` 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 :class:`~fastNLP.core.samplers.unrepeated_sampler.UnrepeatedSampler`
+ 但是没找到对应的 :class:`~fastNLP.core.samplers.reproducible_sampler.ReproducibleSampler` 则会报错。
+
+ :param sampler: 需要转换的 ``sampler`` ;
+ :return:
+ """
+ assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \
+ "The sampler must be UnrepeatedSampler or ReproducibleSampler"
+ if isinstance(sampler, UnrepeatedSampler):
+ if isinstance(sampler, UnrepeatedRandomSampler):
+ return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler)
+ elif isinstance(sampler, UnrepeatedSequentialSampler):
+ return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler)
+ elif isinstance(sampler, UnrepeatedSortedSampler):
+ return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler)
+ raise TypeError(f"{sampler.__class__} has no unrepeated version.")
+ else:
+ if isinstance(sampler, RandomSampler):
+ return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler)
+ elif isinstance(sampler, SequentialSampler):
+ return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler)
+ elif isinstance(sampler, SortedSampler):
+ return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler)
+ raise TypeError(f"{sampler.__class__} has no reproducible version.")
\ No newline at end of file
diff --git a/fastNLP/core/samplers/mix_sampler.py b/fastNLP/core/samplers/mix_sampler.py
new file mode 100644
index 00000000..2c3daa10
--- /dev/null
+++ b/fastNLP/core/samplers/mix_sampler.py
@@ -0,0 +1,518 @@
+import array
+import numpy as np
+from typing import Union, List, Iterable, Dict
+
+__all__ = [
+ 'MixSampler',
+ 'DopedSampler',
+ 'MixSequentialSampler',
+ 'PollingSampler'
+]
+
+from fastNLP.envs.imports import _NEED_IMPORT_TORCH
+
+if _NEED_IMPORT_TORCH:
+ from torch.utils.data import SequentialSampler, Sampler
+ import torch
+
+
+class MixSampler:
+ """
+ 所有 mix_sampler 的基类。
+
+ :param dataset: 一个字典,每个元素都是一个实现了 __getitem__ 和 __len__ 的数据容器
+ :param batch_size: ``dataset`` 的批次大小,所有 ``dataset`` 均采用该 ``batch_size`` 作为批次大小
+ :param sampler: 实例化好的 ``sampler`` ,每个 ``dataset`` 对应一个 ``sampler`` 对象
+ :param ds_ratio: ``ds_ratio`` 是控制 datasets 怎么组成一个混合大数据集的重要参数, 其取值为 ``[None, 'truncate_to_least', 'pad_to_most', List[float], Dict[str, float]]``:
+
+ * ds_ratio 为 ``None``, datasets 数据集序列或字典不进行数据扩充处理;
+ * ds_ratio 为 ``'truncate_to_least'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最断长度 ``mix_len``, 其他数据集会被切断
+ 到最短长度 ``mix_len``。这种切断不是物理上切断,``MixDataLoader`` 会根据 sampler 不同来采样数据集到指定的最短长度 ``mix_len``;
+ * ds_ratio 为 ``'pad_to_most'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最大长度 ``max_len``, 其他其他数据集会扩充
+ 到最大长度 ``mix_len``。这种扩充不是物理上扩充, ``MixDataLoader`` 会根据 sampler 不同来重采样 dataset 到指定的最大长度 ``max_len``;
+ * ds_ratio 为 ``Dict[str, float]`` 时, datasets 类型也必须为 ``Dict[str, DataSet]``, 其 key 一一对应。 ``ds_ratio`` 的 value 是任意大于 0 的浮点数,
+ 代表着 datasets 的 value 数据进行扩充或者缩减的倍数;
+
+ :param drop_last: 当最后一个 batch 长度小于 ``batch_size`` 时是否丢弃
+ :param rank: 分布式训练中当前进程的 ``global_rank``
+ :param world_size: 分布式训练中进程的总数 **world_size**
+ """
+
+ def __init__(self, dataset: Dict, batch_size: int = None,
+ sampler: Union[Dict[str, "Sampler"], None, str] = None,
+ ds_ratio: Union[str, Dict[str, float]] = None,
+ drop_last: bool = False, rank: int = -1, word_size: int = -1) -> None:
+ # sampler 为 dict,则判断是否与 datasets 的 key 相同
+ if isinstance(sampler, Dict):
+ for key in dataset.keys():
+ if not sampler[key]:
+ raise ValueError(f"the key:{key} of datasets is not in sampler, where sampler is a dict!")
+
+ if batch_size <= 0:
+ raise ValueError("batch_size should be a positive integer value, "
+ "but got batch_size={}".format(batch_size))
+ if not isinstance(drop_last, bool):
+ raise ValueError("drop_last should be a boolean value, but got "
+ "drop_last={}".format(drop_last))
+
+ if not isinstance(sampler, str) and (rank >= 0 or word_size >= 0):
+ raise ValueError("if rank>=0 and word_size>=0, sampler must be str")
+
+ if sampler is None and (word_size < 0 or rank < 0):
+ self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()}
+
+ elif isinstance(sampler, Dict):
+ self.sampler = sampler
+
+ else:
+ # 单卡多机情况下, sampler为None或者str且word_size>0, rank > 0
+ if isinstance(sampler, str):
+ if sampler not in ['seq', 'rand']:
+ raise ValueError(f"sampler is {sampler}, but seq or rand is required")
+ self.sampler = sampler
+
+ # 计算扩展后的大数据集长度total_len和扩展后的单个数据集长度sampler_len
+ sampler_lens, total_lens, sampler_index = [], 0, []
+ if isinstance(self.sampler, Dict):
+ if ds_ratio is None:
+ sampler_lens = [len(spl) for _, spl in self.sampler.items()]
+
+ elif ds_ratio == 'pad_to_most':
+ sampler_len = sum([1 for _ in self.sampler.keys()])
+ sampler_lens = [max(len(spl) for _, spl in self.sampler.items())] * sampler_len
+
+ elif ds_ratio == 'truncate_to_least':
+ sampler_len = sum([1 for _ in self.sampler.keys()])
+ sampler_lens = [min(len(spl) for _, spl in self.sampler.items())] * sampler_len
+
+ elif isinstance(ds_ratio, Dict):
+ if not all([item >= 0 for item in ds_ratio.values()]):
+ raise ValueError("batch_size should be a positive integer value, "
+ "but got ds_ratio={}".format(ds_ratio))
+ sampler_lens = [int(len(spl) * ds_ratio[name]) for name, spl in self.sampler.items()]
+ else:
+ raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None or List")
+ total_lens = sum(sampler_lens)
+
+ # sampler 为 str 时候,初始化下移到 iter 方法中
+ if len(sampler_lens) > 0:
+ sampler_index = [sampler_lens[0]]
+ for idx in sampler_lens[1:]:
+ temp = sampler_index[-1]
+ sampler_index.append(temp + idx)
+
+ self.batch_size = batch_size
+ self.drop_last = drop_last
+ self.ds_ratio = ds_ratio
+ self.rank = rank
+ self.word_size = word_size
+ self.datasets = dataset
+ self.num_samplers = sampler_index
+ self.len_samplers = total_lens
+ self.epoch = 0
+
+ def __iter__(self):
+ pass
+
+ def __len__(self):
+ pass
+
+ def set_epoch(self, epoch: int) -> None:
+ """
+ 配合ddp使用, 控制随机数种子
+
+ :param epoch: 当前的轮次
+ :return:
+ """
+ self.epoch = epoch
+
+
+class InnerSampler:
+ """
+ 提供多卡情况下使用的内部 sampler
+ """
+ def __init__(self, ds_ind_list: List) -> None:
+ self.ds_ind_list = ds_ind_list
+
+ def __iter__(self) -> int:
+ for item in self.ds_ind_list:
+ yield item
+
+ def __len__(self) -> int:
+ return len(self.ds_ind_list)
+
+
+class DopedSampler(MixSampler):
+ """
+ 定制给 :class:`~fastNLP.core.dataloaders.MixDataLoader` 的 ``BatchSampler``,其功能是将传入的 ``datasets``
+ 字典混合采样组成一个个 batch 返回。
+ """
+ def __init__(self, dataset: Dict, batch_size: int = None,
+ sampler: Union[Dict[str, "Sampler"], str] = None,
+ ds_ratio: Union[str, None, Dict[str, float]] = None,
+ drop_last: bool = False, rank: int = -1, word_size: int = -1) -> None:
+ super(DopedSampler, self).__init__(dataset=dataset, batch_size=batch_size,
+ sampler=sampler, ds_ratio=ds_ratio,
+ drop_last=drop_last, rank=rank, word_size=word_size)
+
+ def __iter__(self) -> List[int]:
+ # sampler 为 str, 此时为单机多卡或者单机,可以实现 rand 随机化
+ if isinstance(self.sampler, str):
+ if self.sampler == 'seq':
+ self.sampler = {}
+ for name, per_ds in self.datasets.items():
+ if self.word_size >= 0 and self.rank >= 0:
+ self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])
+ else:
+ self.sampler[name] = InnerSampler(list(range(len(per_ds))))
+ elif self.sampler == 'rand':
+ self.sampler = {}
+ for name, per_ds in self.datasets.items():
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(len(per_ds), generator=g).tolist()
+ if self.word_size >= 0 and self.rank >= 0:
+ self.sampler[name] = InnerSampler(indices[self.rank::self.word_size])
+ else:
+ self.sampler[name] = InnerSampler(indices)
+
+ # 根据给定的ds_ratio计算真正需要处理数据集
+ if isinstance(self.sampler, Dict):
+ if self.ds_ratio is None:
+ sampler_lens = [len(spl) for _, spl in self.sampler.items()]
+
+ elif self.ds_ratio == 'pad_to_most':
+ sampler_len = sum([1 for _ in self.sampler.keys()])
+ sampler_lens = [max(len(spl) for _, spl in self.sampler.items())] * sampler_len
+
+ elif self.ds_ratio == 'truncate_to_least':
+ sampler_len = sum([1 for _ in self.sampler.keys()])
+ sampler_lens = [min(len(spl) for _, spl in self.sampler.items())] * sampler_len
+
+ elif isinstance(self.ds_ratio, Dict):
+ if not all(item >= 0 for item in self.ds_ratio):
+ raise ValueError("batch_size should be a positive integer value, "
+ "but got ds_ratio={}".format(self.ds_ratio))
+ sampler_lens = [int(len(spl) * self.ds_ratio[name]) for name, spl in self.sampler.items()]
+ else:
+ raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List")
+ total_lens = sum(sampler_lens)
+ else:
+ raise ValueError("datasets must be dict or list")
+ # 初始化参数
+ sampler_index = [sampler_lens[0]]
+ for idx in sampler_lens[1:]:
+ temp = sampler_index[-1]
+ sampler_index.append(temp + idx)
+ self.num_samplers = sampler_index
+ self.len_samplers = total_lens
+ # 每个 batch 的数据, 总的数据量 total_index , 每个数据集的 samplers
+ batch_idx, samplers = [], []
+ # 如果单机则用所有数据,否则采用多卡
+ if self.rank < 0 or self.word_size < 0:
+ # 根据 sampler 长度判断是否使用 unsigned int 或者 unsigned long
+ if self.len_samplers > 42e8:
+ total_index = array.array('L', list(range(self.len_samplers)))
+ else:
+ total_index = array.array('I', list(range(self.len_samplers)))
+ else:
+ if (self.len_samplers // self.word_size) > 42e8:
+ # 整分给每个卡的数据
+ self.len_samplers = self.len_samplers - self.len_samplers % self.word_size
+ total_index = array.array('L', list(range(self.len_samplers))[self.rank::self.word_size])
+ else:
+ total_index = array.array('I', list(range(self.len_samplers))[self.rank::self.word_size])
+
+ start_idx = 0
+
+ # (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标)
+ for idx, (name, spl) in enumerate(self.sampler.items()):
+ end_idx = len(spl)
+ samplers.append((iter(spl), name, start_idx))
+ start_idx += end_idx
+ # 根据sampler的类型取出每个数据集的sampler
+ # sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1]
+ # samplers = [(iter(spl), name, sampler_base_index[idx])
+ # for idx, (name, spl) in enumerate(self.sampler.items())]
+ # 生成随机数
+ np.random.seed(self.epoch)
+ np.random.shuffle(total_index)
+ for idx in total_index:
+ ds_index = np.searchsorted(self.num_samplers, idx, side='right')
+ spl, name, base_index = samplers[ds_index]
+ try:
+ batch_idx.append(next(spl) + base_index)
+ except StopIteration:
+ # 重新初始化一个新的sampler,因为不可能为空,故一定不会出现stopIteration
+ spl = iter(self.sampler[name])
+ batch_idx.append(next(spl) + base_index)
+ samplers[ds_index] = (spl, name, base_index)
+ if len(batch_idx) == self.batch_size:
+ yield batch_idx
+ batch_idx = []
+
+ if len(batch_idx) > 0 and not self.drop_last:
+ yield batch_idx
+
+ def __len__(self) -> int:
+ # 多卡情况下
+ if self.rank >= 0 and self.word_size >= 0:
+ # 整分给每个卡的数据
+ self.len_samplers = (self.len_samplers - self.len_samplers % self.word_size) / self.word_size
+ if self.drop_last:
+ return self.len_samplers // self.batch_size
+ else:
+ return (self.len_samplers + self.batch_size - 1) // self.batch_size
+
+
+class MixSequentialSampler(MixSampler):
+ """
+ 定制给 :class:`~fastNLP.core.dataloaders.MixDataLoader` 的 ``BatchSampler``,其功能是将传入的 ``datasets`` 按顺序采样并返回 index,
+ 只有上一个 dataset 处理结束后才会处理下一个。
+ """
+
+ def __init__(self, dataset: Dict, batch_size: int = None,
+ sampler: Union[List["Sampler"], Dict[str, "Sampler"], None, str] = None,
+ ds_ratio: Union[str, List[float], Dict[str, float]] = None,
+ drop_last: bool = False, rank: int = -1, word_size: int = -1) -> None:
+ super(MixSequentialSampler, self).__init__(dataset=dataset, batch_size=batch_size,
+ sampler=sampler, ds_ratio=ds_ratio,
+ drop_last=drop_last, rank=rank, word_size=word_size)
+
+ def __iter__(self) -> Iterable[List[int]]:
+ """
+ 按照 ``dataset`` 的顺序采样,打包成一个 batch 后返回。
+
+ :return:
+ """
+ # sampler为str, 此时为单机多卡或者单机,可以实现rand随机化
+ if isinstance(self.sampler, str):
+ if self.sampler == 'seq':
+ self.sampler = {}
+ for name, per_ds in self.datasets.items():
+ if self.word_size >= 0 and self.rank >= 0:
+ self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])
+ else:
+ self.sampler[name] = InnerSampler(list(range(len(per_ds))))
+ elif self.sampler == 'rand':
+
+ self.sampler = {}
+ for name, per_ds in self.datasets.items():
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(len(per_ds), generator=g).tolist()
+ if self.word_size >= 0 and self.rank >= 0:
+ self.sampler[name] = InnerSampler(indices[self.rank::self.word_size])
+ else:
+ self.sampler[name] = InnerSampler(indices)
+
+ # 根据给定的 ds_ratio 算真正需要处理数据集
+ if isinstance(self.sampler, Dict):
+ if self.ds_ratio is None:
+ sampler_lens = [len(spl) for _, spl in self.sampler.items()]
+
+ elif self.ds_ratio == 'pad_to_most':
+ sampler_len = sum([1 for _ in self.sampler.keys()])
+ sampler_lens = [max(len(spl) for _, spl in self.sampler.items())] * sampler_len
+
+ elif self.ds_ratio == 'truncate_to_least':
+ sampler_len = sum([1 for _ in self.sampler.keys()])
+ sampler_lens = [min(len(spl) for _, spl in self.sampler.items())] * sampler_len
+
+ elif isinstance(self.ds_ratio, Dict):
+ if not all(item >= 0 for item in self.ds_ratio):
+ raise ValueError("batch_size should be a positive integer value, "
+ "but got ds_ratio={}".format(self.ds_ratio))
+ sampler_lens = [int(len(spl) * self.ds_ratio[name]) for name, spl in self.sampler.items()]
+ else:
+ raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List")
+ total_lens = sum(sampler_lens)
+ else:
+ raise ValueError("datasets must be dict or list")
+ # 初始化参数
+ sampler_index = [sampler_lens[0]]
+ for idx in sampler_lens[1:]:
+ temp = sampler_index[-1]
+ sampler_index.append(temp + idx)
+ self.num_samplers = sampler_index
+ self.len_samplers = total_lens
+
+ batch_idx, total_index, samplers = [], list(range(self.len_samplers)), []
+ start_idx = 0
+
+ # (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标)
+ for idx, (name, spl) in enumerate(self.sampler.items()):
+ end_idx = len(spl)
+ samplers.append((iter(spl), name, start_idx))
+ start_idx += end_idx
+ # if self.word_size > 0 and self.rank >= 0:
+ # sampler_base_index = [0] + [len(spl) * self.word_size for _, spl in self.sampler.items()][:-1]
+ # else:
+ # sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1]
+ #
+ # samplers = [(iter(spl), name, sampler_base_index[idx])
+ # for idx, (name, spl) in enumerate(self.sampler.items())]
+ for idx in total_index:
+ ds_index = np.searchsorted(self.num_samplers, idx, side='right')
+
+ spl, name, base_index = samplers[ds_index]
+ try:
+ batch_idx.append(next(spl) + base_index)
+ except StopIteration:
+ # 重新初始化一个新的sampler,因为不可能为空,故一定不会出现stopIteration
+ spl = iter(self.sampler[name])
+ batch_idx.append(next(spl) + base_index)
+ samplers[ds_index] = (spl, name, base_index)
+ if len(batch_idx) == self.batch_size:
+ yield batch_idx
+ batch_idx = []
+ # 当前数据集采样完,需要及时处理最后一个batch
+ if self.num_samplers[ds_index] == (idx + 1):
+ if len(batch_idx) > 0 and not self.drop_last:
+ yield batch_idx
+ batch_idx = []
+
+ def __len__(self) -> int:
+ lens, index = 0, 0
+ num_sampler = []
+ for ds_len in self.num_samplers:
+ num_sampler.append(ds_len - index)
+ index = ds_len
+
+ for ds_len in num_sampler:
+ if self.drop_last:
+ lens += ds_len // self.batch_size
+ else:
+ lens += (ds_len + self.batch_size - 1) // self.batch_size
+ return lens
+
+
+class PollingSampler(MixSampler):
+ """
+ 定制给 :class:`~fastNLP.core.dataloaders.MixDataLoader` 的 ``BatchSampler``,其功能是将传入的 ``datasets`` 轮流采样并返回 index,
+ 处理结束上个 dataset 的一个 batch 后会处理下一个。
+ """
+
+ def __init__(self, dataset: Union[List, Dict], batch_size: int = 16,
+ sampler: Union[List["Sampler"], Dict[str, "Sampler"], str] = None,
+ drop_last: bool = False, ds_ratio="pad_to_most", rank: int = -1,
+ word_size: int = -1) -> None:
+ super(PollingSampler, self).__init__(dataset=dataset, batch_size=batch_size,
+ sampler=sampler, ds_ratio=ds_ratio,
+ drop_last=drop_last, rank=rank, word_size=word_size)
+
+ def __iter__(self) -> List[int]:
+ # sampler为str, 此时为单机多卡或者单机,可以实现rand随机化
+ if isinstance(self.sampler, str):
+ if self.sampler == 'seq':
+ self.sampler = {}
+ for name, per_ds in self.datasets.items():
+ if self.word_size >= 0 and self.rank >= 0:
+ self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])
+ else:
+ self.sampler[name] = InnerSampler(list(range(len(per_ds))))
+ elif self.sampler == 'rand':
+
+ self.sampler = {}
+ for name, per_ds in self.datasets.items():
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(len(per_ds), generator=g).tolist()
+ if self.word_size >= 0 and self.rank >= 0:
+ self.sampler[name] = InnerSampler(indices[self.rank::self.word_size])
+ else:
+ self.sampler[name] = InnerSampler(indices)
+
+ # 根据给定的ds_ratio计算真正需要处理数据集
+ if isinstance(self.sampler, Dict):
+ if self.ds_ratio is None:
+ sampler_lens = [len(spl) for _, spl in self.sampler.items()]
+
+ elif self.ds_ratio == 'pad_to_most':
+ sampler_len = sum([1 for _ in self.sampler.keys()])
+ sampler_lens = [max(len(spl) for _, spl in self.sampler.items())] * sampler_len
+
+ elif self.ds_ratio == 'truncate_to_least':
+ sampler_len = sum([1 for _ in self.sampler.keys()])
+ sampler_lens = [min(len(spl) for _, spl in self.sampler.items())] * sampler_len
+
+ elif isinstance(self.ds_ratio, Dict):
+ if not all(item >= 0 for item in self.ds_ratio):
+ raise ValueError("batch_size should be a positive integer value, "
+ "but got ds_ratio={}".format(self.ds_ratio))
+ sampler_lens = [int(len(spl) * self.ds_ratio[name]) for name, spl in self.sampler.items()]
+ else:
+ raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List")
+ total_lens = sum(sampler_lens)
+ else:
+ raise ValueError("datasets must be dict or list")
+ # 初始化参数
+ sampler_index = [sampler_lens[0]]
+ for idx in sampler_lens[1:]:
+ temp = sampler_index[-1]
+ sampler_index.append(temp + idx)
+ self.num_samplers = sampler_index
+ self.len_samplers = total_lens
+
+ start_idx, samplers, true_start_idx, true_end_idx = 0, [], 0, 0
+
+ # (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标)
+ for idx, (name, spl) in enumerate(self.sampler.items()):
+ end_idx = len(spl)
+ true_end_idx = self.num_samplers[idx]
+ samplers.append((iter(range(true_start_idx, true_end_idx)), iter(spl), start_idx, name))
+ start_idx += end_idx
+ true_start_idx = true_end_idx
+
+ while True:
+ # 退出循环
+ if len(samplers) == 0:
+ break
+ batch_idx, flag = [], False
+ ds_total_iter, ds_sampler, ds_base_idx, sampler_idx = samplers.pop(0)
+ for _ in range(self.batch_size):
+ try:
+ # 取出数据
+ next(ds_total_iter)
+ # 取出真正数据, 若取完则重新初始化一个
+ try:
+ batch_idx.append(next(ds_sampler) + ds_base_idx)
+ except StopIteration:
+ ds_sampler = iter(self.sampler[sampler_idx])
+ batch_idx.append(next(ds_sampler) + ds_base_idx)
+ except StopIteration:
+ # 当前ds所有的数据集采样完毕,将其清除队列
+ flag = True
+ # 判断是否真正解决某个数据集的采样
+ if flag is False:
+ samplers.append((ds_total_iter, ds_sampler, ds_base_idx, sampler_idx))
+ if len(batch_idx) == self.batch_size:
+ yield batch_idx
+ elif len(batch_idx) > 0 and not self.drop_last:
+ yield batch_idx
+
+ def __len__(self) -> int:
+ lens, index = 0, 0
+ num_sampler = []
+ for ds_len in self.num_samplers:
+ num_sampler.append(ds_len - index)
+ index = ds_len
+
+ for ds_len in num_sampler:
+ if self.drop_last:
+ lens += ds_len // self.batch_size
+ else:
+ lens += (ds_len + self.batch_size - 1) // self.batch_size
+ return lens
+
+
+if __name__ == '__main__':
+ from fastNLP.core.dataset import DataSet
+ ds = DataSet({'x': ["x1a", "1ws2", "xa qa", "ax wq", "iu, lk"] * 101, 'y': [1, 0, 1, 0, 1] * 101})
+ ds1 = DataSet({'x': ["x12a", "1wzs2", "xa xqa", "aax wq", "iau, lk"] * 101, 'y': ['1', '0', '1', '0', '1'] * 101})
+ sampler = DopedSampler(dataset=[ds, ds1], batch_size=6, rank=0, word_size=-2, sampler='seq')
+ seqSpl = MixSequentialSampler(dataset=[ds, ds1], batch_size=6, rank=0, word_size=2, sampler='seq', drop_last=True)
+ polSpl = PollingSampler(dataset=[ds, ds1], batch_size=6, rank=1, word_size=2, sampler='seq', drop_last=False)
+ for idx, batch in enumerate(polSpl):
+ print(idx, batch)
+ # print(len(seqSpl))
diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py
new file mode 100644
index 00000000..bd59806b
--- /dev/null
+++ b/fastNLP/core/samplers/reproducible_batch_sampler.py
@@ -0,0 +1,685 @@
+"""
+:class:`ReproducibleBatchSampler` 是 **fastNLP** 提供的一种特殊 BatchSampler,它可以记录采样过程中每一次采样和 epoch 的信息,
+方便在保存-加载后能够从上一次采样结束的地方继续进行采样,实现 **断点重训** 。
+
+.. note::
+
+ DataLoader 中只要存在 :class:`~fastNLP.core.samplers.reproducible_sampler.ReproducibleSampler` 或 :class:`ReproducibleBatchSampler`
+ 中的一个便可以实现断点重训复现的功能。
+
+"""
+
+__all__ = [
+ 'BucketedBatchSampler',
+ "ReproduceBatchSampler",
+ "RandomBatchSampler"
+]
+
+import os
+import math
+from copy import deepcopy
+from typing import Dict, Union, List
+from itertools import chain
+
+import numpy as np
+
+from fastNLP.core.dataset import DataSet
+from fastNLP.core.log import logger
+from .utils import create_array
+from abc import abstractmethod
+
+
+class ReproducibleBatchSampler:
+ """
+ **可复现**的 BatchSampler 对象。
+
+ 注意所有继承 :class:`ReproducibleBatchSampler` 的类的 :meth:`__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 BatchSampler
+ 注意,所有 :meth:`__init__` 中初始化的变量,都不能含有 ``_`` 下横线作为开头;所有不在 :meth:`__init__` 中设置的变量都必须以下横线开头。
+ """
+ def __init__(self, **kwargs):
+ self.num_replicas = 1
+
+ @abstractmethod
+ def set_distributed(self, num_replicas, rank, pad=True):
+ raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.")
+
+ @abstractmethod
+ def __len__(self):
+ raise NotImplementedError("Each specific batch_sampler should implement its own `__len__` method.")
+
+ @abstractmethod
+ def __iter__(self):
+ raise NotImplementedError("Each specific batch_sampler should implement its own `__iter__` method.")
+
+ @abstractmethod
+ def state_dict(self):
+ """
+
+ :return:
+ """
+ raise NotImplementedError("Each specific batch_sampler should implement its own `state_dict` method.")
+
+ @abstractmethod
+ def load_state_dict(self, states):
+ raise NotImplementedError("Each specific batch_sampler should implement its own `load_state_dict` method.")
+
+ @abstractmethod
+ def set_epoch(self, epoch):
+ pass
+
+ @property
+ def batch_idx_in_epoch(self):
+ raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.")
+
+
+class ReproduceBatchSampler(ReproducibleBatchSampler):
+ """
+ 可以使得 ``batch_sampler`` 对象状态恢复的 wrapper 。
+
+ :param batch_sampler: 可迭代出 **数字** 或 **数字列表** 的可迭代对象。:class:`ReproduceBatchSampler` 将首先遍历一边该对象,然后将迭代
+ 出来的序号暂存起来,使用时按照 ``batch_size`` 的 batch 大小吐出序号列表。
+ :param batch_size: 每个 batch 的大小是多少
+ :param drop_last: 如果最后一个 batch 无法构成 ``batch_size`` 个 sample ,是否丢掉
+ :param kwargs: fastNLP 内部使用的参数
+ """
+ def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs):
+ super().__init__()
+
+ self.batch_sampler = batch_sampler
+ self.batch_size = batch_size
+ self.drop_last = drop_last
+
+ self.num_consumed_samples = kwargs.get("num_consumed_samples", 0)
+
+ self.index_list = kwargs.get("index_list", self._iterate_sampler())
+ self.need_reinitialize = kwargs.get("need_reinitialize", False)
+
+ def _iterate_sampler(self):
+ _index_lst = []
+ for idx in self.batch_sampler:
+ if isinstance(idx, list):
+ _index_lst.extend(idx)
+ # 说明是在初始化时传入的是一个 sampler,理论上对应于 dataloader 在初始化时没有 batch_size,也没有 batch_sampler 的情况;
+ else:
+ _index_lst.append(idx)
+ _index_lst = create_array(len(_index_lst), _index_lst)
+ return _index_lst
+
+ def __iter__(self):
+ if self.need_reinitialize:
+ self.index_list = self._iterate_sampler()
+ self.num_consumed_samples = 0
+ else:
+ self.need_reinitialize = True
+
+ batch = []
+ if self.num_consumed_samples:
+ index_list = self.index_list[self.num_consumed_samples:]
+ else:
+ index_list = self.index_list
+
+ # 暂时弃用。记住每个 batch 对应的 consumed_samples, 需要这个原因是由于现在的 dataloader 都存在预取数据的设计,需要再结合Trainer中
+ # batch_idx_in_epoch 才能最终确定实际消耗的数据。这个变量需要记录每次yield出去时的真实 num_consumed_samples 的数值。
+ # self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30),
+ # num_consumed_samples=self.num_consumed_samples)
+ for idx in index_list:
+ batch.append(idx)
+ if len(batch) == self.batch_size:
+ self.num_consumed_samples += self.batch_size # [16, 32, 48, 64,..., ]
+ yield batch
+ batch = []
+ if len(batch) > 0 and not self.drop_last:
+ self.num_consumed_samples += len(batch)
+ yield batch
+ # 需要重置防止边界条件问题
+ self.num_consumed_samples = 0
+
+ def __len__(self) -> int:
+ if self.drop_last:
+ return len(self.index_list) // self.batch_size
+ else:
+ return (len(self.index_list) + self.batch_size - 1) // self.batch_size
+
+ def state_dict(self) -> Dict:
+ states = {
+ "index_list": deepcopy(self.index_list),
+ "num_consumed_samples": self.num_consumed_samples,
+ 'sampler_type': self.__class__.__name__
+ }
+ return states
+
+ def load_state_dict(self, states: Dict):
+ assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \
+ f"we cannot use {self.__class__.__name__} to load it."
+
+ _index_list = states["index_list"]
+ assert len(_index_list) == len(self.index_list), "The number of samples is different between the checkpoint " \
+ "record and current dataset."
+ self.index_list = _index_list
+ self.num_consumed_samples = states["num_consumed_samples"]
+ self.need_reinitialize = False
+
+ def set_distributed(self, num_replicas, rank, pad=True):
+ raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.")
+
+ def set_epoch(self, epoch):
+ if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch):
+ self.batch_sampler.sampler.set_epoch(epoch)
+
+ @property
+ def batch_idx_in_epoch(self):
+ if self.drop_last:
+ return len(self.index_list) // self.batch_size - (len(self.index_list) - self.num_consumed_samples) // self.batch_size
+ else:
+ return (len(self.index_list) + self.batch_size - 1) // self.batch_size - \
+ (len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size
+
+
+class RandomBatchSampler(ReproducibleBatchSampler):
+ """
+ 随机分 batch 的 batch_sampler 。
+
+ :param dataset: 实现了 __len__ 方法的数据容器
+ :param batch_size: 每个 batch 的大小
+ :param shuffle: 如果为 ``True``,将不进行打乱操作,实际上数据会以从长到短的方式输出
+ :param drop_last: 如果最后一个 batch 无法构成 batch_size 个 sample ,是否丢掉
+ :param seed: 设置的随机数种子
+ :param kwargs: fastNLP 内部使用的参数
+ """
+ def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True,
+ drop_last: bool = False, seed: int = 0, **kwargs):
+ super().__init__()
+
+ self.dataset = dataset
+
+ self.batch_size = batch_size
+ self.shuffle = shuffle
+ self.drop_last = drop_last
+ self.seed = int(seed)
+
+ self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量
+
+ # 多卡的相关的参数
+ self.num_replicas = kwargs.get("num_replicas", 1)
+ self.rank = kwargs.get("rank", 0)
+ self.epoch = kwargs.get("epoch", -1)
+ self.pad = kwargs.get("pad", False) # 该参数在单卡上不具有任何意义;
+
+ # 是否处于iteration之间,为True不允许调用 set_distributed()和load_state_dict()
+ self.during_iter = kwargs.get("during_iter", False)
+
+ # 以下变量为内部使用恢复状态的变量。
+ self.old_batch_size = kwargs.get('old_batch_size', self.batch_size)
+
+ def set_distributed(self, num_replicas, rank, pad=True):
+ """
+ 进行分布式的相关设置,应当在初始化该 BatchSampler 本身后立即被调用。
+
+ :param num_replicas: 分布式训练中的进程总数
+ :param rank: 当前进程的 ``global_rank``
+ :param pad: 如果 sample 数量不整除 ``num_replicas`` 的时候,要不要 pad 一下,使得最终使得每个进程上
+ 的 sample 数量是完全一致的
+ :return: 自身
+ """
+ assert self.during_iter is False, "Cannot set the sampler to be distributed when it is " \
+ "during an unfinished iteration."
+ assert num_replicas > 0 and isinstance(num_replicas, int)
+ assert isinstance(rank, int) and 0 <= rank < num_replicas
+ # 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态;
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.pad = pad
+
+ return self
+
+ def __iter__(self):
+ if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了
+ self.num_consumed_samples = 0
+ self.during_iter = True
+
+ indices = list(range(self.num_samples))
+
+ if self.shuffle:
+ if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的
+ _batches = []
+ for _i in range(self.old_num_replicas):
+ _indices = indices[_i:len(indices):self.old_num_replicas]
+ __batches = self.batchify(_indices, self.old_batch_size, seed=self.seed + self.epoch)
+ _batches.append(__batches)
+ batches = list(chain(*[_ for _ in zip(*_batches)]))
+ indices = list(chain(*batches))
+ indices = indices[self.num_consumed_samples:]
+ # 取出这个 rank ,
+ indices = indices[self.rank:len(indices):self.num_replicas]
+ batches = self.batchify(indices, self.batch_size, seed=self.seed + self.epoch)
+ batches = list(map(list, batches))
+ else:
+ indices = indices[self.num_consumed_samples:]
+ indices = indices[self.rank:len(indices):self.num_replicas]
+ _num_batches = len(indices) // self.batch_size
+ if _num_batches == 0:
+ batches = [indices]
+ else:
+ batches = list(map(list, np.array_split(indices[:_num_batches*self.batch_size], _num_batches)))
+ if len(indices)%self.batch_size!=0:
+ batches.append(indices[_num_batches*self.batch_size:])
+
+ need_pad_num = (self.num_samples-self.num_consumed_samples) % self.num_replicas
+ if self.pad and need_pad_num !=0 and need_pad_num<=self.rank:
+ if len(batches) > 0:
+ if len(batches[-1])self.rank:
+ if len(batches):
+ batches[-1].pop(-1)
+ if len(batches[-1])==0:
+ batches.pop(-1)
+
+ assert sum(map(len, batches)) == self.num_left_samples
+
+ if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size:
+ batches = batches[:-1]
+
+ for batch in batches:
+ self.num_consumed_samples += self.num_replicas * len(batch)
+ yield list(map(int, batch))
+ self.during_iter = False
+ self.num_consumed_samples = 0
+ self.old_batch_size = self.batch_size
+ self.old_num_replicas = self.num_replicas
+ if self.epoch < 0: # 防止用户没有修改epoch,导致每个epoch都一样了
+ self.epoch -= 1
+
+ def batchify(self, indices, batch_size, seed) -> List[List[int]]:
+ """
+ 将 ``indices`` 分为 batches
+
+ :param indices: List[int]
+ :param batch_size: int
+ :param seed: int
+ :return:
+ """
+ # 实际的 bucket 大小
+ rng = np.random.default_rng(abs(seed))
+ rng.shuffle(indices)
+ num_samples = 0
+ batches = []
+ while num_samplesint:
+ """
+ 返回当前 sampler 还会返回多少个 batch 的数据
+
+ :return:
+ """
+ num_sampler_per_rank = self.total_size//self.num_replicas
+ num_batches = num_sampler_per_rank//self.batch_size if self.drop_last else \
+ (num_sampler_per_rank+self.batch_size-1)//self.batch_size
+ return num_batches
+
+ def state_dict(self) -> Dict:
+ if self.old_batch_size != self.batch_size:
+ raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been"
+ " consumed. ")
+ states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
+ 'sampler_type': self.__class__.__name__, 'length': self.num_samples, 'shuffle': self.shuffle,
+ 'batch_size': self.batch_size,
+ 'num_replicas': self.num_replicas}
+
+ return states
+
+ def load_state_dict(self, states: Dict):
+ # 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0;
+ assert self.during_iter is False, "Cannot call load_state_dict() when it is " \
+ "during an unfinished iteration."
+
+ assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \
+ f"we cannot use {self.__class__.__name__} to load it."
+
+ length = states['length']
+ assert length == self.num_samples, "The number of samples is different between the checkpoint record " \
+ "and current dataset."
+ self.seed = states['seed']
+ self.epoch = states['epoch']
+ self.num_consumed_samples = states['num_consumed_samples']
+ if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0
+ self.num_consumed_samples = 0
+ if self.shuffle != states['shuffle']:
+ logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, "
+ f"we use shuffle={states['shuffle']}")
+ self.shuffle = states["shuffle"]
+ self.old_batch_size = states['batch_size']
+ self.old_num_replicas = states['num_replicas']
+
+
+class BucketedBatchSampler(ReproducibleBatchSampler):
+ """
+ 首先按照 ``sample`` 的长度排序,然后按照 *batch_size*num_batch_per_bucket* 为一个桶的大小,``sample`` 只会在这个桶内进行组
+ 合,这样每个 ``batch`` 中的 ``padding`` 数量会比较少 (因为桶内的数据的长度都接近)。
+
+ :param dataset: 实现了 __len__ 方法的数据容器。
+ :param length: 每条数据的长度。
+
+ * 为 ``List[int]`` 时
+ 应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;
+ * 为 ``str`` 时
+ 仅当传入的 ``dataset`` 是 :class:`~fastNLP.DataSet` 时,允许传入 `str` ,该 `str` 将被认为是 ``dataset`` 中的
+ ``field`` 。若 field 中的元素为 ``int``,则认为该值是 sample 的长度;若不为 ``int`` ,则尝试使用 ``len`` 方法
+ 获取该 ``field`` 中每个元素的长度。
+
+ :param batch_size: 每个 batch 的大小
+ :param num_batch_per_bucket: 多少个 ``batch`` 组成一个桶,数据只会在一个桶内进行 ``shuffle`` 。
+ :param shuffle: 如果为 True,将不进行 ``shuffle``,实际上数据会以从长到短的方式输出。
+ :param drop_last: 如果最后一个 `batch` 的 ``sample`` 数量无法凑齐 ``batch_size`` 这么多,是否需要丢掉。
+ :param seed: 设置的随机数种子
+ :param kwargs: fastNLP 保留使用
+ """
+ def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10,
+ shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs):
+ super().__init__()
+ if isinstance(dataset, DataSet) and isinstance(length, str):
+ length = dataset.get_field(length).content
+ if not isinstance(length[0], int):
+ length = list(map(len, length))
+ self.length = np.array(length, dtype=int)
+ self.sorted_indices = np.argsort(self.length)[::-1] # 按长度从高到低排序的
+ else:
+ try:
+ self.length = np.array(length, dtype=int)
+ self.sorted_indices = np.argsort(length)[::-1]
+ except BaseException as e:
+ logger.error(f"Cannot use {self.__class__.__name__} as length, since it is not sortable.")
+
+ assert len(length) == len(dataset), f"The length of `dataset`({len(dataset)}) and " \
+ f"`length`({len(length)}) should be equal."
+ assert len(self.sorted_indices) == len(dataset), "The indices and dataset should have equal length."
+
+ self.dataset = dataset
+
+ self.batch_size = batch_size
+ self.num_batch_per_bucket = num_batch_per_bucket
+ self.shuffle = shuffle
+ self.drop_last = drop_last
+ self.seed = int(seed)
+
+ self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量
+
+ # 多卡的相关的参数
+ self.num_replicas = kwargs.get("num_replicas", 1)
+ self.rank = kwargs.get("rank", 0)
+ self.epoch = kwargs.get("epoch", -1)
+ self.pad = kwargs.get("pad", False) # 该参数在单卡上不具有任何意义;
+
+ # 是否处于iteration之间,为True不允许调用 set_distributed()和load_state_dict()
+ self.during_iter = kwargs.get("during_iter", False)
+
+ # 以下变量为内部使用恢复状态的变量。
+ self.old_batch_size = kwargs.get('old_batch_size', self.batch_size)
+ self.old_num_batch_per_bucket = kwargs.get('old_num_batch_per_bucket', self.num_batch_per_bucket)
+
+ def set_distributed(self, num_replicas, rank, pad=True):
+ """
+ 进行分布式的相关设置,应当在初始化该 BatchSampler 本身后立即被调用。
+
+ :param num_replicas: 分布式训练中的进程总数
+ :param rank: 当前进程的 ``global_rank``
+ :param pad: 如果 sample 数量不整除 ``num_replicas`` 的时候,要不要 pad 一下,使得最终使得每个进程上
+ 的 sample 数量是完全一致的
+ :return:
+ """
+ assert self.during_iter is False, "Cannot set the sampler to be distributed when it is " \
+ "during an unfinished iteration."
+ assert num_replicas > 0 and isinstance(num_replicas, int)
+ assert isinstance(rank, int) and 0 <= rank < num_replicas
+ # 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态;
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.pad = pad
+
+ # num_samples = (len(self.dataset)+self.num_replicas-1)//self.num_replicas*self.num_replicas if pad \
+ # else len(self.dataset)
+ #
+ # if self.drop_last:
+ # assert self.num_replicas*self.batch_size<=num_samples, "The number of samples should be greater " \
+ # "than the number of replicates multiplied " \
+ # "with batch_size when drop_last=True."
+
+ return self
+
+ @property
+ def total_size(self):
+ """
+ 当前 BatchSampler 会最终产生出的 index 数量(包括了其它 rank 的),因为 ``replica`` 和 ``pad`` 的原因,这个值可能等于、
+ 大于或者小于 ``len(dataset)``。
+ """
+ return self.num_consumed_samples + self.num_replicas*self.num_left_samples
+
+ @property
+ def num_left_samples(self):
+ """
+ 当前迭代还有多少个 sample 结束,表示的是 **当前 rank** 的还剩多少。
+ """
+ num_consumed_samples = self.num_consumed_samples
+ return math.ceil((self.num_samples - num_consumed_samples) / self.num_replicas) if \
+ self.pad else math.floor(((self.num_samples - num_consumed_samples) / self.num_replicas))
+
+ @property
+ def num_samples(self):
+ """
+ 样本的总数
+ """
+ total_len = getattr(self.dataset, 'total_len', None)
+ if not isinstance(total_len, int):
+ total_len = len(self.dataset)
+ return total_len
+
+ def __len__(self)->int:
+ """
+ 返回当前 sampler 还会返回多少个 batch 的数据
+
+ :return:
+ """
+ num_sampler_per_rank = self.total_size//self.num_replicas
+ num_batches = num_sampler_per_rank//self.batch_size if self.drop_last else \
+ (num_sampler_per_rank+self.batch_size-1)//self.batch_size
+ return num_batches
+
+ def __iter__(self):
+ if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了
+ self.num_consumed_samples = 0
+ self.during_iter = True
+
+ sorted_indices = deepcopy(self.sorted_indices).tolist() # 按长度从高到低排序的
+
+ if self.shuffle:
+ if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的
+ _batches = []
+ for _i in range(self.old_num_replicas):
+ _sorted_indices = sorted_indices[_i:len(sorted_indices):self.old_num_replicas]
+ __batches = self.bucketerize(_sorted_indices, self.old_batch_size, self.old_num_batch_per_bucket,
+ seed=self.seed+self.epoch)
+ _batches.append(__batches)
+ batches = list(chain(*[_ for _ in zip(*_batches)]))
+ sorted_indices = list(chain(*batches))
+ sorted_indices = sorted_indices[self.num_consumed_samples:]
+ # 再进行排序
+ sub_length = self.length[sorted_indices]
+ sorted_indices = np.array(sorted_indices)[np.argsort(sub_length)[::-1]] # 按长度从高到低排序的
+ # 取出这个 rank ,
+ sorted_indices = sorted_indices[self.rank:len(sorted_indices):self.num_replicas]
+ batches = self.bucketerize(sorted_indices, self.batch_size, self.num_batch_per_bucket,
+ seed=self.seed+self.epoch)
+ batches = list(map(list, batches))
+ else:
+ sorted_indices = sorted_indices[self.num_consumed_samples:]
+ sorted_indices = sorted_indices[self.rank:len(sorted_indices):self.num_replicas]
+ _num_batches = len(sorted_indices) // self.batch_size
+ if _num_batches == 0:
+ batches = [sorted_indices]
+ else:
+ batches = list(map(list, np.array_split(sorted_indices[:_num_batches*self.batch_size], _num_batches)))
+ if len(sorted_indices)%self.batch_size!=0:
+ batches.append(sorted_indices[_num_batches*self.batch_size:])
+
+ need_pad_num = (self.num_samples-self.num_consumed_samples) % self.num_replicas
+ if self.pad and need_pad_num !=0 and need_pad_num<=self.rank:
+ if len(batches) > 0:
+ if len(batches[-1])self.rank:
+ if len(batches):
+ batches[-1].pop(-1)
+ if len(batches[-1])==0:
+ batches.pop(-1)
+
+ assert sum(map(len, batches)) == self.num_left_samples
+
+ if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size:
+ batches = batches[:-1]
+
+ # 暂时弃用
+ # self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30),
+ # num_consumed_samples=self.num_consumed_samples)
+ for batch in batches:
+ self.num_consumed_samples += self.num_replicas * len(batch)
+ yield list(map(int, batch))
+ self.during_iter = False
+ self.num_consumed_samples = 0
+ self.old_batch_size = self.batch_size
+ self.old_num_batch_per_bucket = self.num_batch_per_bucket
+ self.old_num_replicas = self.num_replicas
+ if self.epoch < 0: # 防止用户没有修改epoch,导致每个epoch都一样了
+ self.epoch -= 1
+
+ def bucketerize(self, sorted_indices, batch_size, num_batch_per_bucket, seed) -> List[List[int]]:
+ """
+ 将 ``indices`` 分桶
+
+ :param sorted_indices: List[int]
+ :param batch_size: int
+ :param num_batch_per_bucket: int
+ :param seed: int
+ :return:
+ """
+ # 实际的 bucket 大小
+ bucket_size = min(len(sorted_indices), batch_size * num_batch_per_bucket)
+ rng = np.random.default_rng(abs(seed))
+ num_buckets = (len(sorted_indices) + bucket_size - 1) // bucket_size
+ batches = []
+ batch_indices = []
+ for i in range(num_buckets):
+ bucket = sorted_indices[i * bucket_size:(i + 1) * bucket_size]
+ rng.shuffle(bucket) # bucket 内部 shuffle 一下
+ _num_batches = len(bucket) // batch_size
+ if _num_batches == 0:
+ _batches = [bucket]
+ else:
+ _batches = np.array_split(bucket[:_num_batches*batch_size], _num_batches)
+ if len(bucket) % batch_size != 0:
+ _batches.append(bucket[_num_batches*batch_size:])
+ batch_indices.extend(list(range(len(batches), len(batches) + len(_batches))))
+ batches.extend(_batches)
+ last_batches = []
+ # 最后一个batch 统一不参与shuffle,因为有的rank最后一个 batch 可能不足一个batch_size (不足的时候
+ # 一定要放在末尾,所以就干脆所有的rank都不对最后一个batch进行shuffle)。
+ if len(batches) >= 1:
+ last_batches = [list(batches[-1])]
+ batch_indices = list(batch_indices[:-1])
+ rng = np.random.default_rng(abs(seed)) # 这里防止由于bucket长度不同,对随机数状态有影响
+ rng.shuffle(batch_indices) # 不同的 batch 也 shuffle ,当前这种可以保证每张卡上每个 batch 长度都接近的。
+ batches = (np.array(batches, dtype=object)[batch_indices]).tolist()
+ if last_batches:
+ batches = batches + last_batches
+ return batches
+
+ def state_dict(self) -> Dict:
+ if self.old_batch_size != self.batch_size or self.old_num_batch_per_bucket != self.num_batch_per_bucket:
+ raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been"
+ " consumed. ")
+ states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
+ 'sampler_type': self.__class__.__name__, 'length': self.num_samples, 'shuffle': self.shuffle,
+ 'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket,
+ 'num_replicas': self.num_replicas
+ }
+
+ return states
+
+ def load_state_dict(self, states: Dict):
+ # 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0;
+ assert self.during_iter is False, "Cannot call load_state_dict() when it is " \
+ "during an unfinished iteration."
+
+ assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \
+ f"we cannot use {self.__class__.__name__} to load it."
+
+ length = states['length']
+ assert length == self.num_samples, "The number of samples is different between the checkpoint record " \
+ "and current dataset."
+ self.seed = states['seed']
+ self.epoch = states['epoch']
+ self.num_consumed_samples = states['num_consumed_samples']
+ if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0
+ self.num_consumed_samples = 0
+ if self.shuffle != states['shuffle']:
+ logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, "
+ f"we use shuffle={states['shuffle']}")
+ self.shuffle = states["shuffle"]
+ self.old_batch_size = states['batch_size']
+ self.old_num_batch_per_bucket = states['num_batch_per_bucket']
+ self.old_num_replicas = states['num_replicas']
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+ @property
+ def batch_idx_in_epoch(self):
+ if self.drop_last:
+ return self.num_samples // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size
+ else:
+ return (self.num_samples // self.num_replicas + self.batch_size - 1) // self.batch_size - \
+ (self.num_left_samples + self.batch_size - 1) // self.batch_size
\ No newline at end of file
diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py
new file mode 100644
index 00000000..1e57fc71
--- /dev/null
+++ b/fastNLP/core/samplers/reproducible_sampler.py
@@ -0,0 +1,375 @@
+"""
+:class:`ReproducibleSampler` 是 **fastNLP** 提供的一种特殊 Sampler,它可以记录采样过程中每一次采样和 epoch 的信息,
+方便在保存-加载后能够从上一次采样结束的地方继续进行采样,实现 **断点重训** 。
+
+.. note::
+
+ DataLoader 中只要存在 :class:`ReproducibleSampler` 或 :class:`~fastNLP.core.samplers.reproducible_batch_sampler.ReproducibleBatchSampler`
+ 中的一个便可以实现断点重训复现的功能。
+
+"""
+
+__all__ = [
+ 'ReproducibleSampler',
+ 'RandomSampler',
+ "SortedSampler",
+ "SequentialSampler"
+]
+
+from typing import Dict, List, Union, Sequence
+import math
+
+import numpy as np
+
+from fastNLP.core.log import logger
+from fastNLP.core.dataset import DataSet
+
+
+class ReproducibleSampler:
+ """
+ **可复现** 的 Sampler 对象。
+
+ 注意所有继承 :class:`ReproducibleSampler` 的类的 :meth:`__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 Sampler
+ 注意,所有 :meth:`__init__` 中初始化的变量,都不能含有 ``_`` 下横线作为开头;所有不在 :meth:`__init__` 中设置的变量都必须以下横线开头。
+
+ """
+ def __init__(self, **kwargs):
+ self.num_replicas = 1
+
+ def set_distributed(self, num_replicas, rank, pad=True):
+ raise NotImplementedError("Each specific sampler should implement its own `set_distributed` method.")
+
+ def __len__(self):
+ raise NotImplementedError("Each specific sampler should implement its own `__len__` method.")
+
+ def __iter__(self):
+ raise NotImplementedError("Each specific sampler should implement its own `__iter__` method.")
+
+ def state_dict(self):
+ """
+
+ :return:
+ """
+ raise NotImplementedError("Each specific sampler should implement its own `state_dict` method.")
+
+ def load_state_dict(self, states):
+ raise NotImplementedError("Each specific sampler should implement its own `load_state_dict` method.")
+
+ @property
+ def num_left_samples(self):
+ raise NotImplementedError("Each specific sampler should implement its own `num_left_samples` method.")
+
+ @property
+ def num_samples(self):
+ raise NotImplementedError("Each specific sampler should implement its own `num_samples` method.")
+
+ def set_epoch(self, epoch):
+ pass
+
+
+class RandomSampler(ReproducibleSampler):
+ """
+ 随机顺序的 Sampler 。
+
+ :param dataset: 实现了 __len__ 方法的数据容器
+ :param shuffle: 是否在每次 iterate 的时候打乱顺序
+ :param seed: 随机数种子
+ :param kwargs: fastNLP 内部使用的参数
+ """
+ def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):
+ super(RandomSampler, self).__init__()
+ self.dataset = dataset
+ self.shuffle = shuffle
+ self.seed = int(seed)
+
+ self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量
+
+ # 多卡的相关的参数
+ self.num_replicas = kwargs.get("num_replicas", 1)
+ self.rank = kwargs.get("rank", 0)
+ self.epoch = kwargs.get("epoch", -1)
+ self.pad = kwargs.get("pad", False) # 该参数在单卡上不具有任何意义;
+
+ # 是否处于iteration之间,为True不允许调用 set_distributed()和load_state_dict()
+ self.during_iter = kwargs.get("during_iter", False)
+
+ def __len__(self):
+ """
+ 返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank。
+ """
+ return self.total_size//self.num_replicas
+
+ def __iter__(self):
+ r"""
+ 当前使用 num_consumed_samples 做法会在交替使用的时候遇到问题。
+
+ Example::
+
+ >>> sampler = RandomSampler()
+ >>> iter1 = iter(sampler)
+ >>> iter2 = iter(sampler)
+ >>> next(iter1)
+ >>> next(iter2) # 当前num_consumed_samples的数量会发生变化
+ """
+
+ if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了
+ self.num_consumed_samples = 0
+ self.during_iter = True
+ indices = self.generate_indices()
+
+ if self.pad:
+ # add extra samples to make it evenly divisible
+ padding_size = self.total_size - len(indices)
+ if padding_size <= len(indices):
+ indices += indices[:padding_size]
+ else:
+ indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
+ else:
+ # remove tail of data to make it evenly divisible.
+ indices = indices[:self.total_size]
+
+ assert len(indices) == self.total_size
+
+ # subsample
+ indices = indices[self.num_consumed_samples:]
+ indices = indices[self.rank:len(indices):self.num_replicas]
+ assert len(indices) == self.num_left_samples
+ for idx, index in enumerate(indices, start=1):
+ self.num_consumed_samples += self.num_replicas
+ yield index
+ self.during_iter = False
+ self.num_consumed_samples = 0
+
+ def generate_indices(self) -> List[int]:
+ """
+ 生成随机序列
+ """
+ if self.shuffle:
+ indices = list(range(self.num_samples))
+ seed = self.seed + self.epoch
+ rng = np.random.default_rng(abs(seed))
+ rng.shuffle(indices)
+ if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。
+ self.epoch -= 1
+ else:
+ indices = list(range(self.num_samples))
+ return indices
+
+ def state_dict(self) -> Dict:
+ states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
+ 'sampler_type': self.__class__.__name__, 'length': self.num_samples, 'shuffle': self.shuffle}
+ return states
+
+ def load_state_dict(self, states: Dict):
+ # 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0;
+ assert self.during_iter is False, "Cannot call load_state_dict() when it is " \
+ "during an unfinished iteration."
+
+ assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \
+ f"we cannot use {self.__class__.__name__} to load it."
+
+ length = states['length']
+ assert length == self.num_samples, "The number of samples is different between the checkpoint " \
+ f"record({length}) and current dataset({self.num_samples})."
+ self.seed = states['seed']
+ self.epoch = states['epoch']
+ self.num_consumed_samples = states['num_consumed_samples']
+ if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0
+ self.num_consumed_samples = 0
+ if self.shuffle != states['shuffle']:
+ logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, "
+ f"we use shuffle={states['shuffle']}")
+ self.shuffle = states["shuffle"]
+
+ def set_epoch(self, epoch: int) -> None:
+ self.epoch = epoch
+
+ def set_distributed(self, num_replicas:int, rank:int, pad:bool=True):
+ """
+ 进行分布式的相关设置,应当在初始化该 Sampler 本身后立即被调用。
+
+ :param num_replicas: 分布式训练中的进程总数
+ :param rank: 当前进程的 ``global_rank``。
+ :param pad: 如果 sample 数量不整除 ``num_replicas`` 的时候,要不要 pad 一下,使得最终使得每个进程上
+ 的 sample 数量是完全一致的
+ :return: 自身
+ """
+
+ assert self.during_iter is False, "Cannot set the sampler to be distributed when it is " \
+ "during an unfinished iteration."
+ assert num_replicas>0 and isinstance(num_replicas, int)
+ assert isinstance(rank, int) and 0<=rank List[int]:
+ """
+ 生成随机序列
+
+ :return:
+ """
+ return list(range(self.num_samples))
+
+ def state_dict(self) -> Dict:
+ states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, 'length': self.num_samples}
+ return states
+
+ def load_state_dict(self, states: Dict):
+ # 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0;
+ assert self.during_iter is False, "Cannot call load_state_dict() when it is " \
+ "during an unfinished iteration."
+
+ assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \
+ f"we cannot use {self.__class__.__name__} to load it."
+
+ length = states['length']
+ assert length == self.num_samples, "The number of samples is different between the checkpoint " \
+ f"record({length}) and current dataset({self.num_samples})."
+ self.num_consumed_samples = states['num_consumed_samples']
+ if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0
+ self.num_consumed_samples = 0
+
+
+class SortedSampler(SequentialSampler):
+ """
+ 将 ``dataset`` 中的数据根据 ``length`` 从长到短进行迭代。在多卡情况下,由于 ``padding`` , 最后一个 ``sample`` 可能是最长
+ 的那个 ``sample`` 。
+
+ :param dataset: 实现了 __len__ 方法的数据容器
+ :param length: 每条数据的长度:
+
+ * 为 ``List[int]`` 时
+ 应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;
+ * 为 ``str`` 时
+ 仅当传入的 ``dataset`` 是 :class:`~fastNLP.DataSet` 时,允许传入 `str` ,该 `str` 将被认为是 ``dataset`` 中的
+ ``field`` 。若 field 中的元素为 ``int``,则认为该值是 sample 的长度;若不为 ``int`` ,则尝试使用 ``len`` 方法
+ 获取该 ``field`` 中每个元素的长度;
+
+ :param seed: 设置的随机数种子
+ :param kwargs: fastNLP 内部使用的参数
+ """
+ def __init__(self, dataset, length:Union[str, List], **kwargs):
+ super().__init__(dataset=dataset, **kwargs)
+ if isinstance(dataset, DataSet) and isinstance(length, str):
+ length = dataset.get_field(length).content
+ if not isinstance(length[0], int):
+ length = list(map(len, length))
+ self.length = np.array(length, dtype=int)
+ self.sorted_indices = np.argsort(self.length)[::-1] # 按长度从高到低排序的
+ else:
+ try:
+ self.length = np.array(length, dtype=int)
+ self.sorted_indices = np.argsort(length)[::-1]
+ except BaseException as e:
+ logger.error(f"Cannot use {self.__class__.__name__} as length, since it is not sortable.")
+
+ assert len(length) == self.num_samples, f"The length of `dataset`({len(dataset)}) and " \
+ f"`length`({self.num_samples}) should be equal."
+ assert len(self.sorted_indices) == self.num_samples, "The indices and dataset should have equal length."
+
+ self.length = np.array(length, dtype=int) # 按照长到短排列的序号。
+ self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的
+
+ def generate_indices(self) -> List[int]:
+ return self.sorted_indices
+
+ def __iter__(self):
+ if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了
+ self.num_consumed_samples = 0
+ self.during_iter = True
+ indices = self.generate_indices()
+
+ if self.pad:
+ padding_size = self.total_size - len(indices)
+ if padding_size <= len(indices):
+ indices += indices[:padding_size]
+ else:
+ indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
+ else:
+ # remove tail of data to make it evenly divisible.
+ indices = indices[:self.total_size]
+
+ assert len(indices) == self.total_size
+
+ # subsample
+ indices = indices[self.num_consumed_samples:]
+ indices = indices[self.rank:len(indices):self.num_replicas]
+ assert len(indices) == self.num_left_samples
+
+ for idx, index in enumerate(indices, start=1):
+ self.num_consumed_samples += self.num_replicas
+ yield index
+ self.during_iter = False
+ self.num_consumed_samples = 0
+
diff --git a/fastNLP/core/samplers/unrepeated_sampler.py b/fastNLP/core/samplers/unrepeated_sampler.py
new file mode 100644
index 00000000..9e9aa6b6
--- /dev/null
+++ b/fastNLP/core/samplers/unrepeated_sampler.py
@@ -0,0 +1,179 @@
+__all__ = [
+ 'UnrepeatedSampler',
+ 'UnrepeatedSortedSampler',
+ 'UnrepeatedRandomSampler',
+ "UnrepeatedSequentialSampler"
+]
+
+from typing import List, Union
+from fastNLP.core.dataset import DataSet
+
+import numpy as np
+
+
+class UnrepeatedSampler:
+ """
+ 在多卡场景下保证 indice 不重复的 Sampler。
+ """
+ pass
+
+
+class UnrepeatedRandomSampler(UnrepeatedSampler):
+ """
+ 考虑在多卡 evaluate 的场景下,不能重复采样。
+
+ :param dataset: 实现了 __len__ 方法的数据容器
+ :param shuffle: 如果为 ``True``,将不进行 shuffle,实际上数据会以从长到短的方式输出
+ :param seed: 设置的随机数种子
+ :param kwargs: fastNLP 内部使用的参数
+ """
+ def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs):
+ self.dataset = dataset
+ self.shuffle = shuffle
+ self.seed = int(seed)
+
+ # 多卡的相关的参数
+ self.num_replicas = kwargs.get('num_replicas', 1)
+ self.rank = kwargs.get('rank', 0)
+ self.epoch = kwargs.get('epoch', -1)
+
+ def __len__(self):
+ """
+ 返回 ``Sampler`` 一次完整的迭代过程会产生多少个 index 。多卡的情况下,只考虑 **当前rank** 。
+ :return:
+ """
+ num_common = self.num_samples//self.num_replicas
+ num_samples = num_common + int(self.rank < (self.num_samples-num_common*self.num_replicas))
+ return num_samples
+
+ def __iter__(self):
+ indices = self.generate_indices()
+
+ # subsample
+ indices = indices[self.rank:len(indices):self.num_replicas]
+ assert len(indices) == len(self)
+
+ for index in indices:
+ yield index
+
+ def generate_indices(self) -> List[int]:
+ """
+ 生成随机序列
+
+ :return:
+ """
+ if self.shuffle:
+ indices = list(range(self.num_samples))
+ seed = self.seed + self.epoch
+ rng = np.random.default_rng(abs(seed))
+ rng.shuffle(indices)
+ if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。
+ self.epoch -= 1
+ else:
+ indices = list(range(self.num_samples))
+ return indices
+
+ def set_epoch(self, epoch: int) -> None:
+ self.epoch = epoch
+
+ def set_distributed(self, num_replicas, rank):
+ """
+ 该方法本质上等同于 ddp 情形下的没有完成的初始化,应当在初始化该 Sampler 本身后立即被调用。
+
+ :param num_replicas: 分布式训练中的进程总数
+ :param rank: 当前进程的 ``global_rank``
+ :return: 自身
+ """
+ assert num_replicas<=self.num_samples, f"The number of replicas({num_replicas}) should be lesser than the " \
+ f"number of samples({self.num_samples})."
+ assert num_replicas>0 and isinstance(num_replicas, int)
+ assert isinstance(rank, int) and 0<=rank List[int]:
+ return self.sorted_indices
+
+
+class UnrepeatedSequentialSampler(UnrepeatedRandomSampler):
+ """
+ 按照顺序读取 dataset。
+
+ :param dataset: 实现了 __len__ 方法的数据容器。
+ :param chunk_dist: 如果为 ``True`` ,当多卡时将不间隔索取数据;为 ``False`` 时则会间隔取数据。假设 dataset 有 10 个 sample ,使用
+ 2 卡,如果为 ``True`` ,卡 **0** 拿 [0, 1, 2, 3, 4], 卡 **1** 拿 [5, 6, 7, 8, 9] ; 如果为 ``False`` ,则卡 **0** 拿 [0, 2, 4, 8, 8],
+ 卡 **1** 拿 [1, 3, 5, 7, 9] 。
+ :param kwargs:
+ """
+ def __init__(self, dataset, chunk_dist=False, **kwargs):
+ kwargs['shuffle'] = False
+ kwargs['seed'] = 0
+ super(UnrepeatedSequentialSampler, self).__init__(dataset, **kwargs)
+ self.chunk_dist = chunk_dist
+
+ def __iter__(self):
+ indices = self.generate_indices()
+ if self.num_replicas>1:
+ if self.chunk_dist:
+ chunk_size = len(indices)//self.num_replicas
+ start = chunk_size * self.rank
+ end = chunk_size * (self.rank + 1)
+ if self.rank == self.num_replicas - 1:
+ end = len(indices)
+ indices = indices[start:end]
+ else:
+ indices = indices[self.rank:len(indices):self.num_replicas]
+ for index in indices:
+ yield index
+
+ def generate_indices(self) -> List[int]:
+ return list(range(self.num_samples))
+
diff --git a/fastNLP/core/samplers/utils.py b/fastNLP/core/samplers/utils.py
new file mode 100644
index 00000000..f8535ed2
--- /dev/null
+++ b/fastNLP/core/samplers/utils.py
@@ -0,0 +1,69 @@
+__all__ = [
+ 're_instantiate_sampler'
+]
+from array import array
+from typing import Sequence
+from collections import deque
+
+
+def re_instantiate_sampler(sampler, new_sampler_class=None):
+ all_attributes = vars(sampler)
+ if new_sampler_class is not None:
+ return new_sampler_class(**all_attributes)
+ return type(sampler)(**all_attributes)
+
+
+def create_array(length, fill_value) -> array:
+ """
+ 根据长度自动创建 array ,超过 4294967295 需要使用 'L', 否则使用 'I'
+
+ :param length:
+ :param fill_value:
+ :return:
+ """
+ if not isinstance(fill_value, Sequence):
+ fill_value = [fill_value]*length
+
+ if length > 4294967295:
+ _index_lst = array("L", fill_value)
+ else:
+ _index_lst = array("I", fill_value)
+ return _index_lst
+
+
+class NumConsumedSamplesArray:
+ def __init__(self, buffer_size=2000, num_consumed_samples=0):
+ """
+ 保留 buffer_size 个 num_consumed_samples 数据,可以索引得到某个 index 下的 num_consumed_samples 多少
+
+ Example::
+
+ array = NumConsumedSamplesArray(buffer_size=3)
+ for i in range(10):
+ array.push(i)
+
+ array[9] # 输出为9,表示这个位置真实的 num_consumed_samples 是多少。
+ array[6] # 报错,因为只保留了3个最近的数据,6超过了最大buffer的记录了,即 [7, 8, 9]
+
+ 暂时由于 sampler 的 batch 都是规整的,先保留
+
+ :param buffer_size: 报错多少个历史。
+ :param num_consumed_samples: 第一个 num_consumed_samples 是多少。
+ """
+ self.count = 0
+ self.deque = deque(maxlen=buffer_size)
+ if num_consumed_samples is not None:
+ self.push(num_consumed_samples)
+ self.buffer_size = buffer_size
+
+ def __getitem__(self, item):
+ if len(self.deque) == 0: # 如果没有任何缓存的内容,说明还没有写入,直接返回0
+ return 0
+ assert isinstance(item, int), "Only int index allowed."
+ assert self.count-len(self.deque)<=item` 的1.3部分。
-Tester在验证进行之前会调用model.eval()提示当前进入了evaluation阶段,即会关闭nn.Dropout()等,在验证结束之后会调用model.train()恢复到训练状态。
-
-
-"""
-import time
-
-import torch
-import torch.nn as nn
-
-try:
- from tqdm.auto import tqdm
-except:
- from .utils import _pseudo_tqdm as tqdm
-
-from .batch import BatchIter, DataSetIter
-from .dataset import DataSet
-from .metrics import _prepare_metrics
-from .sampler import SequentialSampler
-from .utils import _CheckError
-from .utils import _build_args
-from .utils import _check_loss_evaluate
-from .utils import _move_dict_value_to_device
-from .utils import _get_func_signature
-from .utils import _get_model_device
-from .utils import _move_model_to_device
-from .utils import _build_fp16_env
-from .utils import _can_use_fp16
-from ._parallel_utils import _data_parallel_wrapper
-from ._parallel_utils import _model_contains_inner_module
-from functools import partial
-from ._logger import logger
-from .sampler import Sampler
-
-__all__ = [
- "Tester"
-]
-
-
-class Tester(object):
- r"""
- Tester是在提供数据,模型以及metric的情况下进行性能测试的类。需要传入模型,数据以及metric进行验证。
- """
-
- def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True,
- fp16=False, **kwargs):
- r"""
-
- :param ~fastNLP.DataSet,~fastNLP.BatchIter data: 需要测试的数据集
- :param torch.nn.Module model: 使用的模型
- :param ~fastNLP.core.metrics.MetricBase,List[~fastNLP.core.metrics.MetricBase] metrics: 测试时使用的metrics
- :param int batch_size: evaluation时使用的batch_size有多大。
- :param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型
- 的计算位置进行管理。支持以下的输入:
-
- 1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中,可见的第一个GPU中,可见的第二个GPU中;
-
- 2. torch.device:将模型装载到torch.device上。
-
- 3. int: 将使用device_id为该值的gpu进行训练
-
- 4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。
-
- 5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。
-
- 如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。
- :param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。
- :param bool use_tqdm: 是否使用tqdm来显示测试进度; 如果为False,则不会显示任何内容。
- :param bool fp16: 是否使用float16进行验证
- :param kwargs:
- Sampler sampler: 支持传入sampler控制测试顺序
- bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快数据速度。
- """
- super(Tester, self).__init__()
-
- if not isinstance(model, nn.Module):
- raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.")
-
- self.metrics = _prepare_metrics(metrics)
-
- self.data = data
- self._model = _move_model_to_device(model, device=device)
- self.batch_size = batch_size
- self.verbose = verbose
- self.use_tqdm = use_tqdm
- self.logger = logger
- self.pin_memory = kwargs.get('pin_memory', True)
-
- if isinstance(data, DataSet):
- sampler = kwargs.get('sampler', None)
- if sampler is None:
- sampler = SequentialSampler()
- elif not isinstance(sampler, (Sampler, torch.utils.data.Sampler)):
- raise ValueError(f"The type of sampler should be fastNLP.BaseSampler or pytorch's Sampler, got {type(sampler)}")
- if hasattr(sampler, 'set_batch_size'):
- sampler.set_batch_size(batch_size)
- self.data_iterator = DataSetIter(dataset=data, batch_size=batch_size, sampler=sampler,
- num_workers=num_workers,
- pin_memory=self.pin_memory)
- elif isinstance(data, BatchIter):
- self.data_iterator = data
- else:
- raise TypeError("data type {} not support".format(type(data)))
-
- # check predict
- if (hasattr(self._model, 'predict') and callable(self._model.predict)) or \
- (_model_contains_inner_module(self._model) and hasattr(self._model.module, 'predict') and
- callable(self._model.module.predict)):
- if isinstance(self._model, nn.DataParallel):
- self._predict_func_wrapper = partial(_data_parallel_wrapper('predict',
- self._model.device_ids,
- self._model.output_device),
- network=self._model.module)
- self._predict_func = self._model.module.predict # 用于匹配参数
- elif isinstance(self._model, nn.parallel.DistributedDataParallel):
- self._predict_func = self._model.module.predict
- self._predict_func_wrapper = self._model.module.predict # 用于调用
- else:
- self._predict_func = self._model.predict
- self._predict_func_wrapper = self._model.predict
- else:
- if _model_contains_inner_module(self._model):
- self._predict_func_wrapper = self._model.forward
- self._predict_func = self._model.module.forward
- else:
- self._predict_func = self._model.forward
- self._predict_func_wrapper = self._model.forward
-
- if fp16:
- _can_use_fp16(model=model, device=device, func=self._predict_func)
- self.auto_cast, _grad_scaler = _build_fp16_env(not fp16)
-
- def test(self):
- r"""开始进行验证,并返回验证结果。
-
- :return Dict[Dict]: dict的二层嵌套结构,dict的第一层是metric的名称; 第二层是这个metric的指标。一个AccuracyMetric的例子为{'AccuracyMetric': {'acc': 1.0}}。
- """
- # turn on the testing mode; clean up the history
- self._model_device = _get_model_device(self._model)
- network = self._model
- self._mode(network, is_test=True)
- data_iterator = self.data_iterator
- eval_results = {}
- try:
- with torch.no_grad():
- if not self.use_tqdm:
- from .utils import _pseudo_tqdm as inner_tqdm
- else:
- inner_tqdm = tqdm
- with inner_tqdm(total=len(data_iterator), leave=False, dynamic_ncols=True) as pbar:
- pbar.set_description_str(desc="Test")
-
- start_time = time.time()
-
- for batch_x, batch_y in data_iterator:
- _move_dict_value_to_device(batch_x, batch_y, device=self._model_device,
- non_blocking=self.pin_memory)
- with self.auto_cast():
- pred_dict = self._data_forward(self._predict_func, batch_x)
- if not isinstance(pred_dict, dict):
- raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} "
- f"must be `dict`, got {type(pred_dict)}.")
- for metric in self.metrics:
- metric(pred_dict, batch_y)
-
- if self.use_tqdm:
- pbar.update()
-
- for metric in self.metrics:
- eval_result = metric.get_metric()
- if not isinstance(eval_result, dict):
- raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be "
- f"`dict`, got {type(eval_result)}")
- metric_name = metric.get_metric_name()
- eval_results[metric_name] = eval_result
- pbar.close()
- end_time = time.time()
- test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!'
- if self.verbose >= 0:
- self.logger.info(test_str)
- except _CheckError as e:
- prev_func_signature = _get_func_signature(self._predict_func)
- _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,
- check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
- dataset=self.data, check_level=0)
- finally:
- self._mode(network, is_test=False)
- if self.verbose >= 1:
- logger.info("[tester] \n{}".format(self._format_eval_results(eval_results)))
- return eval_results
-
- def _mode(self, model, is_test=False):
- r"""Train mode or Test mode. This is for PyTorch currently.
-
- :param model: a PyTorch model
- :param is_test: bool, whether in test mode or not.
-
- """
- if is_test:
- model.eval()
- else:
- model.train()
-
- def _data_forward(self, func, x):
- r"""A forward pass of the model. """
- x = _build_args(func, **x)
- y = self._predict_func_wrapper(**x)
- return y
-
- def _format_eval_results(self, results):
- r"""Override this method to support more print formats.
-
- :param results: dict, (str: float) is (metrics name: value)
-
- """
- _str = ''
- for metric_name, metric_result in results.items():
- _str += metric_name + ': '
- _str += ", ".join([str(key) + "=" + str(value) for key, value in metric_result.items()])
- _str += '\n'
- return _str[:-1]
diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py
deleted file mode 100644
index f4f8a093..00000000
--- a/fastNLP/core/trainer.py
+++ /dev/null
@@ -1,1038 +0,0 @@
-r"""
-Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在不同训练任务中重复撰以下步骤的代码
-
- (1) epoch循环;
-
- (2) 将数据分成不同的Batch;
-
- (3) 对Batch进行pad;
-
- (4) 每个epoch结束或一定step后进行验证集验证;
-
- (5) 保存获得更好验证性能的模型。
-
-
-----------------------------
-1. Trainer的基本使用
-----------------------------
-
-下面的例子是使用神经网络来进行预测一个序列中是否有偶数个1。
-
-.. code-block:: python
-
- import numpy as np
- from torch import nn
- import torch
- import torch.nn.functional as F
- from torch.optim import SGD
-
- from fastNLP import DataSet
- from fastNLP import Trainer
- from fastNLP import CrossEntropyLoss
- from fastNLP import AccuracyMetric
- from fastNLP.modules.decoder import MLP
-
- # 模型
- class Model(nn.Module):
- def __init__(self, input_num):
- super().__init__()
- self.fcs = MLP([input_num, 40, 40, 2], 'relu')
-
- def forward(self, x):
- x = self.fcs(x)
- return {'pred': x}
- model = Model(10)
-
- # 生成数据
- def generate_psedo_dataset(num_samples):
- dataset = DataSet()
- data = np.random.randint(2, size=(num_samples, 10))
- label = np.sum(data, axis=1)%2
- dataset = DataSet({'x':data.astype(float), 'label': label})
- dataset.set_input('x')
- dataset.set_target('label')
- return dataset
- tr_dataset = generate_psedo_dataset(1000)
- dev_data = generate_psedo_dataset(100)
-
- # 训练
- trainer = Trainer(tr_dataset, model, loss=CrossEntropyLoss(target='label'),
- optimizer=SGD(model.parameters(), lr=0.1),n_epochs=1000,
- dev_data = dev_data, metrics=AccuracyMetric(target='label'))
- trainer.train()
-
-由上面的例子可以看出通过使用Trainer,可以使得训练部分的代码大幅减少。
-使用Trainer需要满足以下几个条件:
-
-1.1 模型
-----------------------------
-
-1 模型的forward()的参数名需要与DataSet中的名字对应。实际上fastNLP在将DataSet中的数据传递给模型forward()时,是
-通过匹配名称实现的。所以上例中,如果Model的forward函数修改为forward(self, data), 则DataSet中的'x'这个field就应该
-改名为'data'。
-
-2 传递给forward()的参数是DataSet中被设置为input的那些field。但如果forward()中没有对应的参数,则不会将数据传递
-给forward()。例如,DataSet中'x1', 'x2'都是input,但是模型的函数为forward(self, x1), 那么'x2'不会传递给forward()。
-
-3 模型的forward()返回值需要为一个dict。
-
-1.2 Loss
-----------------------------
-
-fastNLP中的为了不限制forward函数的返回内容数量(比如一些复杂任务需要返回多个内容,如Dependency Parsing,
-:mod:`Loss` 与 :mod:`Metric` 都使用了通过名称来匹配相应内容的策略。如上面的例子中
-
-.. code-block:: python
-
- trainer = Trainer(tr_dataset, model, loss=CrossEntropyLoss(target='label'),
- optimizer=SGD(model.parameters(), lr=0.1),n_epochs=1000,
- dev_data = dev_data, metrics=AccuracyMetric(target='label'))
-
-loss被设置为了 :class:`~fastNLP.CrossEntropyLoss` , 但在初始化的时候传入了target='label'这个参数,
-:class:`~fastNLP.CrossEntropyLoss` 的初始化参数为(pred=None, target=None, padding_idx=-100)。
-
-这里的两个参数分别为计算CrossEntropy时需要使用到的模型的预测值与真实值。
-其中 `pred` 一般来自于模型forward()的返回结果,`target` 一般是来自于DataSet中被设置为target的field。
-由于每个人对真实值或者model的返回值取名并不一样,所以fastNLP的 :mod:`Loss` 提供一种类似于映射的机制来匹配对应的值,
-比如这里 :class:`~fastNLP.CrossEntropyLoss` 将尝试找到名为'label'的内容来作为真实值得到loss;
-而pred=None, 则 :class:`~fastNLP.CrossEntropyLoss` 使用'pred'作为名称匹配预测值,
-正好forward的返回值也叫pred,所以这里不需要申明pred。
-
-尽管fastNLP使用了映射机制来使得loss的计算变得比较灵活,但有些情况下loss必须在模型中进行计算,比如使用了CRF的模型。
-fastNLP中提供了 :class:`~fastNLP.LossInForward` 这个loss。
-这个loss的原理是直接在forward()的返回结果中找到loss_key(默认寻找'loss')指定的那个tensor,并使用它作为loss。
-如果Trainer初始化没有提供loss则默认使用 :class:`~fastNLP.LossInForward` 。
-
-.. todo::
- 补充一个例子 详细例子可以参照
-
-1.3 Metric
-----------------------------
-
-:mod:`Metric` 使用了与上述Loss一样的策略,即使用名称进行匹配。
-AccuracyMetric(target='label')的情况与CrossEntropyLoss 是同理的。
-
-在进行验证时,可能用到的计算与forward()中不太一致,没有办法直接从forward()的结果中得到预测值,这时模型可以提供一个predict()方法,
-如果提供的模型具有predict方法,则在模型验证时将调用predict()方法获取预测结果,
-传入到predict()的参数也是从DataSet中被设置为input的field中选择出来的;
-与forward()一样,返回值需要为一个dict。
-
-.. todo::
- 补充一个例子 具体例子可以参考
-
-----------------------------
-2. Trainer的代码检查
-----------------------------
-
-由于在fastNLP中采取了映射的机制,所以难免可能存在对应出错的情况。Trainer提供一种映射检查机制,可以通过check_code_level来进行控制
-比如下面的例子中,由于各种原因产生的报错
-
-Example2.1
-----------------------------
-
-.. code-block:: python
-
- import numpy as np
- from torch import nn
- import torch
- from torch.optim import SGD
- from fastNLP import Trainer
- from fastNLP import DataSet
-
- class Model(nn.Module):
- def __init__(self):
- super().__init__()
- self.fc = nn.Linear(1, 1)
- def forward(self, x, b):
- loss = torch.mean((self.fc(x)-b)**2)
- return {'loss': loss}
- model = Model()
-
- dataset = DataSet({'a': np.arange(10), 'b':np.arange(10)*2})
- dataset.set_input('a', 'b')
-
- trainer = Trainer(dataset, model, loss=None, optimizer=SGD(model.parameters(), lr=0.001))
-
- trainer = Trainer(dataset, model, SGD(model.parameters()))
- # 会报以下的错误
- # input fields after batch(if batch size is 2):
- # a: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2])
- # b: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2])
- # There is no target field.
- # ....
- # NameError:
- # Problems occurred when calling Model.forward(self, x, b)
- # missing param: ['x']
- # unused field: ['a']
- # Suggestion: You need to provide ['x'] in DataSet and set it as input.
-
-这里就是由于在Trainer初始化的时候,fastNLP会尝试使用一个batch_size=2的batch去运行一遍forward()以及backward()。这里有两类
-信息可以为你提供参考
-
-1 'input fields after batch...'这部分显示的是train dataset经过Batch操作后,每个field对应的类型以及进行shape。这里
-因为train dataset没有target所以没有显示。根据这里可以看出是否正确将需要的内容设置为了input或target。
-
-2 NameError,NameError发生在映射出错的情况。这里报错的原因是由于尝试进行forward计算时(可以通过Model.forward(self, x, b)判断
-出当前是在调取forward),却没有获取到forward()函数中需要的'x';在报错信息中同时指出了缺'x',而'a'没有被使用,那么可能
-就是由于field的名称不对。这里将dataset中'a'这个field的名称改为'x',或者model的参数从'x'修改为'a'都可以解决问题。
-
-下面的例子是由于loss计算的时候找不到需要的值
-
-Example2.2
-----------------------------
-
-.. code-block:: python
-
- import numpy as np
- from torch import nn
- from torch.optim import SGD
- from fastNLP import Trainer
- from fastNLP import DataSet
- from fastNLP import L1Loss
- import torch
-
- class Model(nn.Module):
- def __init__(self):
- super().__init__()
- self.fc = nn.Linear(1, 1)
- def forward(self, a):
- return {'pred_b': self.fc(a.unsqueeze(1)).squeeze(1), 'No use':1}
-
- model = Model()
-
- dataset = DataSet({'a': np.arange(10, dtype=float), 'b':np.arange(10, dtype=float)*2})
-
- dataset.set_input('a')
- dataset.set_target('b')
-
- trainer = Trainer(dataset, model, loss=L1Loss(target='label'), optimizer=SGD(model.parameters(), lr=0.001))
- # 报错信息如下
- # input fields after batch(if batch size is 2):
- # a: (1)type:torch.Tensor (2)dtype:torch.float32, (3)shape:torch.Size([2])
- # target fields after batch(if batch size is 2):
- # b: (1)type:torch.Tensor (2)dtype:torch.float32, (3)shape:torch.Size([2])
- # ....
- # NameError:
- # Problems occurred when calling L1Loss.get_loss(self, pred, target)
- # missing param: ['pred(assign to `pred` in `L1Loss`)', 'label(assign to `target` in `L1Loss`)']
- # unused field: ['b']
- # unused param: ['pred_b', 'No use']
- # target field: ['b']
- # param from Model.forward(self, a): ['pred_b', 'No use']
- # Suggestion: (1). Check key assignment for `target` when initialize L1Loss. Or provide `label` in DataSet or output of Model.forward(self, a).
- # (2). Check key assignment for `pred` when initialize L1Loss. Or provide `pred` in DataSet or output of Model.forward(self, a).
-
-报错信息也包含两部分:
-
-1 第一部分与上面是一样的
-
-2 这里报错的原因是由于计算loss的时候找不到相应的值(通过L1Loss.get_loss(self, pred, target)判断出来的);
-报错的原因是因为 `pred` 和 `label` (我们在初始化L1Loss时将target指定为了label)都没有找到。
-这里'unused field'是DataSet中出现了,但却没有被设置为input或者target的field;
-'unused param'是forward()中返回且没有被使用到的内容;'target field'是被设置为了target的field;
-'param from Model.forward(self, a)'是forward()返回的所有key。"Suggestion"是关于当前错误处理的建议。
-
-但是在一些情况下,比如forward()返回值只有一个,target也只有一个,fastNLP不会进行匹配,而直接将forward()的结果作为pred,
-将DataSet中的target设置为target。上面的例子在返回值中加入了一个'No use'则只是为了使得Loss去匹配结果。
-
-
-下面是带有dev dataset时如果出现错误会发生的报错,
-
-Example2.3
-----------------------------
-
-.. code-block:: python
-
- import numpy as np
- from torch import nn
- from torch.optim import SGD
- from fastNLP import Trainer
- from fastNLP import DataSet
- from fastNLP import AccuracyMetric
- import torch
-
- class Model(nn.Module):
- def __init__(self):
- super().__init__()
- self.fc = nn.Linear(1, 1)
- def forward(self, a, b):
- loss = torch.mean((self.fc(a.float().unsqueeze(1))-b.float())**2)
- return {'loss': loss}
- def predict(self, a): # 使用predict()进行验证
- return {'output':self.fc(a.float().unsqueeze(1))} #这里return的值不包含'pred'这个key
- model = Model()
-
- dataset = DataSet({'a': np.arange(10), 'b':np.arange(10)*2})
- dev_data = DataSet({'a': np.arange(10, 20), 'b':np.arange(10, 20)*2})
-
- dataset.set_input('a', 'b')
- dev_data.set_input('a') # 这里没有设置target
-
- trainer = Trainer(dataset, model, loss=None, optimizer=SGD(model.parameters(), lr=0.001),
- dev_data=dev_data, metrics=AccuracyMetric())
-
- # 报错信息
- # ...
- # NameError:
- # Problems occurred when calling AccuracyMetric.evaluate(self, pred, target, seq_len=None)
- # missing param: ['pred(assign to `pred` in `AccuracyMetric`)', 'target(assign to `target` in `AccuracyMetric`)']
- # unused param: ['output']
- # target field: []
- # param from Model.predict(self, a): ['output']
- # Suggestion: (1). Check key assignment for `pred` when initialize AccuracyMetric. Or provide `pred` in DataSet or output of Model.predict(self, a).
- # (2). Check key assignment for `target` when initialize AccuracyMetric. Or provide `target` in DataSet or output of Model.predict(self, a).
-
-报错信息和前面都是类似的,但是可以通过'AccuracyMetric.evaluate(self, pred, target, seq_len=None)'看出这里是evaluation
-的时候发生了错误。这样避免了需要在完成一整个epoch的训练才能发现evaluation弄错的情况。这里的修改是通过在初始化metric的时候
-指明通过'output'获取`pred`, 即AccuracyMetric(pred='output')。
-
-可以通过check_code_level调节检查的强度。默认为0,即进行检查。
-
-----------------------------
-3. Trainer与callback
-----------------------------
-
-虽然Trainer本身已经集成了一些功能,但仍然不足以囊括训练过程中可能需要到的功能,比如负采样,learning rate decay, Early Stop等。
-为了解决这个问题fastNLP引入了callback的机制,:class:`~fastNLP.Callback` 是一种在Trainer训练过程中特定阶段会运行的函数集合,
-所有的 :class:`~fastNLP.Callback` 都具有on_*(比如on_train_start, on_backward_begin)等函数。
-如果 Callback 实现了该函数,则Trainer运行至对应阶段,会进行调用,例如::
-
- from fastNLP import Callback, EarlyStopCallback, Trainer, CrossEntropyLoss, AccuracyMetric
- from fastNLP.models import CNNText
-
- start_time = time.time()
-
- class MyCallback(Callback):
- def on_epoch_end(self):
- print('{:d}ms\n\n'.format(round((time.time()-start_time)*1000)))
-
- model = CNNText((len(vocab),50), num_classes=5, padding=2, dropout=0.1)
- trainer = Trainer(model=model, train_data=train_data, dev_data=dev_data, loss=CrossEntropyLoss(),
- metrics=AccuracyMetric(), callbacks=[MyCallback(),EarlyStopCallback(10)])
- trainer.train()
-
-这里,我们通过继承 :class:`~fastNLP.Callback` 类定义了自己的 callback 的,并和内置的 :class:`~fastNLP.EarlyStopCallback`
-一起传给了 :class:`~fastNLP.Trainer` ,增强了 :class:`~fastNLP.Trainer` 的功能
-
-fastNLP已经自带了很多callback函数供使用,可以参考 :mod:`fastNLP.core.callback` 。
-
-"""
-__all__ = [
- "Trainer"
-]
-
-import os
-import time
-from datetime import datetime, timedelta
-
-import numpy as np
-import torch
-import torch.nn as nn
-
-try:
- from tqdm.auto import tqdm
-except:
- from .utils import _pseudo_tqdm as tqdm
-import warnings
-from pkg_resources import parse_version
-
-from .batch import DataSetIter, BatchIter
-from .callback import CallbackManager, CallbackException, Callback
-from .dataset import DataSet
-from .losses import _prepare_losser
-from .metrics import _prepare_metrics
-from .optimizer import Optimizer
-from .sampler import Sampler
-from .sampler import RandomSampler, ConstTokenNumSampler
-from .tester import Tester
-from .utils import _CheckError
-from .utils import _build_args
-from .utils import _check_forward_error
-from .utils import _check_loss_evaluate
-from .utils import _move_dict_value_to_device
-from .utils import _get_func_signature
-from .utils import _get_model_device
-from .utils import _move_model_to_device
-from .utils import _build_fp16_env
-from .utils import _can_use_fp16
-from ._parallel_utils import _model_contains_inner_module
-from ._logger import logger
-
-
-class Trainer(object):
- r"""
- Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在不同训练任务中重复撰写
- (1) epoch循环;
- (2) 将数据分成不同的Batch;
- (3) 对Batch进行pad;
- (4) 每个epoch结束或一定step后进行验证集验证;
- (5) 保存获得更好验证性能的模型等。
-
- 详细的介绍参见 :mod:`fastNLP.core.trainer`
- """
-
- def __init__(self, train_data, model, optimizer=None, loss=None,
- batch_size=32, sampler=None, drop_last=False, update_every=1,
- num_workers=0, n_epochs=10, print_every=5,
- dev_data=None, metrics=None, metric_key=None,
- validate_every=-1, save_path=None, use_tqdm=True, device=None,
- callbacks=None, check_code_level=0, fp16=False, **kwargs):
- r"""
- :param train_data: 训练集, :class:`~fastNLP.DataSet` 类型或 :class:`~fastNLP.BatchIter` 的子类
- :param nn.modules model: 待训练的模型
- :param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器
- :param int batch_size: 训练和验证的时候的batch大小。
- :param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward`
- :param sampler: Batch数据生成的顺序, :class:`~fastNLP.Sampler` 类型。如果为None,默认使用 :class:`~fastNLP.RandomSampler`
- :param drop_last: 如果最后一个batch没有正好为batch_size这么多数据,就扔掉最后一个batch
- :param num_workers: int, 有多少个线程来进行数据pad处理。
- :param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128
- 会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。
- :param int n_epochs: 需要优化迭代多少次。
- :param int print_every: 多少次反向传播更新tqdm显示的loss; 如果use_tqdm=False, 则多少次反向传播打印loss。
- :param dev_data: 用于做验证的DataSet, :class:`~fastNLP.DataSet` 类型。
- :param metrics: 验证的评估函数。可以只使用一个 :class:`Metric` ,
- 也可以使用多个 :class:`Metric` ,通过列表传入。
- 如验证时取得了更好的验证结果(如果有多个Metric,以列表中第一个Metric为准),且save_path不为None,
- 则保存当前模型。Metric种类详见 :mod:`metrics模块 ` 。仅在传入dev_data时有效。
- :param str,None metric_key: :class:`Metric` 有时会有多个指标,
- 比如 :class:`~fastNLP.core.metrics.SpanFPreRecMetric` 中包含了'f', 'pre', 'rec'。此时需
- 要指定以哪个指标为准。另外有些指标是越小效果越好,比如语言模型的困惑度,这种情况下,在key前面增加一个'-'来表
- 明验证时,值越小越好(比如: "-ppl")。仅在传入dev_data时有效。
- :param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。
- :param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存
- 最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。
- :param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。
- :param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型
- 的计算位置进行管理。支持以下的输入:
-
- 1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中,
- 可见的第二个GPU中;
-
- 2. torch.device:将模型装载到torch.device上。
-
- 3. int: 将使用device_id为该值的gpu进行训练
-
- 4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。
-
- 5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。
-
- 已知可能会出现的问题:Adagrad优化器可能无法正常使用这个参数,请手动管理模型位置。
-
- :param list(callbacks) callbacks: 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以
- 通过callback机制实现。 可使用的callback参见 :mod:`callback模块 `
- :param int check_code_level: 模型检查等级. -1: 不进行检查; 0: 仅出现错误时停止; 1: 如果有field没有被使用,
- 报告警告信息; 2: 有任何field没有被使用都报错. 检查的原理是通过使用很小的batch(默认2个sample)来运行代码,但是
- 这个过程理论上不会修改任何参数,只是会检查能否运行。但如果(1)模型中存在将batch_size写为某个固定值的情况;
- (2)模型中存在累加前向计算次数的,可能会多计算1次。以上情况建议将check_code_level设置为-1。
- :param bool fp16: 是否使用fp16进行训练。
- :param kwargs: 支持配置可选参数
- bool test_use_tqdm: 在dev上验证的时候是否开启tqdm
- Sampler test_sampler: 在evaluate的时候使用的sampler
- bool test_use_fp16: evalute的时候是否使用fp16测试,默认与fp16相同的取值。
- bool set_grad_to_none: 在zero_grad的时候是否将gradient设置为None,而不是设置为zero
- GradScaler grad_scaler: 仅在fp16为True时有效,如果不使用torch.cuda.amp.GradScaler的初始化参数,可传入一个已经初始化后的
- grad_scaler。
- bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快数据速度。
- """
- super(Trainer, self).__init__()
- if not isinstance(model, nn.Module):
- raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.")
-
- # check metrics and dev_data
- if (not metrics) and dev_data is not None:
- raise ValueError("No metric for dev_data evaluation.")
- if metrics and (dev_data is None):
- raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ")
-
- # check update every
- assert update_every >= 1, "update_every must be no less than 1."
- self.update_every = int(update_every)
-
- # check save_path
- if not (save_path is None or isinstance(save_path, str)):
- raise ValueError("save_path can only be None or `str`.")
- # prepare evaluate
- metrics = _prepare_metrics(metrics)
-
- # parse metric_key
- # increase_better is True. It means the exp result gets better if the indicator increases.
- # It is true by default.
- self.increase_better = True
- if metric_key is not None:
- self.increase_better = False if metric_key[0] == "-" else True
- self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
- else:
- self.metric_key = None
- # prepare loss
- losser = _prepare_losser(loss)
-
- if isinstance(train_data, BatchIter):
- if sampler is not None:
- warnings.warn("sampler is ignored when train_data is a BatchIter.")
- if num_workers>0:
- warnings.warn("num_workers is ignored when train_data is BatchIter.")
- if drop_last:
- warnings.warn("drop_last is ignored when train_data is BatchIter.")
- # concerning issue from https://github.com/pytorch/pytorch/issues/57273
- self.pin_memory = kwargs.get('pin_memory', False if parse_version(torch.__version__)==parse_version('1.9') else True)
- if isinstance(model, nn.parallel.DistributedDataParallel): # 如果是分布式的
- # device为None
- if device is not None:
- warnings.warn("device is ignored when model is nn.parallel.DistributedDataParallel.")
- device = None
- # Sampler要是分布式的
- if sampler is None:
- sampler = torch.utils.data.DistributedSampler(train_data)
- elif not isinstance(sampler, torch.utils.data.DistributedSampler):
- raise TypeError("When using nn.parallel.DistributedDataParallel, "
- "sampler must be None or torch.utils.data.DistributedSampler.")
- # 不能保存模型
- if save_path:
- raise RuntimeError("Saving model in Distributed situation is not allowed right now.")
- else:
- # sampler check
- if sampler is not None and not isinstance(sampler, (Sampler, torch.utils.data.Sampler)):
- raise ValueError(f"The type of sampler should be fastNLP.BaseSampler or pytorch's Sampler, got {type(sampler)}")
- if sampler is None:
- sampler = RandomSampler()
- elif hasattr(sampler, 'set_batch_size'):
- sampler.set_batch_size(batch_size)
- if isinstance(sampler, ConstTokenNumSampler): # 直接使用固定token数量的Sampler
- assert isinstance(train_data,
- DataSet), f"When sampler is `ConstTokenNumSampler`, the train_data must" \
- f" be `DataSet`."
- sampler(train_data)
- train_data = DataSetIter(train_data,
- batch_size=1, sampler=None, as_numpy=False, num_workers=num_workers,
- pin_memory=self.pin_memory, drop_last=drop_last, timeout=0, worker_init_fn=None,
- batch_sampler=sampler)
-
- if isinstance(train_data, DataSet):
- self.data_iterator = DataSetIter(dataset=train_data, batch_size=batch_size, sampler=sampler,
- num_workers=num_workers, drop_last=drop_last,
- pin_memory=self.pin_memory)
- elif isinstance(train_data, BatchIter):
- self.data_iterator = train_data
- train_data = train_data.dataset
- check_code_level = -1 # 强制跳过校验
- else:
- raise TypeError("train_data type {} not support".format(type(train_data)))
-
- model.train()
- self.model = _move_model_to_device(model, device=device)
- if _model_contains_inner_module(self.model):
- self._forward_func = self.model.module.forward
- else:
- self._forward_func = self.model.forward
-
- self.fp16 = fp16
- self.verbose = kwargs.get('verbose', 0)
-
- # check fp16相关的设置
- self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16)
- self.grad_scaler = _grad_scaler()
- if self.fp16:
- _can_use_fp16(device=device, model=model, func=self._forward_func)
- grad_scaler = kwargs.get('grad_scaler', None)
- if grad_scaler is not None:
- self.grad_scaler = grad_scaler
- else:
- self.grad_scaler = _grad_scaler()
- self.test_use_fp16 = kwargs.get('test_use_fp16', fp16)
- self.set_grad_to_none = kwargs.get('set_grad_to_none', True)
-
- if check_code_level > -1:
- # _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的field名与模型的输入
- # 名是否匹配
- dev_dataset = dev_data
- if isinstance(dev_data, BatchIter):
- dev_dataset = None
- warnings.warn("dev_data is of BatchIter type, ignore validation checking.")
- check_batch_size = min(batch_size, DEFAULT_CHECK_BATCH_SIZE)
- if isinstance(self.model, nn.DataParallel):
- _num_devices = len(self.model.device_ids)
- if batch_size//_num_devices>1: # 如果多卡是每个卡可以分多个数据的,则用每个卡给两个sample
- check_batch_size = max(len(self.model.device_ids)*2, check_batch_size)
- else:
- check_batch_size = max(len(self.model.device_ids), check_batch_size)
- _check_code(dataset=train_data, model=self.model, losser=losser, forward_func=self._forward_func, metrics=metrics,
- dev_data=dev_dataset, metric_key=self.metric_key, check_level=check_code_level,
- batch_size=check_batch_size)
-
- self.train_data = train_data
- self.dev_data = dev_data # If None, No validation.
- self.losser = losser
- self.metrics = metrics
- self.n_epochs = int(n_epochs)
- self.batch_size = int(batch_size)
- self.save_path = save_path
- self.print_every = int(print_every)
- self.validate_every = int(validate_every) if validate_every != 0 else -1
- self.best_metric_indicator = None
- self.best_dev_epoch = None
- self.best_dev_step = None
- self.best_dev_perf = None
- self.n_steps = len(self.data_iterator) * self.n_epochs
-
- if isinstance(optimizer, torch.optim.Optimizer):
- self.optimizer = optimizer
- elif isinstance(optimizer, Optimizer):
- self.optimizer = optimizer.construct_from_pytorch(self.model.parameters())
- elif optimizer is None:
- self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3)
- else:
- if not (hasattr(optimizer, 'step') and callable(optimizer.step)):
- raise TypeError("optimizer must have a callable step() function.")
- else:
- self.optimizer = optimizer
-
- self.logger = logger
-
- self.use_tqdm = use_tqdm
- self.test_use_tqdm = kwargs.get('test_use_tqdm', self.use_tqdm)
- self.pbar = None
- self.print_every = abs(self.print_every)
- self.kwargs = kwargs
- if self.dev_data is not None:
- self.tester = Tester(model=self.model,
- data=self.dev_data,
- metrics=self.metrics,
- batch_size=kwargs.get("dev_batch_size", self.batch_size),
- device=None, # 由上面的部分处理device
- verbose=0,
- use_tqdm=self.test_use_tqdm,
- sampler=kwargs.get('test_sampler', None),
- fp16=self.test_use_fp16,
- num_workers=num_workers,
- pin_memory=self.pin_memory)
-
- self.start_time = None # start timestamp
-
- if isinstance(callbacks, Callback):
- callbacks = [callbacks]
-
- self.callback_manager = CallbackManager(env={"trainer": self},
- callbacks=callbacks)
-
- def train(self, load_best_model=True, on_exception='auto', **kwargs):
- r"""
- 使用该函数使Trainer开始训练。
-
- :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现
- 最好的模型参数。
- :param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。
- 支持'ignore','raise', 'auto': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出;
- 'auto'将ignore以下两种Exception: CallbackException与KeyboardInterrupt, raise其它exception.
- :param kwargs:
- int verbose: 为1时在发生异常时会打印异常发生时batch中的数据在dataset中的index
- :return dict: 返回一个字典类型的数据,
- 内含以下内容::
-
- seconds: float, 表示训练时长
- 以下三个内容只有在提供了dev_data的情况下会有。
- best_eval: Dict of Dict, 表示evaluation的结果。第一层的key为Metric的名称,
- 第二层的key为具体的Metric
- best_epoch: int,在第几个epoch取得的最佳值
- best_step: int, 在第几个step(batch)更新取得的最佳值
-
- """
- results = {}
- verbose = kwargs.get('verbose', 0)
- if self.n_epochs <= 0:
- self.logger.info(f"training epoch is {self.n_epochs}, nothing was done.")
- results['seconds'] = 0.
- return results
- try:
- self._model_device = _get_model_device(self.model)
- self._mode(self.model, is_test=False)
- self._load_best_model = load_best_model
- # 加上millsecond,防止两个太接近的保存
- self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f'))
- start_time = time.time()
- self.logger.info("training epochs started " + self.start_time)
- self.step = 0
- self.epoch = 1
- try:
- self.callback_manager.on_train_begin()
- self._train()
- self.callback_manager.on_train_end()
-
- except BaseException as e:
- self.callback_manager.on_exception(e)
- if verbose>0:
- self.logger.info(f"The data indices for current batch are: {self.data_iterator.cur_batch_indices}.")
- if on_exception == 'auto':
- if not isinstance(e, (CallbackException, KeyboardInterrupt)):
- raise e
- elif on_exception == 'raise':
- raise e
-
- if self.dev_data is not None and self.best_dev_perf is not None and load_best_model:
- model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])
- load_succeed = self._load_model(self.model, model_name)
- if load_succeed:
- self.logger.info("Reloaded the best model.")
- else:
- self.logger.info("Fail to reload best model.")
-
- if self.dev_data is None and self.save_path is not None:
- model_name = "_".join([self.model.__class__.__name__, self.start_time])
- self._save_model(self.model, model_name)
-
- finally:
- if self.dev_data is not None and self.best_dev_perf is not None:
- self.logger.info(
- "\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step))
- self.logger.info(self.tester._format_eval_results(self.best_dev_perf))
- results['best_eval'] = self.best_dev_perf
- results['best_epoch'] = self.best_dev_epoch
- results['best_step'] = self.best_dev_step
-
- results['seconds'] = round(time.time() - start_time, 2)
-
- return results
-
- def _train(self):
- if not self.use_tqdm:
- from .utils import _pseudo_tqdm as inner_tqdm
- else:
- inner_tqdm = tqdm
- start = time.time()
- with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True,
- initial=self.step) as pbar:
- self.pbar = pbar
- avg_loss = 0
- self.batch_per_epoch = self.data_iterator.num_batches
- for epoch in range(self.epoch, self.n_epochs + 1):
- self.epoch = epoch
- pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
- # early stopping
- self.callback_manager.on_epoch_begin()
- for batch_x, batch_y in self.data_iterator:
- self.step += 1
- _move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
- indices = self.data_iterator.get_batch_indices()
- # negative sampling; replace unknown; re-weight batch_y
- self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
- prediction = self._data_forward(self.model, batch_x)
-
- # edit prediction
- self.callback_manager.on_loss_begin(batch_y, prediction)
- with self.auto_cast():
- loss = self._compute_loss(prediction, batch_y).mean()
- loss = loss / self.update_every
- avg_loss += loss.item()
-
- # Is loss NaN or inf? requires_grad = False
- self.callback_manager.on_backward_begin(loss)
- self._grad_backward(loss)
- self.callback_manager.on_backward_end()
-
- self._update()
- self.callback_manager.on_step_end()
-
- if self.step % self.print_every == 0:
- avg_loss = float(avg_loss) / self.print_every
- if self.use_tqdm:
- print_output = "loss:{:<6.5f}".format(avg_loss)
- pbar.update(self.print_every)
- else:
- end = time.time()
- diff = timedelta(seconds=round(end - start))
- print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
- epoch, self.step, avg_loss, diff)
- pbar.set_postfix_str(print_output)
- avg_loss = 0
- self.callback_manager.on_batch_end()
-
- if (self.validate_every > 0 and self.step % self.validate_every == 0) \
- and self.dev_data is not None:
- eval_res = self._do_validation(epoch=epoch, step=self.step)
- eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}: ".format(epoch, self.n_epochs, self.step,
- self.n_steps)
- # pbar.write(eval_str + '\n')
- self.logger.info(eval_str)
- self.logger.info(self.tester._format_eval_results(eval_res)+'\n')
- # ================= mini-batch end ==================== #
- if self.validate_every<0 and self.dev_data is not None: # 在epoch结束之后的evaluate
- eval_res = self._do_validation(epoch=epoch, step=self.step)
- eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}: ".format(epoch, self.n_epochs, self.step,
- self.n_steps)
- # pbar.write(eval_str + '\n')
- self.logger.info(eval_str)
- self.logger.info(self.tester._format_eval_results(eval_res) + '\n')
- # lr decay; early stopping
- self.callback_manager.on_epoch_end()
- # =============== epochs end =================== #
- if self.dev_data is not None and (self.validate_every>0 and self.n_steps%self.validate_every!=0):
- eval_res = self._do_validation(epoch=epoch, step=self.step)
- eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}: ".format(epoch, self.n_epochs, self.step,
- self.n_steps)
- # pbar.write(eval_str + '\n')
- self.logger.info(eval_str)
- self.logger.info(self.tester._format_eval_results(eval_res) + '\n')
- pbar.close()
- self.pbar = None
- # ============ tqdm end ============== #
-
- def _do_validation(self, epoch, step):
- self.callback_manager.on_valid_begin()
- res = self.tester.test()
-
- is_better_eval = False
- if self._better_eval_result(res):
- if self.save_path is not None:
- self._save_model(self.model,
- "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]))
- elif self._load_best_model:
- self._best_model_states = {name: param.cpu().clone() for name, param in self.model.state_dict().items()}
- self.best_dev_perf = res
- self.best_dev_epoch = epoch
- self.best_dev_step = step
- is_better_eval = True
- # get validation results; adjust optimizer
- self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval)
- return res
-
- def _mode(self, model, is_test=False):
- r"""Train mode or Test mode. This is for PyTorch currently.
-
- :param model: a PyTorch model
- :param bool is_test: whether in test mode or not.
-
- """
- if is_test:
- model.eval()
- else:
- model.train()
-
- def _update(self):
- r"""Perform weight update on a model.
-
- """
- if self.step % self.update_every == 0:
- self.grad_scaler.step(self.optimizer)
- self.grad_scaler.update()
-
- def _data_forward(self, network, x):
- x = _build_args(self._forward_func, **x)
- with self.auto_cast():
- y = network(**x)
- if not isinstance(y, dict):
- raise TypeError(
- f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.")
- return y
-
- def _grad_backward(self, loss):
- r"""Compute gradient with link rules.
-
- :param loss: a scalar where back-prop starts
-
- For PyTorch, just do "loss.backward()"
- """
- if (self.step-1) % self.update_every == 0:
- self._clear_grad(self.optimizer, self.set_grad_to_none)
- self.grad_scaler.scale(loss).backward()
-
- def _clear_grad(self, optimizer, set_to_none=True):
- param_groups = optimizer.param_groups
- for group in param_groups:
- for p in group['params']:
- if p.grad is not None:
- if set_to_none:
- p.grad = None
- else:
- if p.grad.grad_fn is not None:
- p.grad.detach_()
- else:
- p.grad.requires_grad_(False)
- p.grad.zero_()
-
- def _compute_loss(self, predict, truth):
- r"""Compute loss given prediction and ground truth.
-
- :param predict: prediction dict, produced by model.forward
- :param truth: ground truth dict, produced by batch_y
- :return: a scalar
- """
- return self.losser(predict, truth)
-
- def _save_model(self, model, model_name, only_param=False):
- r""" 存储不含有显卡信息的state_dict或model
- :param model:
- :param model_name:
- :param only_param:
- :return:
- """
- if self.save_path is not None:
- model_path = os.path.join(self.save_path, model_name)
- if not os.path.exists(self.save_path):
- os.makedirs(self.save_path, exist_ok=True)
- if _model_contains_inner_module(model):
- model = model.module
- if only_param:
- state_dict = model.state_dict()
- for key in state_dict:
- state_dict[key] = state_dict[key].cpu()
- torch.save(state_dict, model_path)
- else:
- model.cpu()
- torch.save(model, model_path)
- model.to(self._model_device)
-
- def _load_model(self, model, model_name, only_param=False):
- # 返回bool值指示是否成功reload模型
- if self.save_path is not None:
- model_path = os.path.join(self.save_path, model_name)
- if only_param:
- states = torch.load(model_path)
- else:
- states = torch.load(model_path).state_dict()
- if _model_contains_inner_module(model):
- model.module.load_state_dict(states)
- else:
- model.load_state_dict(states)
- elif hasattr(self, "_best_model_states"):
- model.load_state_dict(self._best_model_states)
- else:
- return False
- return True
-
- def _better_eval_result(self, metrics):
- r"""Check if the current epoch yields better validation results.
-
- :return bool value: True means current results on dev set is the best.
- """
- indicator, indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics)
- if self.metric_key is None:
- self.metric_key = indicator
- is_better = True
- if self.best_metric_indicator is None:
- # first-time validation
- self.best_metric_indicator = indicator_val
- else:
- if self.increase_better is True:
- if indicator_val > self.best_metric_indicator:
- self.best_metric_indicator = indicator_val
- else:
- is_better = False
- else:
- if indicator_val < self.best_metric_indicator:
- self.best_metric_indicator = indicator_val
- else:
- is_better = False
- return is_better
-
- @property
- def is_master(self):
- r"""是否是主进程"""
- return True
-
-DEFAULT_CHECK_BATCH_SIZE = 2
-DEFAULT_CHECK_NUM_BATCH = 2
-
-
-def _get_value_info(_dict):
- # given a dict value, return information about this dict's value. Return list of str
- strs = []
- for key, value in _dict.items():
- _str = ''
- if isinstance(value, torch.Tensor):
- _str += "\t{}: (1)type:torch.Tensor (2)dtype:{}, (3)shape:{} ".format(key,
- value.dtype, value.size())
- elif isinstance(value, np.ndarray):
- _str += "\t{}: (1)type:numpy.ndarray (2)dtype:{}, (3)shape:{} ".format(key,
- value.dtype, value.shape)
- else:
- _str += "\t{}: type:{}".format(key, type(value))
- strs.append(_str)
- return strs
-
-
-def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAULT_CHECK_BATCH_SIZE,
- dev_data=None, metric_key=None, check_level=0):
- # check get_loss 方法
- model_device = _get_model_device(model=model)
- _iter = DataSetIter(dataset, batch_size=batch_size, sampler=None)
-
- for batch_count, (batch_x, batch_y) in enumerate(_iter):
- _move_dict_value_to_device(batch_x, batch_y, device=model_device)
- # forward check
- if batch_count == 0:
- info_str = ""
- input_fields = _get_value_info(batch_x)
- target_fields = _get_value_info(batch_y)
- if len(input_fields) > 0:
- info_str += "input fields after batch(if batch size is {}):\n".format(batch_size)
- info_str += "\n".join(input_fields)
- info_str += '\n'
- else:
- raise RuntimeError("There is no input field.")
- if len(target_fields) > 0:
- info_str += "target fields after batch(if batch size is {}):\n".format(batch_size)
- info_str += "\n".join(target_fields)
- info_str += '\n'
- else:
- info_str += 'There is no target field.'
- logger.info(info_str)
- _check_forward_error(forward_func=forward_func, dataset=dataset,
- batch_x=batch_x, check_level=check_level)
- refined_batch_x = _build_args(forward_func, **batch_x)
- pred_dict = model(**refined_batch_x)
- func_signature = _get_func_signature(forward_func)
- if not isinstance(pred_dict, dict):
- raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.")
-
- # loss check
- try:
- loss = losser(pred_dict, batch_y)
- # check loss output
- if batch_count == 0:
- if not isinstance(loss, torch.Tensor):
- raise TypeError(
- f"The return value of {_get_func_signature(losser.get_loss)} should be `torch.Tensor`, "
- f"but got `{type(loss)}`.")
- if len(loss.size()) != 0:
- raise ValueError(
- f"The size of return value of {_get_func_signature(losser.get_loss)} is {loss.size()}, "
- f"should be torch.size([])")
- loss.backward()
- except _CheckError as e:
- # TODO: another error raised if _CheckError caught
- pre_func_signature = _get_func_signature(forward_func)
- _check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature,
- check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
- dataset=dataset, check_level=check_level)
- model.zero_grad()
- if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH:
- break
-
- if dev_data is not None:
- tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics,
- batch_size=batch_size, verbose=-1, use_tqdm=False)
- evaluate_results = tester.test()
- _check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics)
-
-
-def _check_eval_results(metrics, metric_key, metric_list):
- # metrics: tester返回的结果
- # metric_key: 一个用来做筛选的指标,来自Trainer的初始化
- # metric_list: 多个用来做评价的指标,来自Trainer的初始化
- if isinstance(metrics, tuple):
- loss, metrics = metrics
-
- if isinstance(metrics, dict):
- metric_dict = list(metrics.values())[0] # 取第一个metric
-
- if metric_key is None:
- indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0]
- else:
- # metric_key is set
- if metric_key not in metric_dict:
- raise RuntimeError(f"metric key {metric_key} not found in {metric_dict}")
- indicator_val = metric_dict[metric_key]
- indicator = metric_key
- else:
- raise RuntimeError("Invalid metrics type. Expect {}, got {}".format((tuple, dict), type(metrics)))
- return indicator, indicator_val
diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py
deleted file mode 100644
index a7a286d0..00000000
--- a/fastNLP/core/utils.py
+++ /dev/null
@@ -1,1120 +0,0 @@
-r"""
-utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。
-"""
-
-__all__ = [
- "cache_results",
- "seq_len_to_mask",
- "get_seq_len"
-]
-
-import inspect
-import os
-import warnings
-from collections import Counter, namedtuple
-from typing import List
-
-import _pickle
-import numpy as np
-import torch.nn as nn
-from prettytable import PrettyTable
-
-from ._logger import logger
-from ._parallel_utils import _model_contains_inner_module
-# from .vocabulary import Vocabulary
-import torch
-import contextlib
-from pkg_resources import parse_version
-
-
-_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
- 'varargs'])
-
-
-class ConfusionMatrix:
- r"""a dict can provide Confusion Matrix"""
- def __init__(self, show_result=None,vocab=None, print_ratio=False):
- r"""
- :param show_result: list type, 数据类型需要和target保持一致
- :param vocab: 需要有to_word方法,建议直接使用Fastnlp.core.Vocabulary。
- :param print_ratio: 限制print的输出,False只输出数量Confusion Matrix, True还会输出百分比Confusion Matrix, 分别为行/列
- """
- if vocab and not hasattr(vocab, "to_word"):
- raise TypeError(
- f"`vocab` in {_get_func_signature(self.__init__)} must be Fastnlp.core.Vocabulary,"
- f"got {type(vocab)}.")
- self.confusiondict = {} # key: pred index, value:target word ocunt
- self.predcount = {} # key:pred index, value:count
- self.targetcount = {} # key:target index, value:count
- self.show_result = show_result
- self.vocab = vocab
- self.print_ratio = print_ratio
-
- def add_pred_target(self, pred, target): # 一组结果
- r"""
- 通过这个函数向ConfusionMatrix加入一组预测结果
- :param list pred: 预测的标签列表
- :param list target: 真实值的标签列表
- :return ConfusionMatrix
- confusion=ConfusionMatrix()
- pred = [2,1,3]
- target = [2,2,1]
- confusion.add_pred_target(pred, target)
- print(confusion)
-
- target 1 2 3 all
- pred
- 1 0 1 0 1
- 2 0 1 0 1
- 3 1 0 0 1
- all 1 2 0 3
- """
- for p, t in zip(pred, target): #
- self.predcount[p] = self.predcount.get(p, 0) + 1
- self.targetcount[t] = self.targetcount.get(t, 0) + 1
- if p in self.confusiondict:
- self.confusiondict[p][t] = self.confusiondict[p].get(t, 0) + 1
- else:
- self.confusiondict[p] = {}
- self.confusiondict[p][t] = 1
- return self.confusiondict
-
- def clear(self):
- r"""
- 清空ConfusionMatrix,等待再次新加入
- :return:
- """
- self.confusiondict = {}
- self.targetcount = {}
- self.predcount = {}
-
- def get_result(self):
- r"""
- :return list output: ConfusionMatrix content,具体值与汇总统计
- """
- row2idx = {}
- idx2row = {}
- # 已知的所有键/label
- totallabel = sorted(
- list(
- set(self.targetcount.keys()).union(set(
- self.predcount.keys()))))
- lenth = len(totallabel)
-
- for label, idx in zip(totallabel, range(lenth)):
- idx2row[
- label] = idx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,...
- row2idx[
- idx] = label # 建立一个临时字典,value:vocab的index, key: 行列index 0,1,2...->1,3,5,...
- output = []
- for i in row2idx.keys(): # 第i行
- p = row2idx[i]
- l = [0 for _ in range(lenth)]
- if self.confusiondict.get(p, None):
- for t, c in self.confusiondict[p].items():
- l[idx2row[t]] = c # 完成一行
- l = [n for n in l] + [sum(l)]
- output.append(l)
- tail = [self.targetcount.get(row2idx[k], 0) for k in row2idx.keys()]
- tail += [sum(tail)]
- output.append(tail)
- return output
-
- def get_percent(self, dim=0):
- r"""
- :param dim int: 0/1, 0 for row,1 for column
- :return list output: ConfusionMatrix content,具体值与汇总统计
- """
- result = self.get_result()
- if dim == 0:
- tmp = np.array(result)
- tmp = tmp / (tmp[:, -1].reshape([len(result), -1]))
- tmp[np.isnan(tmp)] = 0
- tmp = tmp * 100
- elif dim == 1:
- tmp = np.array(result).T
- tmp = tmp / (tmp[:, -1].reshape([len(result), -1]) + 1e-12)
- tmp = tmp.T * 100
- tmp = np.around(tmp, decimals=2)
- return tmp.tolist()
-
- def get_aligned_table(self, data, flag="result"):
- r"""
- :param data: highly recommend use get_percent/ get_result return as dataset here, or make sure data is a n*n list type data
- :param flag: only difference between result and other words is whether "%" is in output string
- :return: an aligned_table ready to print out
- """
- row2idx = {}
- idx2row = {}
- # 已知的所有键/label
- totallabel = sorted(
- list(
- set(self.targetcount.keys()).union(set(
- self.predcount.keys()))))
- lenth = len(totallabel)
- # namedict key :label idx value: str label name/label idx
- namedict = dict([
- (k, str(k if self.vocab == None else self.vocab.to_word(k)))
- for k in totallabel
- ])
- for label, lineidx in zip(totallabel, range(lenth)):
- idx2row[
- label] = lineidx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,...
- row2idx[
- lineidx] = label # 建立一个临时字典,key: 行列index 0,1,2...->1,3,5,...,value:vocab的index,
- # 这里打印东西
- out = str()
- output = []
- # 表头
- head = (["target"] +
- [str(namedict[row2idx[k]]) for k in row2idx.keys()] + ["all"])
- col_lenths = [len(h) for h in head]
- output.append(head)
- output.append(["pred"])
- # 内容
- for i in row2idx.keys(): # 第i行
- p = row2idx[i]
- h = namedict[p]
- l = [h] + [[str(n) + "%", str(n)][flag == "result"]
- for n in data[i]]
- col_lenths = [
- max(col_lenths[idx], [len(i) for i in l][idx])
- for idx in range(len(col_lenths))
- ]
- output.append(l)
-
- tail = ["all"] + [[str(n) + "%", str(n)][flag == "result"]
- for n in data[-1]]
- col_lenths = [
- max(col_lenths[idx], [len(i) for i in tail][idx])
- for idx in range(len(col_lenths))
- ]
- output.append(tail)
-
- if self.show_result:
- missing_item=[]
- missing_item = [i for i in self.show_result if i not in idx2row]
- self.show_result = [i for i in self.show_result if i in idx2row]
- if missing_item:
- print(f"Noticing label(s) which is/are not in target list appeared, final output string will not contain{str(missing_item)}")
- if self.show_result:
- show_col = [0] + [i + 1 for i in [idx2row[i] for i in self.show_result]]
- show_row = [0]+[i+2 for i in [idx2row[i] for i in self.show_result]]
- output = [[row[col] for col in show_col] for row in [output[row] for row in show_row]]
- output.insert(1,["pred"])
- for line in output:
- for colidx in range(len(line)):
- out += "%*s" % (col_lenths[colidx], line[colidx]) + "\t"
- out += "\n"
- return "\n" + out
-
- def __repr__(self):
- r"""
- :return string output: ConfusionMatrix的格式化输出,包括表头各标签字段,具体值与汇总统计。
- """
- result = self.get_result()
- o0 = self.get_aligned_table(result, flag="result")
-
- out = str()
- if self.print_ratio:
- p1 = self.get_percent()
- o1 = "\nNotice the row direction\n" + self.get_aligned_table(
- p1, flag="percent")
- p2 = self.get_percent(dim=1)
- o2 = "\nNotice the column direction\n" + self.get_aligned_table(
- p2, flag="percent")
- out = out + o0 + o1 + o2
- else:
- out = o0
- return out
-
-
-
-class Option(dict):
- r"""a dict can treat keys as attributes"""
-
- def __getattr__(self, item):
- try:
- return self.__getitem__(item)
- except KeyError:
- raise AttributeError(item)
-
- def __setattr__(self, key, value):
- if key.startswith('__') and key.endswith('__'):
- raise AttributeError(key)
- self.__setitem__(key, value)
-
- def __delattr__(self, item):
- try:
- self.pop(item)
- except KeyError:
- raise AttributeError(item)
-
- def __getstate__(self):
- return self
-
- def __setstate__(self, state):
- self.update(state)
-
-
-def _prepare_cache_filepath(filepath):
- r"""
- 检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径
- :param filepath: str.
- :return: None, if not, this function will raise error
- """
- _cache_filepath = os.path.abspath(filepath)
- if os.path.isdir(_cache_filepath):
- raise RuntimeError("The cache_file_path must be a file, not a directory.")
- cache_dir = os.path.dirname(_cache_filepath)
- if not os.path.exists(cache_dir):
- os.makedirs(cache_dir, exist_ok=True)
-
-
-def cache_results(_cache_fp, _refresh=False, _verbose=1):
- r"""
- cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用::
-
- import time
- import numpy as np
- from fastNLP import cache_results
-
- @cache_results('cache.pkl')
- def process_data():
- # 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时
- time.sleep(1)
- return np.random.randint(10, size=(5,))
-
- start_time = time.time()
- print("res =",process_data())
- print(time.time() - start_time)
-
- start_time = time.time()
- print("res =",process_data())
- print(time.time() - start_time)
-
- # 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间
- # Save cache to cache.pkl.
- # res = [5 4 9 1 8]
- # 1.0042750835418701
- # Read cache from cache.pkl.
- # res = [5 4 9 1 8]
- # 0.0040721893310546875
-
- 可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理::
-
- # 还是以上面的例子为例,如果需要重新生成另一个cache,比如另一个数据集的内容,通过如下的方式调用即可
- process_data(_cache_fp='cache2.pkl') # 完全不影响之前的‘cache.pkl'
-
- 上面的_cache_fp是cache_results会识别的参数,它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的
- 'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。
- 上面的例子即为使用_cache_fp的情况,这三个参数不会传入到你的函数中,当然你写的函数参数名也不可能包含这三个名称::
-
- process_data(_cache_fp='cache2.pkl', _refresh=True) # 这里强制重新生成一份对预处理的cache。
- # _verbose是用于控制输出信息的,如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache
-
- :param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在
- 函数调用的时候传入_cache_fp这个参数。
- :param bool _refresh: 是否重新生成cache。
- :param int _verbose: 是否打印cache的信息。
- :return:
- """
-
- def wrapper_(func):
- signature = inspect.signature(func)
- for key, _ in signature.parameters.items():
- if key in ('_cache_fp', '_refresh', '_verbose'):
- raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key))
-
- def wrapper(*args, **kwargs):
- if '_cache_fp' in kwargs:
- cache_filepath = kwargs.pop('_cache_fp')
- assert isinstance(cache_filepath, str), "_cache_fp can only be str."
- else:
- cache_filepath = _cache_fp
- if '_refresh' in kwargs:
- refresh = kwargs.pop('_refresh')
- assert isinstance(refresh, bool), "_refresh can only be bool."
- else:
- refresh = _refresh
- if '_verbose' in kwargs:
- verbose = kwargs.pop('_verbose')
- assert isinstance(verbose, int), "_verbose can only be integer."
- else:
- verbose = _verbose
- refresh_flag = True
-
- if cache_filepath is not None and refresh is False:
- # load data
- if os.path.exists(cache_filepath):
- with open(cache_filepath, 'rb') as f:
- results = _pickle.load(f)
- if verbose == 1:
- logger.info("Read cache from {}.".format(cache_filepath))
- refresh_flag = False
-
- if refresh_flag:
- results = func(*args, **kwargs)
- if cache_filepath is not None:
- if results is None:
- raise RuntimeError("The return value is None. Delete the decorator.")
- _prepare_cache_filepath(cache_filepath)
- with open(cache_filepath, 'wb') as f:
- _pickle.dump(results, f)
- logger.info("Save cache to {}.".format(cache_filepath))
-
- return results
-
- return wrapper
-
- return wrapper_
-
-
-def _save_model(model, model_name, save_dir, only_param=False):
- r""" 存储不含有显卡信息的state_dict或model
- :param model:
- :param model_name:
- :param save_dir: 保存的directory
- :param only_param:
- :return:
- """
- model_path = os.path.join(save_dir, model_name)
- if not os.path.isdir(save_dir):
- os.makedirs(save_dir, exist_ok=True)
- if _model_contains_inner_module(model):
- model = model.module
- if only_param:
- state_dict = model.state_dict()
- for key in state_dict:
- state_dict[key] = state_dict[key].cpu()
- torch.save(state_dict, model_path)
- else:
- _model_device = _get_model_device(model)
- model.cpu()
- torch.save(model, model_path)
- model.to(_model_device)
-
-
-def _move_model_to_device(model, device):
- r"""
- 将model移动到device
-
- :param model: torch.nn.DataParallel or torch.nn.Module. 当为torch.nn.DataParallel, 则只是调用一次cuda。device必须为
- None。
- :param str,int,torch.device,list(int),list(torch.device) device: 将模型load到哪个设备。默认为None,即Trainer不对模型
- 的计算位置进行管理。支持以下的输入:
-
- 1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中,
- 可见的第二个GPU中;
-
- 2. torch.device:将模型装载到torch.device上。
-
- 3. int: 将使用device_id为该值的gpu进行训练
-
- 4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。
-
- 5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。
-
- :return: torch.nn.DataParallel or torch.nn.Module
- """
- # if isinstance(model, torch.nn.parallel.DistributedDataParallel):
- # raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.")
-
- if device is None:
- if isinstance(model, torch.nn.DataParallel):
- model.cuda(model.device_ids[0])
- return model
- else:
- if not torch.cuda.is_available() and ((isinstance(device, str) and device!='cpu') or
- (isinstance(device, torch.device) and device.type != 'cpu')):
- raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.")
-
- if isinstance(model, torch.nn.DataParallel):
- raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.")
-
- if isinstance(device, int):
- assert device > -1, "device can only be non-negative integer"
- assert torch.cuda.device_count() > device, "Only has {} gpus, cannot use device {}.".format(
- torch.cuda.device_count(),
- device)
- device = torch.device('cuda:{}'.format(device))
- elif isinstance(device, str):
- device = torch.device(device)
- if device.type == 'cuda' and device.index is not None:
- assert device.index < torch.cuda.device_count(), "Only has {} gpus, cannot use device cuda:{}.".format(
- torch.cuda.device_count(),
- device)
- elif isinstance(device, torch.device):
- if device.type == 'cuda' and device.index is not None:
- assert device.index < torch.cuda.device_count(), "Only has {} gpus, cannot use device cuda:{}.".format(
- torch.cuda.device_count(),
- device)
- elif isinstance(device, list):
- types = set([type(d) for d in device])
- assert len(types) == 1, "Mixed type in device, only `int` allowed."
- assert list(types)[0] == int, "Only int supported for multiple devices."
- assert len(set(device)) == len(device), "Duplicated device id found in device."
- for d in device:
- assert d > -1, "Only non-negative device id allowed."
- if len(device) > 1:
- output_device = device[0]
- model = nn.DataParallel(model, device_ids=device, output_device=output_device)
- device = torch.device(device[0])
- else:
- raise TypeError("Unsupported device type.")
- model = model.to(device)
- return model
-
-
-def _get_model_device(model):
- r"""
- 传入一个nn.Module的模型,获取它所在的device
-
- :param model: nn.Module
- :return: torch.device,None 如果返回值为None,说明这个模型没有任何参数。
- """
- # TODO 这个函数存在一定的风险,因为同一个模型可能存在某些parameter不在显卡中,比如BertEmbedding. 或者跨显卡
- assert isinstance(model, nn.Module)
-
- parameters = list(model.parameters())
- if len(parameters) == 0:
- return None
- else:
- return parameters[0].device
-
-
-def _build_args(func, **kwargs):
- r"""
- 根据func的初始化参数,从kwargs中选择func需要的参数
-
- :param func: callable
- :param kwargs: 参数
- :return:dict. func中用到的参数
- """
- spect = inspect.getfullargspec(func)
- if spect.varkw is not None:
- return kwargs
- needed_args = set(spect.args)
- defaults = []
- if spect.defaults is not None:
- defaults = [arg for arg in spect.defaults]
- start_idx = len(spect.args) - len(defaults)
- output = {name: default for name, default in zip(spect.args[start_idx:], defaults)}
- output.update({name: val for name, val in kwargs.items() if name in needed_args})
- return output
-
-
-def _map_args(maps: dict, **kwargs):
- # maps: key=old name, value= new name
- output = {}
- for name, val in kwargs.items():
- if name in maps:
- assert isinstance(maps[name], str)
- output.update({maps[name]: val})
- else:
- output.update({name: val})
- for keys in maps.keys():
- if keys not in output.keys():
- pass
- return output
-
-
-def _get_arg_list(func):
- assert callable(func)
- spect = inspect.getfullargspec(func)
- if spect.defaults is not None:
- args = spect.args[: -len(spect.defaults)]
- defaults = spect.args[-len(spect.defaults):]
- defaults_val = spect.defaults
- else:
- args = spect.args
- defaults = None
- defaults_val = None
- varargs = spect.varargs
- kwargs = spect.varkw
- return args, defaults, defaults_val, varargs, kwargs
-
-
-# check args
-def _check_arg_dict_list(func, args):
- if isinstance(args, dict):
- arg_dict_list = [args]
- else:
- arg_dict_list = args
- assert callable(func) and isinstance(arg_dict_list, (list, tuple))
- assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict)
- spect = inspect.getfullargspec(func)
- all_args = set([arg for arg in spect.args if arg != 'self'])
- defaults = []
- if spect.defaults is not None:
- defaults = [arg for arg in spect.defaults]
- start_idx = len(spect.args) - len(defaults)
- default_args = set(spect.args[start_idx:])
- require_args = all_args - default_args
- input_arg_count = Counter()
- for arg_dict in arg_dict_list:
- input_arg_count.update(arg_dict.keys())
- duplicated = [name for name, val in input_arg_count.items() if val > 1]
- input_args = set(input_arg_count.keys())
- missing = list(require_args - input_args)
- unused = list(input_args - all_args)
- varargs = [] if not spect.varargs else [spect.varargs]
- return _CheckRes(missing=missing,
- unused=unused,
- duplicated=duplicated,
- required=list(require_args),
- all_needed=list(all_args),
- varargs=varargs)
-
-
-def _get_func_signature(func):
- r"""
-
- Given a function or method, return its signature.
- For example:
-
- 1 function::
-
- def func(a, b='a', *args):
- xxxx
- get_func_signature(func) # 'func(a, b='a', *args)'
-
- 2 method::
-
- class Demo:
- def __init__(self):
- xxx
- def forward(self, a, b='a', **args)
- demo = Demo()
- get_func_signature(demo.forward) # 'Demo.forward(self, a, b='a', **args)'
-
- :param func: a function or a method
- :return: str or None
- """
- if inspect.ismethod(func):
- class_name = func.__self__.__class__.__name__
- signature = inspect.signature(func)
- signature_str = str(signature)
- if len(signature_str) > 2:
- _self = '(self, '
- else:
- _self = '(self'
- signature_str = class_name + '.' + func.__name__ + _self + signature_str[1:]
- return signature_str
- elif inspect.isfunction(func):
- signature = inspect.signature(func)
- signature_str = str(signature)
- signature_str = func.__name__ + signature_str
- return signature_str
-
-
-def _is_function_or_method(func):
- r"""
-
- :param func:
- :return:
- """
- if not inspect.ismethod(func) and not inspect.isfunction(func):
- return False
- return True
-
-
-def _check_function_or_method(func):
- if not _is_function_or_method(func):
- raise TypeError(f"{type(func)} is not a method or function.")
-
-
-def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False):
- r"""
-
- move data to model's device, element in *args should be dict. This is a inplace change.
- :param device: torch.device
- :param non_blocking: bool, 是否异步将数据转移到cpu, 需要tensor使用pin_memory()
- :param args:
- :return:
- """
- if not torch.cuda.is_available() or device is None:
- return
-
- if not isinstance(device, torch.device):
- raise TypeError(f"device must be `torch.device`, got `{type(device)}`")
-
- for arg in args:
- if isinstance(arg, dict):
- for key, value in arg.items():
- if isinstance(value, torch.Tensor):
- arg[key] = value.to(device, non_blocking=non_blocking)
- else:
- raise TypeError("Only support `dict` type right now.")
-
-
-class _CheckError(Exception):
- r"""
-
- _CheckError. Used in losses.LossBase, metrics.MetricBase.
- """
-
- def __init__(self, check_res: _CheckRes, func_signature: str):
- errs = [f'Problems occurred when calling `{func_signature}`']
-
- if check_res.varargs:
- errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)")
- if check_res.missing:
- errs.append(f"\tmissing param: {check_res.missing}")
- if check_res.duplicated:
- errs.append(f"\tduplicated param: {check_res.duplicated}")
- if check_res.unused:
- errs.append(f"\tunused param: {check_res.unused}")
-
- Exception.__init__(self, '\n'.join(errs))
-
- self.check_res = check_res
- self.func_signature = func_signature
-
-
-IGNORE_CHECK_LEVEL = 0
-WARNING_CHECK_LEVEL = 1
-STRICT_CHECK_LEVEL = 2
-
-
-def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_res: _CheckRes,
- pred_dict: dict, target_dict: dict, dataset, check_level=0):
- errs = []
- unuseds = []
- _unused_field = []
- _unused_param = []
- suggestions = []
- # if check_res.varargs:
- # errs.append(f"\tvarargs: *{check_res.varargs}")
- # suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
-
- if check_res.unused:
- for _unused in check_res.unused:
- if _unused in target_dict:
- _unused_field.append(_unused)
- else:
- _unused_param.append(_unused)
- if _unused_field:
- unuseds.append(f"\tunused field: {_unused_field}")
- if _unused_param:
- unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward
-
- module_name = func_signature.split('.')[0]
- if check_res.missing:
- errs.append(f"\tmissing param: {check_res.missing}")
- import re
- mapped_missing = [] # 提供了映射的参数
- unmapped_missing = [] # 没有指定映射的参数
- input_func_map = {}
- for _miss_ in check_res.missing:
- # they shoudl like 'SomeParam(assign to xxx)'
- _miss = _miss_.split('(')[0]
- matches = re.findall("(?<=`)[a-zA-Z0-9]*?(?=`)", _miss_)
- if len(matches) == 2:
- fun_arg, module_name = matches
- input_func_map[_miss] = fun_arg
- if fun_arg == _miss:
- unmapped_missing.append(_miss)
- else:
- mapped_missing.append(_miss)
- else:
- unmapped_missing.append(_miss)
-
- for _miss in mapped_missing + unmapped_missing:
- if _miss in dataset:
- suggestions.append(f"Set `{_miss}` as target.")
- else:
- _tmp = ''
- if check_res.unused:
- _tmp = f"Check key assignment for `{input_func_map.get(_miss,_miss)}` when initialize {module_name}."
- if _tmp:
- _tmp += f' Or provide `{_miss}` in DataSet or the output of {prev_func_signature}. '
- else:
- _tmp = f'Provide `{_miss}` in DataSet or the output of {prev_func_signature}.'
- if not dataset.collater.is_empty():
- _tmp += f'Or you need to add `{_miss}` in the output of your collate_fn. '
- suggestions.append(_tmp)
-
- if check_res.duplicated:
- errs.append(f"\tduplicated param: {check_res.duplicated}.")
- suggestions.append(f"Delete {check_res.duplicated} in the output of "
- f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ")
-
- if len(errs) > 0:
- errs.extend(unuseds)
- elif check_level == STRICT_CHECK_LEVEL:
- errs.extend(unuseds)
-
- if len(errs) > 0:
- errs.insert(0, f'Problems occurred when calling {func_signature}')
- sugg_str = ""
- if len(suggestions) > 1:
- for idx, sugg in enumerate(suggestions):
- if idx > 0:
- sugg_str += '\t\t\t'
- sugg_str += f'({idx + 1}). {sugg}\n'
- sugg_str = sugg_str[:-1]
- else:
- sugg_str += suggestions[0]
- errs.append(f'\ttarget field: {list(target_dict.keys())}')
- errs.append(f'\tparam from {prev_func_signature}: {list(pred_dict.keys())}')
- err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str
- raise NameError(err_str)
- if check_res.unused:
- if check_level == WARNING_CHECK_LEVEL:
- if not module_name:
- module_name = func_signature.split('.')[0]
- _unused_warn = f'{check_res.unused} is not used by {module_name}.'
- warnings.warn(message=_unused_warn)
-
-
-def _check_forward_error(forward_func, batch_x, dataset, check_level):
- check_res = _check_arg_dict_list(forward_func, batch_x)
- func_signature = _get_func_signature(forward_func)
-
- errs = []
- suggestions = []
- _unused = []
-
- # if check_res.varargs:
- # errs.append(f"\tvarargs: {check_res.varargs}")
- # suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
- if check_res.missing:
- errs.append(f"\tmissing param: {check_res.missing}")
- _miss_in_dataset = []
- _miss_out_dataset = []
- for _miss in check_res.missing:
- if _miss in dataset:
- _miss_in_dataset.append(_miss)
- else:
- _miss_out_dataset.append(_miss)
- if _miss_in_dataset:
- suggestions.append(f"You might need to set `{_miss_in_dataset}` as input. ")
- if _miss_out_dataset:
- _tmp = f"You need to provide `{_miss_out_dataset}` in DataSet and set it as input. "
- if not dataset.collater.is_empty():
- _tmp += f'Or you need to add `{_miss_out_dataset}` in the output of your collate_fn. '
- suggestions.append(_tmp)
-
- if check_res.unused:
- _unused = [f"\tunused field: {check_res.unused}"]
- if len(errs) > 0:
- errs.extend(_unused)
- elif check_level == STRICT_CHECK_LEVEL:
- errs.extend(_unused)
-
- if len(errs) > 0:
- errs.insert(0, f'Problems occurred when calling {func_signature}')
- sugg_str = ""
- if len(suggestions) > 1:
- for idx, sugg in enumerate(suggestions):
- sugg_str += f'({idx + 1}). {sugg}'
- err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str
- elif len(suggestions):
- sugg_str += suggestions[0]
- err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str
- else:
- err_str = '\n' + '\n'.join(errs)
- raise NameError(err_str)
- if _unused:
- if check_level == WARNING_CHECK_LEVEL:
- _unused_warn = _unused[0] + f' in {func_signature}.'
- warnings.warn(message=_unused_warn)
-
-
-def seq_len_to_mask(seq_len, max_len=None):
- r"""
-
- 将一个表示sequence length的一维数组转换为二维的mask,不包含的位置为0。
- 转变 1-d seq_len到2-d mask.
-
- .. code-block::
-
- >>> seq_len = torch.arange(2, 16)
- >>> mask = seq_len_to_mask(seq_len)
- >>> print(mask.size())
- torch.Size([14, 15])
- >>> seq_len = np.arange(2, 16)
- >>> mask = seq_len_to_mask(seq_len)
- >>> print(mask.shape)
- (14, 15)
- >>> seq_len = torch.arange(2, 16)
- >>> mask = seq_len_to_mask(seq_len, max_len=100)
- >>>print(mask.size())
- torch.Size([14, 100])
-
- :param np.ndarray,torch.LongTensor seq_len: shape将是(B,)
- :param int max_len: 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有
- 区别,所以需要传入一个max_len使得mask的长度是pad到该长度。
- :return: np.ndarray, torch.Tensor 。shape将是(B, max_length), 元素类似为bool或torch.uint8
- """
- if isinstance(seq_len, np.ndarray):
- assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}."
- max_len = int(max_len) if max_len else int(seq_len.max())
- broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1))
- mask = broad_cast_seq_len < seq_len.reshape(-1, 1)
-
- elif isinstance(seq_len, torch.Tensor):
- assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}."
- batch_size = seq_len.size(0)
- max_len = int(max_len) if max_len else seq_len.max().long()
- broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len)
- mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1))
- else:
- raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.")
-
- return mask
-
-
-class _pseudo_tqdm:
- r"""
- 当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据
- """
-
- def __init__(self, **kwargs):
- self.logger = logger
-
- def write(self, info):
- self.logger.info(info)
-
- def set_postfix_str(self, info):
- self.logger.info(info)
-
- def __getattr__(self, item):
- def pass_func(*args, **kwargs):
- pass
-
- return pass_func
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- del self
-
-
-def iob2(tags: List[str]) -> List[str]:
- r"""
- 检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两者的差异见
- https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format
-
- :param tags: 需要转换的tags, 需要为大写的BIO标签。
- """
- for i, tag in enumerate(tags):
- if tag == "O":
- continue
- split = tag.split("-")
- if len(split) != 2 or split[0] not in ["I", "B"]:
- raise TypeError("The encoding schema is not a valid IOB type.")
- if split[0] == "B":
- continue
- elif i == 0 or tags[i - 1] == "O": # conversion IOB1 to IOB2
- tags[i] = "B" + tag[1:]
- elif tags[i - 1][1:] == tag[1:]:
- continue
- else: # conversion IOB1 to IOB2
- tags[i] = "B" + tag[1:]
- return tags
-
-
-def iob2bioes(tags: List[str]) -> List[str]:
- r"""
- 将iob的tag转换为bioes编码
- :param tags: List[str]. 编码需要是大写的。
- :return:
- """
- new_tags = []
- for i, tag in enumerate(tags):
- if tag == 'O':
- new_tags.append(tag)
- else:
- split = tag.split('-')[0]
- if split == 'B':
- if i + 1 != len(tags) and tags[i + 1].split('-')[0] == 'I':
- new_tags.append(tag)
- else:
- new_tags.append(tag.replace('B-', 'S-'))
- elif split == 'I':
- if i + 1 < len(tags) and tags[i + 1].split('-')[0] == 'I':
- new_tags.append(tag)
- else:
- new_tags.append(tag.replace('I-', 'E-'))
- else:
- raise TypeError("Invalid IOB format.")
- return new_tags
-
-
-def _is_iterable(value):
- # 检查是否是iterable的, duck typing
- try:
- iter(value)
- return True
- except BaseException as e:
- return False
-
-
-def get_seq_len(words, pad_value=0):
- r"""
- 给定batch_size x max_len的words矩阵,返回句子长度
-
- :param words: batch_size x max_len
- :return: (batch_size,)
- """
- mask = words.ne(pad_value)
- return mask.sum(dim=-1)
-
-
-def pretty_table_printer(dataset_or_ins) -> PrettyTable:
- r"""
- :param dataset_or_ins: 传入一个dataSet或者instance
- ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"])
- +-----------+-----------+-----------------+
- | field_1 | field_2 | field_3 |
- +-----------+-----------+-----------------+
- | [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] |
- +-----------+-----------+-----------------+
- :return: 以 pretty table的形式返回根据terminal大小进行自动截断
- """
- x = PrettyTable()
- try:
- sz = os.get_terminal_size()
- column = sz.columns
- row = sz.lines
- except OSError:
- column = 144
- row = 11
-
- if type(dataset_or_ins).__name__ == "DataSet":
- x.field_names = list(dataset_or_ins.field_arrays.keys())
- c_size = len(x.field_names)
- for ins in dataset_or_ins:
- x.add_row([sub_column(ins[k], column, c_size, k) for k in x.field_names])
- row -= 1
- if row < 0:
- x.add_row(["..." for _ in range(c_size)])
- break
- elif type(dataset_or_ins).__name__ == "Instance":
- x.field_names = list(dataset_or_ins.fields.keys())
- c_size = len(x.field_names)
- x.add_row([sub_column(dataset_or_ins[k], column, c_size, k) for k in x.field_names])
-
- else:
- raise Exception("only accept DataSet and Instance")
- x.align = "l"
-
- return x
-
-
-def sub_column(string: str, c: int, c_size: int, title: str) -> str:
- r"""
- :param string: 要被截断的字符串
- :param c: 命令行列数
- :param c_size: instance或dataset field数
- :param title: 列名
- :return: 对一个过长的列进行截断的结果
- """
- avg = max(int(c / c_size / 2), len(title))
- string = str(string)
- res = ""
- counter = 0
- for char in string:
- if ord(char) > 255:
- counter += 2
- else:
- counter += 1
- res += char
- if counter > avg:
- res = res + "..."
- break
- return res
-
-
-def _is_function_contains_autocast(func):
- """
- 检查func是否包含autocast,(1)是否使用了autocast的修饰器或, (2)使用使用with autocast()环境
-
- :param func: 待检查的函数
- """
- import re
- source = inspect.getsource(func)
- lines = source.split('\n')
- for line in lines:
- line = line.strip()
- if re.search(r'@[\w\.]*autocast\(\w*\)', line):
- raise RuntimeError("Please do not use `autocast()` decorator, use `with autocast():` instead. Please refer to"
- " https://pytorch.org/docs/stable/notes/amp_examples.html#dataparallel-in-a-single-process ")
- if re.search(r'with [\w\.]*autocast\(\w*\):', line):
- return True
- return False
-
-
-class DummyGradScaler:
- """
- 用于Dummy pytorch的GradScaler对象,防止重复写大量的if判断
-
- """
- def __init__(self, *args, **kwargs):
- pass
-
- def get_scale(self):
- return 1.0
-
- def is_enabled(self):
- return False
-
- def scale(self, outputs):
- return outputs
-
- def step(self, optimizer, *args, **kwargs):
- optimizer.step(*args, **kwargs)
-
- def update(self, new_scale=None):
- pass
-
- def unscale_(self, optimizer):
- pass
-
- def load_state_dict(self, state_dict):
- pass
-
- def state_dict(self):
- return {}
-
-
-def _build_fp16_env(dummy=False):
- if dummy:
- autocast = contextlib.ExitStack
- GradScaler = DummyGradScaler
- else:
- if not torch.cuda.is_available():
- raise RuntimeError("No cuda")
- if torch.cuda.get_device_capability(0)[0] < 7:
- warnings.warn(
- "NOTE: your device does NOT support faster training with fp16, "
- "please switch to FP32 which is likely to be faster"
- )
- try:
- from torch.cuda.amp import autocast, GradScaler
- except ImportError:
- raise RuntimeError("torch version too low (less than 1.6)")
- return autocast, GradScaler
-
-
-def _can_use_fp16(device, model, func):
- if parse_version(torch.__version__) < parse_version('1.6'):
- raise RuntimeError("Pytorch supports float16 after version 1.6, please upgrade your pytorch version.")
- model_device = _get_model_device(model)
- if device is None and model_device is not None and model_device.type != 'cuda':
- raise RuntimeError("You have to run in cuda device to use fp16.")
- if isinstance(device, str):
- if device=='cpu':
- raise RuntimeError("You have to run in cuda device to use fp16.")
- if isinstance(device, torch.device) and device.type=='cpu':
- raise RuntimeError("You have to run in cuda device to use fp16.")
-
- if (_model_contains_inner_module(model) or (isinstance(device, list) and len(device) > 1)):
- # 需要提醒用户
- if not _is_function_contains_autocast(func):
- raise RuntimeError("When use fp16 in Parallel Training, you have to set autocast() in your forward "
- "function as described in "
- "https://pytorch.org/docs/stable/notes/amp_examples.html#dataparallel-in-a-single-process")
diff --git a/fastNLP/core/utils/__init__.py b/fastNLP/core/utils/__init__.py
new file mode 100644
index 00000000..d188bc37
--- /dev/null
+++ b/fastNLP/core/utils/__init__.py
@@ -0,0 +1,48 @@
+__all__ = [
+ 'cache_results',
+ 'is_jittor_module',
+ 'is_jittor_dataset',
+ 'jittor_collate_wraps',
+ 'paddle_to',
+ 'paddle_move_data_to_device',
+ 'get_paddle_device_id',
+ 'get_paddle_gpu_str',
+ 'is_in_paddle_dist',
+ 'is_in_fnlp_paddle_dist',
+ 'is_in_paddle_launch_dist',
+ 'is_paddle_module',
+ 'f_rich_progress',
+ 'torch_move_data_to_device',
+ 'is_torch_module',
+ 'get_oneflow_device',
+ 'oneflow_move_data_to_device',
+ 'is_oneflow_module',
+ 'is_in_oneflow_dist',
+ 'get_fn_arg_names',
+ 'auto_param_call',
+ 'check_user_specific_params',
+ 'dataclass_to_dict',
+ 'match_and_substitute_params',
+ 'apply_to_collection',
+ 'nullcontext',
+ 'pretty_table_printer',
+ 'Option',
+ 'deprecated',
+ "flat_nest_dict",
+ "f_tqdm_progress",
+
+ "seq_len_to_mask"
+]
+
+from .cache_results import cache_results
+from .jittor_utils import is_jittor_dataset, jittor_collate_wraps, is_jittor_module
+from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_device_id, get_paddle_gpu_str, is_in_paddle_dist, \
+ is_in_fnlp_paddle_dist, is_in_paddle_launch_dist, is_paddle_module
+from .rich_progress import f_rich_progress
+from .torch_utils import torch_move_data_to_device, is_torch_module
+from .oneflow_utils import oneflow_move_data_to_device, is_oneflow_module, is_in_oneflow_dist, get_oneflow_device
+from .utils import *
+from .tqdm_progress import f_tqdm_progress
+from .seq_len_to_mask import seq_len_to_mask
+
+
diff --git a/fastNLP/core/utils/cache_results.py b/fastNLP/core/utils/cache_results.py
new file mode 100644
index 00000000..dc114b48
--- /dev/null
+++ b/fastNLP/core/utils/cache_results.py
@@ -0,0 +1,349 @@
+"""
+:func:`cache_results` 函数是 **fastNLP** 中用于缓存数据的装饰器,通过该函数您可以省去调试代码过程中一些耗时过长程序
+带来的时间开销。比如在加载并处理较大的数据时,每次修改训练参数都需要从头开始执行处理数据的过程,那么 :func:`cache_results`
+便可以跳过这部分漫长的时间。详细的使用方法和原理请参见下面的说明。
+
+.. warning::
+
+ 如果您发现对代码进行修改之后程序执行的结果没有变化,很有可能是这个函数的原因;届时删除掉缓存数据即可。
+
+"""
+
+from datetime import datetime
+import hashlib
+import _pickle
+import functools
+import os
+import re
+from typing import Callable, List, Any, Optional
+import inspect
+import ast
+from collections import deque
+
+__all__ = [
+ 'cache_results'
+]
+
+from fastNLP.core.log.logger import logger
+from fastNLP.core.log.highlighter import ColorHighlighter
+from .utils import _get_fun_msg
+
+
+class FuncCallVisitor(ast.NodeVisitor):
+ # credit to https://gist.github.com/jargnar/0946ab1d985e2b4ab776
+ def __init__(self):
+ self._name = deque()
+
+ @property
+ def name(self):
+ return '.'.join(self._name)
+
+ @name.deleter
+ def name(self):
+ self._name.clear()
+
+ def visit_Name(self, node):
+ self._name.appendleft(node.id)
+
+ def visit_Attribute(self, node):
+ try:
+ self._name.appendleft(node.attr)
+ self._name.appendleft(node.value.id)
+ except AttributeError:
+ self.generic_visit(node)
+
+
+def get_func_calls(tree):
+ func_calls = []
+ for node in ast.walk(tree):
+ if isinstance(node, ast.Call):
+ callvisitor = FuncCallVisitor()
+ callvisitor.visit(node.func)
+ func_calls.append(callvisitor.name)
+ if isinstance(node, ast.FunctionDef):
+ if not (node is tree):
+ func_calls.extend(get_func_calls(node))
+
+ return func_calls
+
+
+def truncate_start_blanks(source:str)->str:
+ """
+ 将source中的每一行按照第一行的indent删掉多余的空格
+
+ :param source:
+ :return:
+ """
+ lines = source.split('\n')
+ num_blank = 0
+ # get the top blank line
+ for line in lines:
+ if line:
+ num_blank = len(line) - len(line.lstrip())
+ new_lines = []
+ for line in lines:
+ i = -1
+ for i in range(min(len(line), num_blank)):
+ if line[i] == ' ':
+ continue
+ else:
+ break
+ line = line[i:]
+ new_lines.append(line)
+ return '\n'.join(new_lines)
+
+
+def _get_func_and_its_called_func_source_code(func) -> List[str]:
+ """
+ 给定一个func,返回在这个函数里面用到的所有函数的源码。
+
+ :param callable func:
+ :return:
+ """
+ last_frame = inspect.currentframe().f_back.f_back.f_back
+ last_frame_f_local = last_frame.f_locals
+ last_frame_loc = {}
+ if 'loc' in last_frame_f_local:
+ last_frame_loc = last_frame_f_local['loc']
+ func_calls = list(set(get_func_calls(ast.parse(truncate_start_blanks(inspect.getsource(func))))))
+ func_calls.sort()
+ sources = []
+ for _func_name in func_calls:
+ try:
+ if _func_name == 'cache_results': # ignore the decorator
+ continue
+ if '.' in _func_name:
+ _funcs = _func_name.split('.')
+ else:
+ _funcs = [_func_name]
+ if _funcs[0] in last_frame_f_local or _funcs[0] in last_frame_loc:
+ tmp = _funcs.pop(0)
+ variable = last_frame_f_local.get(tmp, last_frame_loc.get(tmp))
+ while len(_funcs) or variable is not None:
+ if hasattr(variable, '__class__') and not inspect.isbuiltin(variable.__class__):
+ try:
+ sources.append(inspect.getsource(variable.__class__))
+ except TypeError:
+ pass
+ if callable(variable) or inspect.isclass(variable):
+ sources.append(inspect.getsource(variable))
+ if len(_funcs):
+ tmp = _funcs.pop(0)
+ if hasattr(variable, tmp):
+ variable = getattr(variable, tmp)
+ else:
+ break
+ else:
+ variable = None
+ except:
+ # some failure
+ pass
+ del last_frame #
+ func_source_code = inspect.getsource(func) # 将这个函数中的 cache_results 装饰删除掉。
+ for match in list(re.finditer('@cache_results\(.*\)\\n', func_source_code))[::-1]:
+ func_source_code = func_source_code[:match.start()] + func_source_code[match.end():]
+ sources.append(func_source_code)
+ return sources
+
+
+def _prepare_cache_filepath(filepath:str):
+ r"""
+ 检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径
+
+ :param filepath: str.
+ :return: None, if not, this function will raise error
+ """
+ _cache_filepath = os.path.abspath(filepath)
+ if os.path.isdir(_cache_filepath):
+ raise RuntimeError("The cache_file_path must be a file, not a directory.")
+ cache_dir = os.path.dirname(_cache_filepath)
+ if not os.path.exists(cache_dir):
+ os.makedirs(cache_dir, exist_ok=True)
+
+
+class Hasher:
+ def __init__(self):
+ self.m = hashlib.sha1()
+
+ def update(self, value: Any) -> None:
+ if isinstance(value, str):
+ value = [value]
+ for x in value:
+ self.m.update(x.encode('utf8'))
+
+ def hexdigest(self) -> str:
+ return self.m.hexdigest()
+
+
+def cal_fn_hash_code(fn: Optional[Callable] = None, fn_kwargs: Optional[dict] = None):
+ if fn_kwargs is None:
+ fn_kwargs = {}
+ hasher = Hasher()
+ if fn is not None:
+ try:
+ sources = _get_func_and_its_called_func_source_code(fn)
+ hasher.update(sources)
+ except:
+ return "can't be hashed"
+ for key in sorted(fn_kwargs):
+ hasher.update(key)
+ try:
+ hasher.update(fn_kwargs[key])
+ except:
+ pass
+ return hasher.hexdigest()
+
+
+def cache_results(_cache_fp: str, _hash_param: bool = True, _refresh: bool = False, _verbose: int = 1, _check_hash: bool = True):
+ r"""
+ :func:`cache_results` 是 **fastNLP** 中用于缓存数据的装饰器。通过下面的例子看一下如何使用::
+
+ import time
+ import numpy as np
+ from fastNLP import cache_results
+
+ @cache_results('cache.pkl')
+ def process_data(second=1):
+ # 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时
+ time.sleep(second)
+ return np.random.randint(10, size=(5,))
+
+ start_time = time.time()
+ print("res =",process_data())
+ print(time.time() - start_time)
+
+ start_time = time.time()
+ print("res =",process_data())
+ print(time.time() - start_time)
+
+ start_time = time.time()
+ print("res =",process_data(second=2))
+ print(time.time() - start_time)
+
+ # 输出内容如下,可以看到前两次结果相同,且第二次几乎没有花费时间。第三次由于参数变化了,所以cache的结果也就自然变化了。
+ # Save cache to 2d145aeb_cache.pkl.
+ # res = [5 4 9 1 8]
+ # 1.0134737491607666
+ # Read cache from 2d145aeb_cache.pkl (Saved on xxxx).
+ # res = [5 4 9 1 8]
+ # 0.0040721893310546875
+ # Save cache to 0ead3093_cache.pkl.
+ # res = [1 8 2 5 1]
+ # 2.0086121559143066
+
+ 可以看到第二次运行的时候,只用了 0.0001s 左右,这是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理。
+ 如果在函数加上了装饰器 ``@cache_results()``,则函数会增加五个参数 ``[_cache_fp, _hash_param, _refresh, _verbose,
+ _check_hash]``。上面的例子即为使用_cache_fp的情况,这五个参数不会传入到被装饰函数中,当然被装饰函数参数名也不能包含这五个名称。
+
+ :param _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为 ``None`` ,cache_results 没有任何效用,除非在
+ 函数调用的时候传入 _cache_fp 这个参数。实际保存的文件名会受到 ``_hash_param`` 参数的影响,例如传入的名称是 **"caches/cache.pkl"**,
+ 实际保存的文件名会是 **"caches/{hash_param_result}_cache.pkl"**。
+ :param _hash_param: 是否将传入给被装饰函数的 parameter 进行 :func:`str` 之后的 hash 结果加入到 ``_cache_fp`` 中,这样每次函数的
+ parameter 改变的时候,cache 文件就自动改变了。
+ :param _refresh: 强制重新生成新的 cache 。
+ :param _verbose: 是否打印 cache 的信息。
+ :param _check_hash: 如果为 ``True`` 将尝试对比修饰的函数的源码以及该函数内部调用的函数的源码的 hash 值。如果发现保存时的 hash 值
+ 与当前的 hash 值有差异,会报 warning 。但该 warning 可能出现实质上并不影响结果的误报(例如增删空白行);且在修改不涉及源码时,虽然
+ 该修改对结果有影响,但无法做出 warning。
+ :return:
+ """
+
+ def wrapper_(func):
+ signature = inspect.signature(func)
+ for key, _ in signature.parameters.items():
+ if key in ('_cache_fp', "_hash_param", '_refresh', '_verbose', '_check_hash'):
+ raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key))
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ # fn_param = kwargs.copy()
+ # if args:
+ # params = [p.name for p in inspect.signature(func).parameters.values()]
+ # fn_param.update(zip(params, args))
+ if '_cache_fp' in kwargs:
+ cache_filepath = kwargs.pop('_cache_fp')
+ assert isinstance(cache_filepath, str), "_cache_fp can only be str."
+ else:
+ cache_filepath = _cache_fp
+ if '_refresh' in kwargs:
+ refresh = kwargs.pop('_refresh')
+ assert isinstance(refresh, bool), "_refresh can only be bool."
+ else:
+ refresh = _refresh
+ if '_verbose' in kwargs:
+ verbose = kwargs.pop('_verbose')
+ assert isinstance(verbose, int), "_verbose can only be integer."
+ else:
+ verbose = _verbose
+
+ if '_check_hash' in kwargs:
+ check_hash = kwargs.pop('_check_hash')
+ else:
+ check_hash = _check_hash
+
+ if '_hash_param' in kwargs:
+ hash_param = kwargs.pop('_hash_param')
+ assert isinstance(hash_param, bool), "_hash_param can only be bool."
+ else:
+ hash_param = _hash_param
+
+ if hash_param and cache_filepath is not None: # 尝试将parameter给hash一下
+ try:
+ params = dict(inspect.getcallargs(func, *args, **kwargs))
+ if inspect.ismethod(func): # 如果是 method 的话第一个参数(一般就是 self )就不考虑了
+ first_key = next(iter(params.items()))
+ params.pop(first_key)
+ if len(params):
+ # sort 一下防止顺序改变
+ params = {k: str(v) for k, v in sorted(params.items(), key=lambda item: item[0])}
+ param_hash = cal_fn_hash_code(None, params)[:8]
+ head, tail = os.path.split(cache_filepath)
+ cache_filepath = os.path.join(head, param_hash + '_' + tail)
+ except BaseException as e:
+ logger.debug(f"Fail to add parameter hash to cache path, because of Exception:{e}")
+
+ refresh_flag = True
+ new_hash_code = None
+ if check_hash:
+ new_hash_code = cal_fn_hash_code(func, None)
+
+ if cache_filepath is not None and refresh is False:
+ # load data
+ if os.path.exists(cache_filepath):
+ cache_filepath = os.path.abspath(cache_filepath)
+ with open(cache_filepath, 'rb') as f:
+ results = _pickle.load(f)
+ old_hash_code = results['hash']
+ save_time = results['save_time']
+ results = results['results']
+ if verbose == 1:
+ logger.info("Read cache from {} (Saved on {}).".format(cache_filepath, save_time))
+ if check_hash and old_hash_code != new_hash_code:
+ logger.warning(f"The function {_get_fun_msg(func)} is different from its last cache (Save on {save_time}). The "
+ f"difference may caused by the sourcecode change.",
+ extra={'highlighter': ColorHighlighter('red')})
+ refresh_flag = False
+
+ if refresh_flag:
+ if new_hash_code is None:
+ new_hash_code = cal_fn_hash_code(func, None)
+ results = func(*args, **kwargs)
+ if cache_filepath is not None:
+ if results is None:
+ raise RuntimeError("The return value is None. Cannot save None results.")
+ cache_filepath = os.path.abspath(cache_filepath)
+ _prepare_cache_filepath(cache_filepath)
+ _dict = {
+ 'results': results,
+ 'hash': new_hash_code,
+ 'save_time': datetime.now(),
+ }
+ with open(cache_filepath, 'wb') as f:
+ _pickle.dump(_dict, f)
+ logger.info("Save cache to {}.".format(cache_filepath))
+
+ return results
+
+ return wrapper
+
+ return wrapper_
\ No newline at end of file
diff --git a/fastNLP/core/utils/dummy_class.py b/fastNLP/core/utils/dummy_class.py
new file mode 100644
index 00000000..afd610ce
--- /dev/null
+++ b/fastNLP/core/utils/dummy_class.py
@@ -0,0 +1,11 @@
+__all__ = []
+
+class DummyClass:
+ def __init__(self, *args, **kwargs):
+ pass
+
+ def __getattr__(self, item):
+ return lambda *args, **kwargs: ...
+
+ def __call__(self, *args, **kwargs):
+ pass
\ No newline at end of file
diff --git a/fastNLP/core/utils/exceptions.py b/fastNLP/core/utils/exceptions.py
new file mode 100644
index 00000000..a052c11b
--- /dev/null
+++ b/fastNLP/core/utils/exceptions.py
@@ -0,0 +1,10 @@
+
+class EarlyStopException(BaseException):
+ r"""
+ 用于 EarlyStop 时从 Trainer 训练循环中跳出。
+
+ """
+
+ def __init__(self, msg):
+ super(EarlyStopException, self).__init__(msg)
+ self.msg = msg
diff --git a/fastNLP/core/utils/jittor_utils.py b/fastNLP/core/utils/jittor_utils.py
new file mode 100644
index 00000000..ea1a86e4
--- /dev/null
+++ b/fastNLP/core/utils/jittor_utils.py
@@ -0,0 +1,71 @@
+__all__ = [
+ 'is_jittor_module',
+ 'is_jittor_dataset',
+ 'jittor_collate_wraps',
+]
+
+from collections.abc import Mapping, Callable
+from functools import wraps
+
+from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
+
+if _NEED_IMPORT_JITTOR:
+ import jittor as jt
+
+from fastNLP.core.dataset import Instance
+
+def is_jittor_module(model) -> bool:
+ """
+ 判断传入的 ``model`` 是否是 :class:`jittor.Module` 类型。
+
+ :param model:
+ :return: 当前模型是否为 ``jittor`` 的模型
+ """
+ try:
+ return isinstance(model, jt.Module)
+ except BaseException:
+ return False
+
+def is_jittor_dataset(dataset) -> bool:
+ """
+ 判断传入的 ``dataset`` 是否是 :class:`jittor.dataset.Dataset` 类型。
+
+ :param dataset:
+ :return: 当前 ``dataset`` 是否为 ``jittor`` 的数据集类型
+ """
+ try:
+ if isinstance(dataset, jt.dataset.Dataset):
+ return True
+ else:
+ return False
+ except BaseException:
+ return False
+
+
+def jittor_collate_wraps(func, auto_collator: Callable):
+ """
+ 对 ``jittor`` 的 ``collate_fn`` 进行 wrap 封装,。如果数据集为 :class:`Mapping` 类型,那么采用 ``auto_collator`` ,
+ 否则还是采用 ``jittor`` 的 ``collate_batch``。
+
+ :param func:
+ :param auto_collator:
+ :return:
+ """
+
+ @wraps(func)
+ def wrapper(batch):
+ if isinstance(batch[0], Instance):
+ if auto_collator is not None:
+ result = auto_collator(batch)
+ else:
+ raise ValueError(f"auto_collator is None, but batch exist fastnlp instance!")
+ elif isinstance(batch[0], Mapping):
+ if auto_collator is not None:
+ result = auto_collator(batch)
+ else:
+ result = func(batch)
+ else:
+ result = func(batch)
+ return result
+
+ return wrapper
diff --git a/fastNLP/core/utils/oneflow_utils.py b/fastNLP/core/utils/oneflow_utils.py
new file mode 100644
index 00000000..6c3026c6
--- /dev/null
+++ b/fastNLP/core/utils/oneflow_utils.py
@@ -0,0 +1,69 @@
+import os
+from typing import Any, Union, Optional
+from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK
+from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW
+
+if _NEED_IMPORT_ONEFLOW:
+ import oneflow
+
+__all__ = [
+ 'get_oneflow_device',
+ 'oneflow_move_data_to_device',
+ 'is_oneflow_module',
+ 'is_in_oneflow_dist',
+]
+
+from .utils import apply_to_collection
+
+def get_oneflow_device(device):
+ """
+ 构造一个 :class:`oneflow.device` 实例并返回。
+
+ :param device: 字符串或 gpu 编号
+ :return: :class:`oneflow.device`
+ """
+ if isinstance(device, oneflow.device):
+ return device
+ if isinstance(device, int):
+ return oneflow.device("cuda", device)
+ if isinstance(device, str):
+ return oneflow.device(device)
+ raise RuntimeError(f"Cannot get `oneflow.device` from {device}.")
+
+def oneflow_move_data_to_device(batch: Any, device: Optional[Union[str, "oneflow.device"]] = None) -> Any:
+ r"""
+ 在 **oneflow** 中将数据集合 ``batch`` 传输到给定设备。
+
+ :param batch: 需要迁移的数据
+ :param device: 数据应当迁移到的设备;当该参数的值为 ``None`` 时则不执行任何操作
+ :return: 迁移到新设备上的数据集合
+ """
+ if device is None:
+ return batch
+
+ def batch_to(data: Any) -> Any:
+ data_output = data.to(device)
+ if data_output is not None:
+ return data_output
+ # user wrongly implemented the `TransferableDataType` and forgot to return `self`.
+ return data
+
+ return apply_to_collection(batch, dtype=oneflow.Tensor, function=batch_to)
+
+def is_oneflow_module(model) -> bool:
+ """
+ 判断传入的 ``model`` 是否是 :class:`oneflow.nn.Module` 类型。
+
+ :param model:
+ :return: 当前模型是否为 ``oneflow`` 的模型
+ """
+ try:
+ return isinstance(model, oneflow.nn.Module)
+ except BaseException:
+ return False
+
+def is_in_oneflow_dist() -> bool:
+ """
+ 判断是否处于 **oneflow** 分布式的进程下。
+ """
+ return "GLOG_log_dir" in os.environ
\ No newline at end of file
diff --git a/fastNLP/core/utils/paddle_utils.py b/fastNLP/core/utils/paddle_utils.py
new file mode 100644
index 00000000..11ab4834
--- /dev/null
+++ b/fastNLP/core/utils/paddle_utils.py
@@ -0,0 +1,190 @@
+__all__ = [
+ "paddle_to",
+ "paddle_move_data_to_device",
+ "get_paddle_gpu_str",
+ "get_paddle_device_id",
+ "is_in_paddle_dist",
+ "is_in_fnlp_paddle_dist",
+ "is_in_paddle_launch_dist",
+ "is_paddle_module",
+]
+
+import os
+import re
+from typing import Any, Optional, Union
+
+from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
+from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_BACKEND_LAUNCH, USER_CUDA_VISIBLE_DEVICES
+
+if _NEED_IMPORT_PADDLE:
+ import paddle
+
+from .utils import apply_to_collection
+
+
+def _convert_data_device(device: Union[str, int]) -> str:
+ """
+ 用于转换 ``driver`` 的 ``data_device`` 的函数。如果用户设置了 ``FASTNLP_BACKEND=paddle``,那么 **fastNLP** 会将
+ 可见的设备保存在 ``USER_CUDA_VISIBLE_DEVICES`` 中,并且将 ``CUDA_VISIBLE_DEVICES`` 设置为可见的第一张显卡;这是为
+ 了顺利执行 **paddle** 的分布式训练而设置的。
+
+ 在这种情况下,单纯使用 ``driver.data_device`` 是无效的。比如在分布式训练中将设备设置为 ``[0,2,3]`` ,且用户设置了
+ ``CUDA_VISIBLE_DEVICES=3,4,5,6`` ,那么在 ``rank1``的进程中有::
+
+ os.environ["CUDA_VISIBLE_DEVICES"] = "5"
+ os.environ["USER_CUDA_VISIBLE_DEVICES"] = "3,4,5,6"
+ driver.data_device = "gpu:2" # 为了向用户正确地反映他们设置的设备减少歧义,因此这里没有设置为 "gpu:5"
+
+ 此时我们便需要通过这个函数将 ``data_device`` 转换为 ``gpu:0``。具体过程便是通过索引 **2** 在 ``USER_CUDA_VISIBLE_DEVICES`` 中
+ 找到设备 **5**,然后在 ``CUDA_VISIBLE_DEVICES`` 中找到设备 **5** 的索引 **0** 返回。
+
+ .. note::
+
+ 在分布式单进程仅支持单卡的情况下中,这个函数实际等同于直接转换为 ``gpu:0`` 返回。
+
+ :param device: 未转化的设备
+ :return: 转化后的设备,格式为 ``gpu:x``
+ """
+ try:
+ user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES)
+ if device == "cpu" or user_visible_devices is None:
+ # 传入的是 CPU,或者没有设置 USER_CUDA_VISIBLE_DEVICES
+ # 此时不需要进行转换
+ return get_paddle_gpu_str(device)
+
+ idx = get_paddle_device_id(device)
+ idx = user_visible_devices.split(",")[idx]
+ # 此时 CUDA_VISIBLE_DEVICES 一定不是 None
+ cuda_visible_devices_list = os.getenv("CUDA_VISIBLE_DEVICES").split(',')
+ return f"gpu:{cuda_visible_devices_list.index(idx)}"
+ except Exception as e:
+ raise ValueError(f"Can't convert device {device} when USER_CUDA_VISIBLE_DEVICES={user_visible_devices} "
+ "and CUDA_VISIBLE_DEVICES={cuda_visible_devices}. If this situation happens, please report this bug to us.")
+
+
+def paddle_to(data: "paddle.Tensor", device: Union[str, int, 'paddle.fluid.core_avx.Place',
+ 'paddle.CPUPlace', 'paddle.CUDAPlace']) -> "paddle.Tensor":
+ """
+ 将 ``data`` 迁移到指定的 ``device`` 上。:class:`paddle.Tensor` 没有类似 :meth:`torch.Tensor.to` 的函数来迁移张量,
+ 因此该函数只是集成了 :func:`paddle.Tensor.cpu` 和 :func:`paddle.Tensor.cuda` 两个函数。
+
+ :param data: 要迁移的张量;
+ :param device: 目标设备,可以是 ``str`` 或 ``int`` 及 **paddle** 自己的 :class:`paddle.fluid.core_avx.Place`、
+ :class:`paddle.CPUPlace` 和 :class:`paddle.CUDAPlace` 类型;
+ :return: 迁移后的张量;
+ """
+ if isinstance(device, paddle.fluid.core_avx.Place):
+ if device.is_cpu_place():
+ return data.cpu()
+ else:
+ return data.cuda(device.gpu_device_id())
+ elif isinstance(device, paddle.CPUPlace):
+ return data.cpu()
+ elif isinstance(device, paddle.CUDAPlace):
+ return data.gpu(device.get_device_id())
+ elif device == "cpu":
+ return data.cpu()
+ else:
+ return data.cuda(get_paddle_device_id(device))
+
+
+def get_paddle_gpu_str(device: Union[str, int]) -> str:
+ """
+ 获得 ``gpu:x`` 格式的设备名::
+
+ >>> get_paddle_gpu_str(1)
+ 'gpu:1'
+ >>> get_paddle_gpu_str("cuda:1")
+ 'gpu:1'
+
+ :param device: 设备编号或设备名;
+ :return: 对应的 ``gpu:x`` 格式的设备名;
+ """
+ if isinstance(device, str):
+ return device.replace("cuda", "gpu")
+ return f"gpu:{device}"
+
+
+def get_paddle_device_id(device: Union[str, int]) -> int:
+ """
+ 获得 ``device`` 的设备编号::
+
+ >>> get_paddle_device_id("gpu:1")
+ 1
+ >>> get_paddle_device_id("gpu")
+ 0
+
+ 请注意不要向这个函数中传入 ``cpu``。
+
+ :param: device: 设备编号或设备名;
+ :return: 设备对应的编号;
+ """
+ if isinstance(device, int):
+ return device
+
+ device = device.lower()
+ if device == "cpu":
+ raise ValueError("Cannot get device id from `cpu`.")
+ elif device == "gpu":
+ return 0
+
+ match_res = re.match(r"gpu:\d+", device)
+ if not match_res:
+ raise ValueError(
+ "The device must be a string which is like 'cpu', 'gpu', 'gpu:x', "
+ f"not '{device}'"
+ )
+ device_id = device.split(':', 1)[1]
+ device_id = int(device_id)
+
+ return device_id
+
+def paddle_move_data_to_device(batch: Any, device: Optional[Union[str, int]]) -> Any:
+ r"""
+ 将 **paddle** 的数据集合传输到给定设备。只有 :class:`paddle.Tensor` 对象会被传输到设备中,其余保持不变。
+
+ :param batch: 需要进行迁移的数据集合;
+ :param device: 目标设备。可以是显卡设备的编号,或是``cpu``, ``gpu`` 或 ``gpu:x`` 格式的字符串;
+ 当这个参数为 `None`` 时,不会执行任何操作。
+ :return: 迁移到新设备上的数据集合;
+ """
+ if device is None:
+ return batch
+
+ def batch_to(data: Any) -> Any:
+ return paddle_to(data, device)
+
+ return apply_to_collection(batch, dtype=paddle.Tensor, function=batch_to)
+
+
+def is_in_paddle_dist() -> bool:
+ """
+ 判断是否处于 **paddle** 分布式的进程下,使用 ``PADDLE_RANK_IN_NODE`` 和 ``FLAGS_selected_gpus`` 判断。
+ """
+ return ('PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ)
+
+
+def is_in_fnlp_paddle_dist() -> bool:
+ """
+ 判断是否处于 **fastNLP** 拉起的 **paddle** 分布式进程中
+ """
+ return FASTNLP_DISTRIBUTED_CHECK in os.environ
+
+
+def is_in_paddle_launch_dist() -> bool:
+ """
+ 判断是否处于 ``python -m paddle.distributed.launch`` 方法启动的 **paddle** 分布式进程中
+ """
+ return FASTNLP_BACKEND_LAUNCH in os.environ
+
+def is_paddle_module(model) -> bool:
+ """
+ 判断传入的 ``model`` 是否是 :class:`paddle.nn.Layer` 类型
+
+ :param model: 模型;
+ :return: 当前模型是否为 ``paddle`` 的模型;
+ """
+ try:
+ return isinstance(model, paddle.nn.Layer)
+ except BaseException:
+ return False
\ No newline at end of file
diff --git a/fastNLP/core/utils/rich_progress.py b/fastNLP/core/utils/rich_progress.py
new file mode 100644
index 00000000..ad539945
--- /dev/null
+++ b/fastNLP/core/utils/rich_progress.py
@@ -0,0 +1,332 @@
+"""
+该文件用于为 **fastNLP** 提供一个统一的 ``progress bar`` 管理,通过共用一个 ``Task`` 对象, :class:`~fastNLP.core.Trainer`
+中的 ``progress bar`` 和 :class:`~fastNLP.core.Evaluator` 中的 ``progress bar`` 才能不冲突
+"""
+import sys
+from typing import Any, Union, Optional
+
+from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live, Text, ProgressSample
+from rich.progress import ProgressColumn, TimeRemainingColumn, BarColumn, TimeElapsedColumn, TextColumn
+
+__all__ = [
+ 'f_rich_progress'
+]
+
+from fastNLP.envs import get_global_rank
+from .utils import is_notebook
+
+
+class Singleton(type):
+ _instances = {}
+
+ def __call__(cls, *args, **kwargs):
+ if cls not in cls._instances:
+ cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
+ return cls._instances[cls]
+
+
+# 如果不打印的时候,使得整个 progress 没有任何意义
+class DummyFRichProgress:
+ def __getattr__(self, item):
+ return DummyFRichProgress()
+
+ def __call__(self, *args, **kwargs):
+ # 防止用户通过 DummyFRichProgress.console.print() 这种调用
+ return None
+
+ @property
+ def dummy(self)->bool:
+ """
+ 当前对象是否是 dummy 的 rich 对象。
+
+ :return:
+ """
+ return True
+
+
+class FRichProgress(Progress, metaclass=Singleton):
+ def new_progess(self, *columns: Union[str, ProgressColumn],
+ # 这里将 auto_refresh 关掉是想要避免单独开启线程,同时也是为了避免pdb的时候会持续刷新
+ auto_refresh: bool = False,
+ refresh_per_second: float = 10,
+ speed_estimate_period: float = 30.0,
+ transient: bool = True,
+ redirect_stdout: bool = True,
+ redirect_stderr: bool = True,
+ get_time: Optional[GetTimeCallable] = None,
+ disable: bool = False,
+ expand: bool = False):
+ for task_id in self.task_ids: # 首先移除已有的
+ self.remove_task(task_id)
+
+ assert (
+ refresh_per_second is None or refresh_per_second > 0
+ ), "refresh_per_second must be > 0"
+
+ # stop previous columns
+ self.stop()
+
+ # do not change these variables
+ # self._lock = RLock()
+ # self._tasks: Dict[TaskID, Task] = {}
+ # self._task_index: TaskID = TaskID(0)
+
+ if len(columns) != 0:
+ self.columns = columns
+
+ self.speed_estimate_period = speed_estimate_period
+
+ self.disable = disable
+ self.expand = expand
+
+ self.live = Live(
+ console=get_console(),
+ auto_refresh=auto_refresh,
+ refresh_per_second=refresh_per_second,
+ transient=transient,
+ redirect_stdout=redirect_stdout,
+ redirect_stderr=redirect_stderr,
+ get_renderable=self.get_renderable,
+ )
+ self.get_time = get_time or self.console.get_time
+ self.print = self.console.print
+ self.log = self.console.log
+ self.auto_refresh = auto_refresh
+ self.transient = transient
+ self.redirect_stdout = redirect_stdout
+ self.redirect_stderr = redirect_stderr
+ self.refresh_per_second = refresh_per_second
+ self._need_renew_live = False
+
+ return self
+
+ def set_transient(self, transient: bool = True):
+ """
+ 设置是否在bar运行结束之后不关闭
+
+ :param transient:
+ :return:
+ """
+ self.new_progess(transient=transient)
+
+ def set_disable(self, flag: bool = True):
+ """
+ 设置当前 progress bar 的状态,如果为 True ,则不会显示进度条了。
+
+ :param flag:
+ :return:
+ """
+ self.disable = flag
+
+ def add_task(
+ self,
+ description: str = 'Progress',
+ start: bool = True,
+ total: float = 100.0,
+ completed: int = 0,
+ visible: bool = True,
+ **fields: Any,
+ ) -> TaskID:
+ from .tqdm_progress import f_tqdm_progress
+ assert not f_tqdm_progress.not_empty(), "Cannot use rich before tqdm finish loop."
+
+ # 如果需要替换,应该是由于destroy的时候给换掉了
+ if self._need_renew_live:
+ self.live = Live(
+ console=get_console(),
+ auto_refresh=self.auto_refresh,
+ refresh_per_second=self.refresh_per_second,
+ transient=self.transient,
+ redirect_stdout=self.redirect_stdout,
+ redirect_stderr=self.redirect_stderr,
+ get_renderable=self.get_renderable,
+ )
+ self._need_renew_live = False
+ if not self.live.is_started:
+ self.start()
+ post_desc = fields.pop('post_desc', '')
+ return super().add_task(description=description,
+ start=start,
+ total=total,
+ completed=completed,
+ visible=visible,
+ post_desc=post_desc,
+ **fields)
+
+ def stop_task(self, task_id: TaskID) -> None:
+ if task_id in self._tasks:
+ super().stop_task(task_id)
+
+ def remove_task(self, task_id: TaskID) -> None:
+ if task_id in self._tasks:
+ super().remove_task(task_id)
+
+ def destroy_task(self, task_id: TaskID):
+ if task_id in self._tasks:
+ super().stop_task(task_id)
+ super().remove_task(task_id)
+ self.refresh() # 使得bar不残留
+ if len(self._tasks) == 0:
+ # 这里将这个line函数给hack一下防止stop的时候打印出空行
+ old_line = getattr(self.live.console, 'line')
+ setattr(self.live.console, 'line', lambda *args,**kwargs:...)
+ self.live.stop()
+ setattr(self.live.console, 'line', old_line)
+ # 在 jupyter 的情况下需要替换一下,不然会出不打印的问题。
+ self._need_renew_live = True if is_notebook() else False
+
+ def start(self) -> None:
+ super().start()
+ self.console.show_cursor(show=True)
+
+ def update(
+ self,
+ task_id: TaskID,
+ *,
+ total: Optional[float] = None,
+ completed: Optional[float] = None,
+ advance: Optional[float] = None,
+ description: Optional[str] = None,
+ visible: Optional[bool] = None,
+ refresh: bool = True,
+ **fields: Any,
+ ) -> None:
+ """Update information associated with a task.
+
+ Args:
+ task_id (TaskID): Task id (returned by add_task).
+ total (float, optional): Updates task.total if not None.
+ completed (float, optional): Updates task.completed if not None.
+ advance (float, optional): Add a value to task.completed if not None.
+ description (str, optional): Change task description if not None.
+ visible (bool, optional): Set visible flag if not None.
+ refresh (bool): Force a refresh of progress information. Default is False.
+ **fields (Any): Additional data fields required for rendering.
+ """
+ with self._lock:
+ task = self._tasks[task_id]
+ completed_start = task.completed
+
+ if total is not None and total != task.total:
+ task.total = total
+ task._reset()
+ if advance is not None:
+ task.completed += advance
+ if completed is not None:
+ task.completed = completed
+ if description is not None:
+ task.description = description
+ if visible is not None:
+ task.visible = visible
+ task.fields.update(fields)
+ update_completed = task.completed - completed_start
+
+ current_time = self.get_time()
+ old_sample_time = current_time - self.speed_estimate_period
+ _progress = task._progress
+
+ popleft = _progress.popleft
+ # 这里修改为至少保留一个,防止超长时间的迭代影响判断
+ while len(_progress)>1 and _progress[0].timestamp < old_sample_time:
+ popleft()
+ if update_completed > 0:
+ _progress.append(ProgressSample(current_time, update_completed))
+ if task.completed >= task.total and task.finished_time is None:
+ task.finished_time = task.elapsed
+
+ if refresh:
+ self.refresh()
+
+ @property
+ def dummy(self) -> bool:
+ """
+ 当前对象是否是 dummy 的 rich 对象。
+
+ :return:
+ """
+ return False
+
+ def not_empty(self):
+ return len(self._tasks) != 0
+
+
+class SpeedColumn(ProgressColumn):
+ """
+ 显示 task 的速度。
+
+ """
+ def render(self, task: "Task"):
+ speed = task.speed
+ if speed is None:
+ return Text('-- it./s', style='progress.data.speed')
+ if speed > 0.1:
+ return Text(str(round(speed, 2))+' it./s', style='progress.data.speed')
+ else:
+ return Text(str(round(1/speed, 2))+' s/it.', style='progress.data.speed')
+
+
+if ((sys.stdin and sys.stdin.isatty()) or is_notebook()) and \
+ get_global_rank() == 0:
+ # TODO 是不是应该可以手动关掉,防止一些 debug 问题
+ f_rich_progress = FRichProgress().new_progess(
+ "[progress.description]{task.description}",
+ "[progress.percentage]{task.percentage:>3.0f}%",
+ BarColumn(),
+ SpeedColumn(),
+ TimeElapsedColumn(),
+ "/",
+ TimeRemainingColumn(),
+ TextColumn("{task.fields[post_desc]}", justify="right"),
+ transient=True,
+ disable=False,
+ speed_estimate_period=30
+ )
+else:
+ f_rich_progress = DummyFRichProgress()
+
+
+if __name__ == '__main__':
+ f = DummyFRichProgress()
+ f.console.print('xxx')
+ f.console.print.print('xxx')
+ # 测试创建
+ import time
+
+ n_steps = 10
+
+ task_id = f_rich_progress.add_task(description='test', total=n_steps)
+ for i in range(n_steps):
+ f_rich_progress.update(task_id, description=f'test:{i}', advance=1, refresh=True)
+ print(f"test:{i}")
+ time.sleep(0.3)
+ f_rich_progress.remove_task(task_id)
+
+ # 测试一下 inner/outer
+ n_steps = 5
+ f_rich_progress.start()
+ outer_task_id = f_rich_progress.add_task(description='Outer:', total=n_steps)
+ inner_task_id = f_rich_progress.add_task(description='Inner:', total=n_steps)
+ for i in range(n_steps):
+ f_rich_progress.reset(inner_task_id, total=n_steps)
+ f_rich_progress.update(outer_task_id, description=f'Outer:{i}', advance=1, refresh=True)
+ for j in range(n_steps):
+ f_rich_progress.update(inner_task_id, description=f'Inner:{j}', advance=1, refresh=True,
+ post_desc='Loss: 0.334332323')
+ print(f"Outer:{i}, Inner:{j}")
+ time.sleep(0.3)
+
+ # 测试一下修改bar
+ f_rich_progress = FRichProgress().new_progess(
+ BarColumn(),
+ "[progress.description]{task.description}",
+ "[progress.percentage]{task.percentage:>3.0f}%",
+ TimeElapsedColumn(),
+ transient=True)
+ n_steps = 10
+ task_id = f_rich_progress.add_task(description='test', total=n_steps)
+ for i in range(n_steps):
+ f_rich_progress.update(task_id, description=f'test:{i}', advance=1)
+ print(f"test:{i}")
+ time.sleep(0.3)
+ f_rich_progress.remove_task(task_id)
+ f_rich_progress.stop()
diff --git a/fastNLP/core/utils/seq_len_to_mask.py b/fastNLP/core/utils/seq_len_to_mask.py
new file mode 100644
index 00000000..c3d0f9ec
--- /dev/null
+++ b/fastNLP/core/utils/seq_len_to_mask.py
@@ -0,0 +1,97 @@
+from typing import Optional
+
+import numpy as np
+from ...envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_ONEFLOW
+from .paddle_utils import paddle_to
+
+
+if _NEED_IMPORT_TORCH:
+ import torch
+
+if _NEED_IMPORT_PADDLE:
+ import paddle
+
+if _NEED_IMPORT_JITTOR:
+ import jittor
+
+if _NEED_IMPORT_ONEFLOW:
+ import oneflow
+
+
+def seq_len_to_mask(seq_len, max_len: Optional[int]=None):
+ r"""
+
+ 将一个表示 ``sequence length`` 的一维数组转换为二维的 ``mask`` ,不包含的位置为 **0**。
+
+ .. code-block::
+
+ >>> seq_len = torch.arange(2, 16)
+ >>> mask = seq_len_to_mask(seq_len)
+ >>> print(mask.size())
+ torch.Size([14, 15])
+ >>> seq_len = np.arange(2, 16)
+ >>> mask = seq_len_to_mask(seq_len)
+ >>> print(mask.shape)
+ (14, 15)
+ >>> seq_len = torch.arange(2, 16)
+ >>> mask = seq_len_to_mask(seq_len, max_len=100)
+ >>>print(mask.size())
+ torch.Size([14, 100])
+
+ :param seq_len: 大小为 ``(B,)`` 的长度序列;
+ :param int max_len: 将长度补齐或截断到 ``max_len``。默认情况(为 ``None``)使用的是 ``seq_len`` 中最长的长度;
+ 但在 :class:`torch.nn.DataParallel` 等分布式的场景下可能不同卡的 ``seq_len`` 会有区别,所以需要传入
+ ``max_len`` 使得 ``mask`` 的补齐或截断到该长度。
+ :return: 大小为 ``(B, max_len)`` 的 ``mask``, 元素类型为 ``bool`` 或 ``uint8``
+ """
+ max_len = int(max_len) if max_len is not None else int(seq_len.max())
+
+ if isinstance(seq_len, np.ndarray):
+ assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim}."
+ broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1))
+ mask = broad_cast_seq_len < seq_len.reshape(-1, 1)
+ return mask
+
+ try: # 尝试是否是 torch
+ if isinstance(seq_len, torch.Tensor):
+ assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}."
+ batch_size = seq_len.shape[0]
+ broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len)
+ mask = broad_cast_seq_len < seq_len.unsqueeze(1)
+ return mask
+ except NameError as e:
+ pass
+
+ try:
+ if isinstance(seq_len, paddle.Tensor):
+ assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}."
+ batch_size = seq_len.shape[0]
+ broad_cast_seq_len = paddle.arange(max_len).expand((batch_size, -1))
+ broad_cast_seq_len = paddle_to(broad_cast_seq_len, device=seq_len.place)
+ mask = broad_cast_seq_len < seq_len.unsqueeze(1)
+ return mask
+ except NameError as e:
+ pass
+
+ try:
+ if isinstance(seq_len, jittor.Var):
+ assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}."
+ batch_size = seq_len.shape[0]
+ broad_cast_seq_len = jittor.arange(max_len).reshape(1, max_len).expand(batch_size, -1)
+ mask = broad_cast_seq_len < seq_len.unsqueeze(1)
+ return mask
+ except NameError as e:
+ pass
+
+ try:
+ if isinstance(seq_len, oneflow.Tensor):
+ assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}."
+ batch_size = seq_len.shape[0]
+ broad_cast_seq_len = oneflow.arange(max_len).expand(batch_size, -1).to(seq_len)
+ mask = broad_cast_seq_len < seq_len.unsqueeze(1)
+ return mask
+ except NameError as e:
+ pass
+
+ raise TypeError("seq_len_to_mask function only supports numpy.ndarray, torch.Tensor, paddle.Tensor, "
+ f"jittor.Var and oneflow.Tensor, but got {type(seq_len)}")
\ No newline at end of file
diff --git a/fastNLP/core/utils/torch_utils.py b/fastNLP/core/utils/torch_utils.py
new file mode 100644
index 00000000..c58715b8
--- /dev/null
+++ b/fastNLP/core/utils/torch_utils.py
@@ -0,0 +1,79 @@
+from abc import ABC
+from typing import Any, Union, Optional
+from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_8
+DEFAULT_TORCH_GROUP = None
+if _NEED_IMPORT_TORCH:
+ import torch
+ if not _TORCH_GREATER_EQUAL_1_8:
+ DEFAULT_TORCH_GROUP = torch.distributed.distributed_c10d.group.WORLD
+
+__all__ = [
+ 'torch_move_data_to_device',
+ 'is_torch_module',
+]
+
+from .utils import apply_to_collection
+
+
+class TorchTransferableDataType(ABC):
+ """
+ A custom type for data that can be moved to a torch device via `.to(...)`.
+ Example::
+
+ >>> isinstance(dict, TorchTransferableDataType)
+ False
+ >>> isinstance(torch.rand(2, 3), TorchTransferableDataType)
+ True
+ >>> class CustomObject:
+ ... def __init__(self):
+ ... self.x = torch.rand(2, 2)
+ ... def to(self, device):
+ ... self.x = self.x.to(device)
+ ... return self
+ >>> isinstance(CustomObject(), TorchTransferableDataType)
+ True
+ """
+
+ @classmethod
+ def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
+ if cls is TorchTransferableDataType:
+ to = getattr(subclass, "to", None)
+ return callable(to)
+ return NotImplemented
+
+
+def torch_move_data_to_device(batch: Any, device: Optional[Union[str, "torch.device"]] = None,
+ non_blocking: Optional[bool] = True) -> Any:
+ r"""
+ 在 **pytorch** 中将数据集合 ``batch`` 传输到给定设备。任何定义方法 ``to(device)`` 的对象都将被移动并且集合中的所有其他对象将保持不变;
+
+ :param batch: 需要迁移的数据;
+ :param device: 数据应当迁移到的设备;当该参数的值为 ``None`` 时则不执行任何操作;
+ :param non_blocking: **pytorch** 的数据迁移方法 ``to`` 的参数;
+ :return: 迁移到新设备上的数据集合;
+ """
+ if device is None:
+ return batch
+
+ def batch_to(data: Any) -> Any:
+ kwargs = dict(non_blocking=non_blocking) if isinstance(data, torch.Tensor) else {}
+ data_output = data.to(device, **kwargs)
+ if data_output is not None:
+ return data_output
+ # user wrongly implemented the `TransferableDataType` and forgot to return `self`.
+ return data
+
+ dtype = TorchTransferableDataType
+ return apply_to_collection(batch, dtype=dtype, function=batch_to)
+
+def is_torch_module(model) -> bool:
+ """
+ 判断传入的 ``model`` 是否是 :class:`torch.nn.Module` 类型
+
+ :param model: 模型;
+ :return: 当前模型是否为 ``torch`` 的模型;
+ """
+ try:
+ return isinstance(model, torch.nn.Module)
+ except BaseException:
+ return False
\ No newline at end of file
diff --git a/fastNLP/core/utils/tqdm_progress.py b/fastNLP/core/utils/tqdm_progress.py
new file mode 100644
index 00000000..d6e0f9fb
--- /dev/null
+++ b/fastNLP/core/utils/tqdm_progress.py
@@ -0,0 +1,162 @@
+__all__ = [
+ 'f_tqdm_progress'
+]
+
+import uuid
+import sys
+from ...envs.utils import _module_available, _compare_version, _get_version
+
+from ...envs import get_global_rank
+from .utils import is_notebook
+from ..log import logger
+if _module_available('tqdm'):
+ from tqdm.autonotebook import tqdm
+import operator
+
+
+
+class Singleton(type):
+ _instances = {}
+
+ def __call__(cls, *args, **kwargs):
+ if cls not in cls._instances:
+ cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
+ return cls._instances[cls]
+
+
+# 如果不打印的时候,使得整个 progress 没有任何意义
+class DummyFTqdmProgress:
+ def __getattr__(self, item):
+ return DummyFTqdmProgress()
+
+ def __call__(self, *args, **kwargs):
+ # 防止用户通过 DummyFRichProgress.console.print() 这种调用
+ return None
+
+ @property
+ def dummy(self)->bool:
+ """
+ 当前对象是否是 dummy 的 tqdm 对象。
+
+ :return:
+ """
+ return True
+
+
+class TqdmProgress(metaclass=Singleton):
+ def __init__(self):
+ self.bars = {}
+
+ def add_task(self, description=None, total=None, leave=False,
+ ncols=None, mininterval=0.1, maxinterval=10.0, miniters=None,
+ ascii=None, visible=True, unit='it', unit_scale=False,
+ dynamic_ncols=False, smoothing=0.3, bar_format=None, initial=0,
+ postfix=None, unit_divisor=1000, write_bytes=None,
+ lock_args=None, nrows=None, colour=None, gui=False, **kwargs):
+ """
+ 主要就模仿了 tqdm bar 的创建,为了和 FRichProgress 的接口尽量统一,将 desc 重名为了 description,以及 disable 专为了
+ visible 。
+
+ :param description:
+ :param total:
+ :param leave:
+ :param ncols:
+ :param mininterval:
+ :param maxinterval:
+ :param miniters:
+ :param ascii:
+ :param visible:
+ :param unit:
+ :param unit_scale:
+ :param dynamic_ncols:
+ :param smoothing:
+ :param bar_format:
+ :param initial:
+ :param postfix:
+ :param unit_divisor:
+ :param write_bytes:
+ :param lock_args:
+ :param nrows:
+ :param colour:
+ :param gui:
+ :param kwargs:
+ :return:
+ """
+ if not _module_available('tqdm'):
+ raise ModuleNotFoundError("Package tqdm is not installed.")
+ elif not _compare_version('tqdm', operator.ge, '4.57'):
+ raise RuntimeError(f"Package tqdm>=4.57 is needed, instead of {_get_version('tqdm')}.")
+
+ from .rich_progress import f_rich_progress
+ assert not f_rich_progress.not_empty(), "Cannot use tqdm before rich finish loop."
+
+ if hasattr(self, 'orig_out_err'):
+ file = self.orig_out_err[0]
+ else:
+ file = sys.stdout
+
+ bar = tqdm(iterable=None, desc=description, total=total, leave=leave, file=file,
+ ncols=ncols, mininterval=mininterval, maxinterval=maxinterval, miniters=miniters,
+ ascii=ascii, disable=not visible, unit=unit, unit_scale=unit_scale,
+ dynamic_ncols=dynamic_ncols, smoothing=smoothing, bar_format=bar_format, initial=initial,
+ position=len(self.bars), postfix=postfix, unit_divisor=unit_divisor, write_bytes=write_bytes,
+ lock_args=lock_args, nrows=nrows, colour=colour, gui=gui, **kwargs)
+ _uuid = str(uuid.uuid1())
+ self.bars[_uuid] = bar
+ if not hasattr(self, 'orig_out_err') and not is_notebook():
+ from tqdm.contrib import DummyTqdmFile
+ self.orig_out_err = sys.stdout, sys.stderr
+ sys.stdout, sys.stderr = map(DummyTqdmFile, self.orig_out_err)
+
+ return _uuid
+
+ def update(self, task_id:str, advance:int, refresh=True):
+ self.bars[task_id].update(advance)
+
+ def set_postfix_str(self, task_id, s, refresh=True):
+ self.bars[task_id].set_postfix_str(s=s, refresh=refresh)
+
+ def set_description_str(self, task_id, desc, refresh=True):
+ self.bars[task_id].set_description_str(desc=desc, refresh=refresh)
+
+ def destroy_task(self, task_id):
+ """
+ 关闭 task_id 对应的 tqdm bar 。
+
+ :param task_id:
+ :return:
+ """
+ self.bars[task_id].close()
+ self.bars.pop(task_id)
+ if len(self.bars) == 0 and hasattr(self, 'orig_out_err'):
+ # recover 成正常的 sys.stdout 与 sys.stderr
+ sys.stdout, sys.stderr = self.orig_out_err
+ delattr(self, 'orig_out_err')
+
+ def reset(self, task_id):
+ self.bars[task_id].reset()
+
+ def print(self):
+ tqdm.write('')
+
+ def not_empty(self):
+ return len(self.bars) != 0
+
+ @property
+ def dummy(self) -> bool:
+ """
+ 当前对象是否是 dummy 的 tqdm 对象。
+
+ :return:
+ """
+ return False
+
+
+if ((sys.stdin and sys.stdin.isatty()) or is_notebook()) and get_global_rank() == 0:
+ f_tqdm_progress = TqdmProgress()
+else:
+ f_tqdm_progress = DummyFTqdmProgress()
+ logger.debug("Use dummy tqdm...")
+
+
+
diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py
new file mode 100644
index 00000000..6ca07e0d
--- /dev/null
+++ b/fastNLP/core/utils/utils.py
@@ -0,0 +1,679 @@
+import functools
+import inspect
+from inspect import Parameter
+import dataclasses
+from dataclasses import is_dataclass
+from copy import deepcopy
+from collections import defaultdict, OrderedDict
+from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence
+from typing import Tuple, Optional
+from time import sleep
+
+import os
+from contextlib import contextmanager
+from functools import wraps
+from prettytable import PrettyTable
+from pathlib import Path
+
+from fastNLP.core.log import logger
+
+
+__all__ = [
+ 'get_fn_arg_names',
+ 'auto_param_call',
+ 'check_user_specific_params',
+ 'dataclass_to_dict',
+ 'match_and_substitute_params',
+ 'apply_to_collection',
+ 'nullcontext',
+ 'pretty_table_printer',
+ 'Option',
+ 'deprecated',
+ "flat_nest_dict"
+]
+
+
+def get_fn_arg_names(fn: Callable) -> List[str]:
+ r"""
+ 该函数可以返回一个函数所有参数的名字::
+
+ >>> def function(a, b=1):
+ ... return a
+ ...
+ >>> get_fn_arg_names(function)
+ ['a', 'b']
+
+ :param fn: 需要查询的函数;
+ :return: 包含函数 ``fn`` 参数名的列表;
+ """
+ return list(inspect.signature(fn).parameters)
+
+
+def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None,
+ mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any:
+ r"""
+ 该函数会根据输入函数的形参名从 ``*args`` (均为 **dict** 类型)中找到匹配的值进行调用,如果传入的数据与 ``fn`` 的形参不匹配,可以通过
+ ``mapping`` 参数进行转换。``mapping`` 参数中的一对 ``(key, value)`` 表示在 ``*args`` 中找到 ``key`` 对应的值,并将这个值传递给形参中名为
+ ``value`` 的参数。
+
+ 1. 该函数用来提供给用户根据字符串匹配从而实现自动调用;
+ 2. 注意 ``mapping`` 默认为 ``None``,如果您希望指定输入和运行函数的参数的对应方式,那么您应当让 ``mapping`` 为一个字典传入进来;
+ 如果 ``mapping`` 不为 ``None``,那么我们一定会先使用 ``mapping`` 将输入的字典的 ``keys`` 修改过来,因此请务必亲自检查 ``mapping`` 的正确性;
+ 3. 如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值;
+ 4. 如果输入的函数是一个 ``partial`` 函数,情况同第三点,即和默认参数的情况相同;
+
+ Examples::
+
+ >>> # 1
+ >>> loss_fn = CrossEntropyLoss() # 如果其需要的参数为 def CrossEntropyLoss(y, pred);
+ >>> batch = {"x": 20, "y": 1}
+ >>> output = {"pred": 0}
+ >>> acc = auto_param_call(loss_fn, batch, output)
+
+ >>> # 2
+ >>> def test_fn(x, y, a, b=10):
+ >>> return x + y + a + b
+ >>> print(auto_param_call(test_fn, {"x": 10}, {"y": 20, "a": 30})) # res: 70
+ >>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20})) # res: 140
+ >>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20, "a": 200})) # res: 240
+
+ :param fn: 用来进行实际计算的函数,其参数可以包含有默认值;
+ :param args: 一系列的位置参数,应当为一系列的字典,我们需要从这些输入中提取 ``fn`` 计算所需要的实际参数;
+ :param signature_fn: 函数,用来替换 ``fn`` 的函数签名,如果该参数不为 ``None``,那么我们首先会从该函数中提取函数签名,
+ 然后通过该函数签名提取参数值后,再传给 ``fn`` 进行实际的运算;
+ :param mapping: 一个字典,用来更改其前面的字典的键值;
+
+ :return: ``fn`` 运行的结果;
+ """
+
+ if signature_fn is not None:
+ if not callable(signature_fn):
+ raise ValueError(f"Parameter `signature_fn` should be `Callable`.")
+ _need_params = OrderedDict(inspect.signature(signature_fn).parameters)
+ else:
+ _need_params = OrderedDict(inspect.signature(fn).parameters)
+ _kwargs = None
+ for _name, _param in _need_params.items():
+ if _param.kind == Parameter.VAR_POSITIONAL:
+ fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn)
+ raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn_msg}.")
+ if _param.kind == Parameter.VAR_KEYWORD:
+ _kwargs = (_name, _param)
+
+ if _kwargs is not None:
+ _need_params.pop(_kwargs[0])
+
+ _default_params = {}
+ for _name, _param in _need_params.items():
+ if _param.default != Parameter.empty:
+ _default_params[_name] = _param.default
+
+ if mapping is not None:
+ fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn)
+ assert isinstance(mapping, Dict), f"Exception happens when calling {fn_msg}. " \
+ f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}."
+
+ _has_params = {}
+ duplicate_names = []
+ for arg in args:
+ if not isinstance(arg, Dict):
+ fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn)
+ raise TypeError(f"Exception happens when calling {fn_msg}. "
+ f"The input part of function `auto_param_call` must be `Dict` type, instead of {type(arg)}.")
+ for _name, _value in arg.items():
+ if mapping is not None and _name in mapping:
+ _name = mapping[_name]
+
+ if _name not in _has_params:
+ if _kwargs is not None or _name in _need_params:
+ _has_params[_name] = _value
+ # 同一参数对象在两个输入的资源中都出现,造成混淆;
+ elif _name in _need_params and not (_has_params[_name] is _value):
+ duplicate_names.append(_name)
+ if duplicate_names:
+ fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn)
+ raise ValueError(f"The following key present in several inputs:{duplicate_names} when calling {fn_msg}.")
+
+ # 将具有默认值但是没有被输入修改过的参数值传进去;
+ for _name, _value in _default_params.items():
+ if _name not in _has_params:
+ _has_params[_name] = _value
+
+ if len(_has_params) < len(_need_params):
+ miss_params = list(set(_need_params.keys()) - set(_has_params.keys()))
+ fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn)
+ _provided_keys = _get_keys(args)
+ raise ValueError(f"The parameters:`{miss_params}` needed by function:{fn_msg} "
+ f"are not found in the input keys({_provided_keys}).")
+
+ return fn(**_has_params)
+
+
+def _get_keys(args:List[Dict]) -> List[List[str]]:
+ """
+ 返回每个 dict 的 keys
+
+ :param args:
+ :return:
+ """
+ _provided_keys = []
+ for arg in args:
+ _provided_keys.append(list(arg.keys()))
+ return _provided_keys
+
+
+def _get_fun_msg(fn, with_fp=True)->str:
+ """
+ 获取函数的基本信息,帮助报错::
+
+ >>>> print(_get_fun_msg(_get_fun_msg))
+ `_get_fun_msg(fn) -> str`(In file:/Users/hnyan/Desktop/projects/fastNLP/fastNLP/fastNLP/core/utils/utils.py)
+
+ :param callable fn:
+ :param with_fp: 是否包含函数所在的文件信息;
+ :return:
+ """
+ if isinstance(fn, functools.partial):
+ return _get_fun_msg(fn.func)
+ try:
+ fn_name = fn.__qualname__ + str(inspect.signature(fn))
+ except:
+ fn_name = str(fn)
+ if with_fp:
+ try:
+ fp = '(In file:' + os.path.abspath(inspect.getfile(fn)) + ')'
+ except:
+ fp = ''
+ else:
+ fp = ''
+ msg = f'`{fn_name}`' + fp
+ return msg
+
+
+def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None):
+ """
+ 检查一个函数是否需要 expected_params 参数(检测数量是否匹配)。除掉 self (如果是method),给定默认值的参数等。
+ 如果匹配不上,就会进行报错。
+
+ :param fn: 需要检测的函数,可以是 method 或者 function 。
+ :param expected_params: 期待应该支持的参数。
+ :param fn_name: fn 的名字,当传入的 fn 不是 callable 的时候方便报错。
+ :return:
+ """
+ if fn_name is not None:
+ assert callable(fn), f"`{fn_name}` should be callable, instead of `{type(fn)}`."
+
+ try:
+ args = []
+ kwargs = {}
+ name = ''
+ if isinstance(fn, functools.partial) and not hasattr(fn, '__name__'):
+ name = 'partial:'
+ f = fn.func
+ while isinstance(f, functools.partial):
+ name += 'partial:'
+ f = f.func
+ fn.__name__ = name + f.__name__
+ inspect.getcallargs(fn, *args, *expected_params, **kwargs)
+ if name: # 如果一开始没有name的,需要给人家删除掉
+ delattr(fn, '__name__')
+
+ except TypeError as e:
+ logger.error(f"The function:{_get_fun_msg(fn)} will be provided with parameters:{expected_params}. "
+ f"The following exception will happen.")
+ raise e
+
+
+def check_user_specific_params(user_params: Dict, fn: Callable, fn_name=None):
+ """
+ 该函数使用用户的输入来对指定函数的参数进行赋值,主要用于一些用户无法直接调用函数的情况;
+ 主要作用在于帮助检查用户对使用函数 ``fn`` 的参数输入是否有误;
+
+ :param user_params: 用户指定的参数的值,应当是一个字典,其中 ``key`` 表示每一个参数的名字,
+ ``value`` 为每一个参数的值;
+ :param fn: 将要被调用的函数;
+ :param fn_name: 在打印提示信息是如何显示函数名
+ :return: 一个字典,其中为在之后调用函数 ``fn`` 时真正会被传进去的参数的值;
+ """
+ if fn_name is None:
+ fn_name = fn.__name__
+
+ fn_arg_names = get_fn_arg_names(fn)
+ for arg_name, arg_value in user_params.items():
+ if arg_name not in fn_arg_names:
+ logger.rank_zero_warning(f"Notice parameter `{arg_name}` may not be used by `{fn_name}`.")
+ return user_params
+
+
+def dataclass_to_dict(data: "dataclasses.dataclass") -> Dict:
+ """
+ 将传入的 ``dataclass`` 实例转换为字典。
+ """
+ if not is_dataclass(data):
+ raise TypeError(f"Parameter `data` can only be `dataclass` type instead of {type(data)}.")
+ _dict = dict()
+ for _key in data.__dataclass_fields__:
+ _dict[_key] = getattr(data, _key)
+ return _dict
+
+
+def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None, data: Optional[Any] = None) -> Any:
+ r"""
+ 用来实现将输入的 **batch** 或者输出的 **outputs** 通过 ``mapping`` 将键值进行更换的功能;
+ 该函数应用于 ``input_mapping`` 和 ``output_mapping``;
+
+ * 对于 ``input_mapping``,该函数会在 :class:`~fastNLP.core.controllers.TrainBatchLoop` 中取完数据后立刻被调用;
+ * 对于 ``output_mapping``,该函数会在 :class:`~fastNLP.core.Trainer` 的 :meth:`~fastNLP.core.Trainer.train_step`
+ 以及 :class:`~fastNLP.core.Evaluator` 的 :meth:`~fastNLP.core.Evaluator.train_step` 中得到结果后立刻被调用;
+
+ 转换的逻辑按优先级依次为:
+
+ 1. 如果 ``mapping`` 是一个函数,那么会直接返回 **mapping(data)**;
+ 2. 如果 ``mapping`` 是一个 **Dict**,那么 ``data`` 的类型只能为以下三种: ``[Dict, dataclass, Sequence]``;
+
+ * 如果 ``data`` 是 **Dict**,那么该函数会将 ``data`` 的 ``key`` 替换为 **mapping[key]**;
+ * 如果 ``data`` 是 **dataclass**,那么该函数会先使用 :func:`dataclasses.asdict` 函数将其转换为 **Dict**,然后进行转换;
+ * 如果 ``data`` 是 **Sequence**,那么该函数会先将其转换成一个对应的字典::
+
+ {
+ "_0": list[0],
+ "_1": list[1],
+ ...
+ }
+
+ 然后使用 ``mapping`` 对这个字典进行转换,如果没有匹配上 ``mapping`` 中的 ``key`` 则保持 ``'_number'`` 这个形式。
+
+ :param mapping: 用于转换的字典或者函数;当 ``mapping`` 是函数时,返回值必须为字典类型;
+ :param data: 需要被转换的对象;
+ :return: 转换后的结果;
+ """
+ if mapping is None:
+ return data
+ if callable(mapping):
+ # 注意我们在 `Trainer.extract_loss_from_outputs` 函数里会检查 outputs 的输出,outputs 的类型目前只支持 `Dict` 和 `dataclass`;
+ return mapping(data)
+
+ if not isinstance(mapping, Dict):
+ raise ValueError(
+ f"Parameter `mapping` should be of type `Dict` or `Callable`, not `{type(mapping)}`. This is caused"
+ f"by your `input_mapping` or `output_mapping` parameter in your `Trainer` or `Evaluator`.")
+ if not isinstance(data, Dict) and not is_dataclass(data) and not isinstance(data, Sequence):
+ raise ValueError("Parameter `data` should be type `Dict` or `dataclass` when the other parameter `mapping` is "
+ "type `Dict`.")
+
+ # 如果 `data` 是一个 dataclass,那么先将其转换为一个 `Dict`;
+ if is_dataclass(data):
+ data = dataclass_to_dict(data)
+ # 如果 `data` 是一个 List,那么我们同样先将其转换为一个 `Dict`,为 {"_0": list[0], "_1": list[1], ...};
+ elif isinstance(data, Sequence):
+ data = {"_" + str(i): data[i] for i in range(len(data))}
+
+ _new_data = {}
+ for _name, _value in data.items():
+ if _name in mapping:
+ _new_data[mapping[_name]] = _value
+ else:
+ _new_data[_name] = _value
+ return _new_data
+
+
+def _is_namedtuple(obj: object) -> bool:
+ # https://github.com/pytorch/pytorch/blob/v1.8.1/torch/nn/parallel/scatter_gather.py#L4-L8
+ return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
+
+
+def _is_dataclass_instance(obj: object) -> bool:
+ # https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions
+ return dataclasses.is_dataclass(obj) and not isinstance(obj, type)
+
+
+def apply_to_collection(
+ data: Any,
+ dtype: Union[type, Any, Tuple[Union[type, Any]]],
+ function: Callable,
+ *args: Any,
+ wrong_dtype: Optional[Union[type, Tuple[type]]] = None,
+ include_none: bool = True,
+ **kwargs: Any,
+) -> Any:
+ """
+ 递归地对 ``data`` 中的元素执行函数 ``function``,且仅在满足元素为 ``dtype`` 时执行。
+
+ 该函数参考了 `pytorch-lightning `_ 的实现
+
+ :param data: 需要进行处理的数据集合或数据;
+ :param dtype: 数据的类型,函数 ``function`` 只会被应用于 ``data`` 中类型为 ``dtype`` 的数据;
+ :param function: 对数据进行处理的函数;
+ :param args: ``function`` 所需要的其它参数;
+ :param wrong_dtype: ``function`` 一定不会生效的数据类型。
+ 如果数据既是 ``wrong_dtype`` 类型又是 ``dtype`` 类型那么也不会生效;
+ :param include_none: 是否包含执行结果为 ``None`` 的数据,默认为 ``True``;
+ :param kwargs: ``function`` 所需要的其它参数;
+ :return: 经过 ``function`` 处理后的数据集合;
+ """
+ # Breaking condition
+ if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)):
+ return function(data, *args, **kwargs)
+
+ elem_type = type(data)
+
+ # Recursively apply to collection items
+ if isinstance(data, Mapping):
+ out = []
+ for k, v in data.items():
+ v = apply_to_collection(
+ v, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs
+ )
+ if include_none or v is not None:
+ out.append((k, v))
+ if isinstance(data, defaultdict):
+ return elem_type(data.default_factory, OrderedDict(out))
+ return elem_type(OrderedDict(out))
+
+ is_namedtuple = _is_namedtuple(data)
+ is_sequence = isinstance(data, Sequence) and not isinstance(data, str)
+ if is_namedtuple or is_sequence:
+ out = []
+ for d in data:
+ v = apply_to_collection(
+ d, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs
+ )
+ if include_none or v is not None:
+ out.append(v)
+ return elem_type(*out) if is_namedtuple else elem_type(out)
+
+ if _is_dataclass_instance(data):
+ # make a deepcopy of the data,
+ # but do not deepcopy mapped fields since the computation would
+ # be wasted on values that likely get immediately overwritten
+ fields = {}
+ memo = {}
+ for field in dataclasses.fields(data):
+ field_value = getattr(data, field.name)
+ fields[field.name] = (field_value, field.init)
+ memo[id(field_value)] = field_value
+ result = deepcopy(data, memo=memo)
+ # apply function to each field
+ for field_name, (field_value, field_init) in fields.items():
+ if field_init:
+ v = apply_to_collection(
+ field_value,
+ dtype,
+ function,
+ *args,
+ wrong_dtype=wrong_dtype,
+ include_none=include_none,
+ **kwargs,
+ )
+ if not field_init or (not include_none and v is None): # retain old value
+ v = getattr(data, field_name)
+ setattr(result, field_name, v)
+ return result
+
+ # data is neither of dtype, nor a collection
+ return data
+
+
+@contextmanager
+def nullcontext():
+ r"""
+ 实现一个什么都不做的上下文环境。
+ """
+ yield
+
+
+def sub_column(string: str, c: int, c_size: int, title: str) -> str:
+ r"""
+ 对传入的字符串进行截断,方便在命令行中显示。
+
+ :param string: 要被截断的字符串;
+ :param c: 命令行列数;
+ :param c_size: :class:`~fastNLP.core.Instance` 或 :class:`~fastNLP.core.DataSet` 的 ``field`` 数目;
+ :param title: 列名;
+ :return: 对一个过长的列进行截断的结果;
+ """
+ avg = max(int(c / c_size / 2), len(title))
+ string = str(string)
+ res = ""
+ counter = 0
+ for char in string:
+ if ord(char) > 255:
+ counter += 2
+ else:
+ counter += 1
+ res += char
+ if counter > avg:
+ res = res + "..."
+ break
+ return res
+
+
+def _is_iterable(value):
+ # 检查是否是iterable的, duck typing
+ try:
+ iter(value)
+ return True
+ except BaseException as e:
+ return False
+
+
+def pretty_table_printer(dataset_or_ins) -> PrettyTable:
+ r"""
+ 用于在 **fastNLP** 中展示数据的函数::
+
+ >>> ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"])
+ +-----------+-----------+-----------------+
+ | field_1 | field_2 | field_3 |
+ +-----------+-----------+-----------------+
+ | [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] |
+ +-----------+-----------+-----------------+
+
+ :param dataset_or_ins: 要展示的 :class:`~fastNLP.core.DataSet` 或者 :class:`~fastNLP.core.Instance` 实例;
+ :return: 根据命令行大小进行自动截断的数据表格;
+ """
+ x = PrettyTable()
+ try:
+ sz = os.get_terminal_size()
+ column = sz.columns
+ row = sz.lines
+ except OSError:
+ column = 144
+ row = 11
+
+ if type(dataset_or_ins).__name__ == "DataSet":
+ x.field_names = list(dataset_or_ins.field_arrays.keys())
+ c_size = len(x.field_names)
+ for ins in dataset_or_ins:
+ x.add_row([sub_column(ins[k], column, c_size, k) for k in x.field_names])
+ row -= 1
+ if row < 0:
+ x.add_row(["..." for _ in range(c_size)])
+ break
+ elif type(dataset_or_ins).__name__ == "Instance":
+ x.field_names = list(dataset_or_ins.fields.keys())
+ c_size = len(x.field_names)
+ x.add_row([sub_column(dataset_or_ins[k], column, c_size, k) for k in x.field_names])
+
+ else:
+ raise Exception("only accept DataSet and Instance")
+ x.align = "l"
+
+ return x
+
+
+class Option(dict):
+ r"""将键转化为属性的字典类型"""
+
+ def __getattr__(self, item):
+ try:
+ return self.__getitem__(item)
+ except KeyError:
+ raise AttributeError(item)
+
+ def __setattr__(self, key, value):
+ if key.startswith('__') and key.endswith('__'):
+ raise AttributeError(key)
+ self.__setitem__(key, value)
+
+ def __delattr__(self, item):
+ try:
+ self.pop(item)
+ except KeyError:
+ raise AttributeError(item)
+
+ def __getstate__(self):
+ return self
+
+ def __setstate__(self, state):
+ self.update(state)
+
+
+_emitted_deprecation_warnings = set()
+
+
+def deprecated(help_message: Optional[str] = None):
+ """
+ 标记当前功能已经过时的装饰器。
+
+ :param help_message: 一段指引信息,告知用户如何将代码切换为当前版本提倡的用法;
+ """
+
+ def decorator(deprecated_function: Callable):
+ global _emitted_deprecation_warnings
+ warning_msg = (
+ (
+ f"{deprecated_function.__name__} is deprecated and will be removed "
+ "in the next major version of datasets."
+ )
+ + f" {help_message}"
+ if help_message
+ else ""
+ )
+
+ @wraps(deprecated_function)
+ def wrapper(*args, **kwargs):
+ func_hash = hash(deprecated_function)
+ if func_hash not in _emitted_deprecation_warnings:
+ logger.warning(warning_msg, category=FutureWarning, stacklevel=2)
+ _emitted_deprecation_warnings.add(func_hash)
+ return deprecated_function(*args, **kwargs)
+
+ wrapper._decorator_name_ = "deprecated"
+ return wrapper
+
+ return decorator
+
+
+def wait_filepath(path, exist=True):
+ """
+ 等待当 path 的存在状态为 {exist} 时返回
+
+ :param path: 待检测的 path
+ :param exist: 为 True 时表明检测这个 path 存在就返回; 为 False 表明检测到这个 path 不存在 返回。
+ :return:
+ """
+ if isinstance(path, str):
+ path = Path(path)
+ assert isinstance(path, Path)
+ count = 0
+ while True:
+ sleep(0.01)
+ if path.exists() == exist:
+ break
+ count += 1
+ if count % 1000 == 0:
+ msg = 'create' if exist else 'delete'
+ logger.warning(f"Waiting path:{path} to {msg} for {count*0.01} seconds...")
+
+
+def get_class_that_defined_method(method):
+ """
+ 给定一个method,返回这个 method 的 class 的对象
+
+ :param method:
+ :return:
+ """
+ if isinstance(method, functools.partial):
+ return get_class_that_defined_method(method.func)
+ if inspect.ismethod(method) or (inspect.isbuiltin(method) and getattr(method, '__self__', None) is not None and getattr(method.__self__, '__class__', None)):
+ for cls in inspect.getmro(method.__self__.__class__):
+ if method.__name__ in cls.__dict__:
+ return cls
+ method = getattr(method, '__func__', method) # fallback to __qualname__ parsing
+ if inspect.isfunction(method):
+ cls = getattr(inspect.getmodule(method),
+ method.__qualname__.split('.', 1)[0].rsplit('.', 1)[0],
+ None)
+ if isinstance(cls, type):
+ return cls
+ return getattr(method, '__objclass__', None) # handle special descriptor objects
+
+
+def is_notebook():
+ """
+ 检查当前运行环境是否为 jupyter
+
+ :return:
+ """
+ try:
+ from IPython import get_ipython
+
+ if "IPKernelApp" not in get_ipython().config: # pragma: no cover
+ raise ImportError("console")
+ if "VSCODE_PID" in os.environ: # pragma: no cover
+ raise ImportError("vscode")
+ except:
+ return False
+ else: # pragma: no cover
+ return True
+
+
+def flat_nest_dict(d:Dict, separator:str='#', compress_none_key:bool=True, top_down:bool=False) -> Dict:
+ """
+ 将一个 nested 的 dict 转成 flat 的 dict,例如
+ ex::
+ d = {'test': {'f1': {'f': 0.2, 'rec': 0.1}}} -> {'f#f1#test':0.2, 'rec#f1#test':0.1}
+
+ :param d: 需要展平的 dict 对象。
+ :param separator: 不同层级之间的 key 之间的连接符号。
+ :param compress_none_key: 如果有 key 为 None ,则忽略这一层连接。
+ :param top_down: 新的 key 的是否按照从最底层往最底层的顺序连接。
+ :return:
+ """
+ assert isinstance(d, Dict)
+ assert isinstance(separator, str)
+ flat_d = {}
+ for key, value in d.items():
+ if key is None:
+ key = ()
+ else:
+ key = (key, )
+ if isinstance(value, Mapping):
+ flat_d.update(_flat_nest_dict(value, parent_key=key, compress_none_key=compress_none_key))
+ else:
+ flat_d[key] = value
+
+ str_flat_d = {}
+ for key, value in flat_d.items():
+ if top_down:
+ key = map(str, key)
+ else:
+ key = map(str, key[::-1])
+ key = separator.join(key)
+ str_flat_d[key] = value
+ return str_flat_d
+
+
+def _flat_nest_dict(d:Mapping, parent_key:Tuple, compress_none_key:bool):
+ flat_d = {}
+ for k, v in d.items():
+ _key = parent_key
+ if k is not None:
+ _key = _key + (k,)
+ if isinstance(v, Mapping):
+ _d = _flat_nest_dict(v, parent_key=_key, compress_none_key=compress_none_key)
+ flat_d.update(_d)
+ else:
+ flat_d[_key] = v
+
+ return flat_d
diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py
index aef99034..3a6ab650 100644
--- a/fastNLP/core/vocabulary.py
+++ b/fastNLP/core/vocabulary.py
@@ -11,15 +11,19 @@ __all__ = [
from collections import Counter
from functools import partial
from functools import wraps
+from typing import List, Callable, Union
-from ._logger import logger
-from .dataset import DataSet
-from .utils import Option
-from .utils import _is_iterable
+from fastNLP.core.dataset import DataSet
+from fastNLP.core.utils.utils import Option
+from fastNLP.core.utils.utils import _is_iterable
+from .log import logger
import io
class VocabularyOption(Option):
+ """
+
+ """
def __init__(self,
max_size=None,
min_freq=None,
@@ -33,8 +37,11 @@ class VocabularyOption(Option):
)
-def _check_build_vocab(func):
- r"""A decorator to make sure the indexing is built before used.
+def _check_build_vocab(func: Callable):
+ r"""
+ A decorator to make sure the indexing is built before used.
+
+ :param func: 传入的callable函数
"""
@@ -48,7 +55,10 @@ def _check_build_vocab(func):
def _check_build_status(func):
- r"""A decorator to check whether the vocabulary updates after the last build.
+ r"""
+ A decorator to check whether the vocabulary updates after the last build.
+
+ :param func: 用户传入要修饰的callable函数
"""
@@ -57,7 +67,7 @@ def _check_build_status(func):
if self.rebuild is False:
self.rebuild = True
if self.max_size is not None and len(self.word_count) >= self.max_size:
- logger.info("[Warning] Vocabulary has reached the max size {} when calling {} method. "
+ logger.warning("Vocabulary has reached the max size {} when calling {} method. "
"Adding more words may cause unexpected behaviour of Vocabulary. ".format(
self.max_size, func.__name__))
return func(self, *args, **kwargs)
@@ -69,28 +79,26 @@ class Vocabulary(object):
r"""
用于构建, 存储和使用 `str` 到 `int` 的一一映射::
+ from fastNLP.core import Vocabulary
vocab = Vocabulary()
word_list = "this is a word list".split()
+ # vocab更新自己的字典,输入为list列表
vocab.update(word_list)
vocab["word"] # str to int
vocab.to_word(5) # int to str
+
+ :param max_size: `Vocabulary` 的最大大小, 即能存储词的最大数量
+ 若为 ``None`` , 则不限制大小。
+ :param min_freq: 能被记录下的词在文本中的最小出现频率, 应大于或等于 1。
+ 若小于该频率, 词语将被视为 `unknown`. 若为 ``None`` , 所有文本中的词都被记录。
+ :param padding: padding的字符. 如果设置为 ``None`` ,
+ 则vocabulary中不考虑padding, 也不计入词表大小,为 ``None`` 的情况多在为 label 建立 Vocabulary 的情况。
+ :param unknown: unknown的字符,所有未被记录的词在转为 :class:`int` 时将被视为 `unknown` 。
+ 如果设置为 ``None`` ,则 vocabulary 中不考虑 `unknown`, 也不计入词表大小。
+ 为 ``None`` 的情况多在为 labe l建立 Vocabulary 的情况
"""
-
- def __init__(self, max_size=None, min_freq=None, padding='', unknown=''):
- r"""
-
- :param int max_size: `Vocabulary` 的最大大小, 即能存储词的最大数量
- 若为 ``None`` , 则不限制大小. Default: ``None``
- :param int min_freq: 能被记录下的词在文本中的最小出现频率, 应大于或等于 1.
- 若小于该频率, 词语将被视为 `unknown`. 若为 ``None`` , 所有文本中的词都被记录. Default: ``None``
- :param str optional padding: padding的字符. 如果设置为 ``None`` ,
- 则vocabulary中不考虑padding, 也不计入词表大小,为 ``None`` 的情况多在为label建立Vocabulary的情况.
- Default: ''
- :param str optional unknown: unknown的字符,所有未被记录的词在转为 `int` 时将被视为unknown.
- 如果设置为 ``None`` ,则vocabulary中不考虑unknow, 也不计入词表大小.
- 为 ``None`` 的情况多在为label建立Vocabulary的情况.
- Default: ''
- """
+
+ def __init__(self, max_size:int=None, min_freq:int=None, padding:str='', unknown:str=''):
self.max_size = max_size
self.min_freq = min_freq
self.word_count = Counter()
@@ -121,45 +129,60 @@ class Vocabulary(object):
self._word2idx = value
@_check_build_status
- def update(self, word_lst, no_create_entry=False):
- r"""依次增加序列中词在词典中的出现频率
-
- :param list word_lst: a list of strings
- :param bool no_create_entry: 如果词语来自于非训练集建议设置为True。在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。
- 如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独
- 的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新
- 加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这
- 个词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary中且并不是no_create_entry的,
- 则这个词将认为是需要创建单独的vector的。
+ def update(self, word_lst: list, no_create_entry:bool=False):
+ r"""
+ 依次增加序列中词在词典中的出现频率
+
+ :param word_lst: 列表形式的词语,如 word_list=['I', 'am', 'a', 'Chinese'],列表中的每个词会计算出现频率并加入到词典中。
+ :param no_create_entry: 如果词语来自于非训练集建议设置为 ``True`` 。
+
+ * 如果为 ``True`` -- 则不会有这个词语创建一个单独的 entry ,它将一直被指向 ```` 的表示;
+ * 如果为 ``False`` -- 为这个词创建一个单独的 entry。如果这个词来自于验证集或训练集,一般设置为True,如果来自于训练集一
+ 般设置为``False``;
+
+ 有以下两种情况: 如果新加入一个 word ,且 ``no_create_entry`` 为 ``True``,但这个词之前已经在 Vocabulary 中且并不是
+ ``no_create_entry`` 的,则还是会为这个词创建一个单独的 vector ; 如果 ``no_create_entry`` 为 ``False`` ,但这个词之
+ 前已经在 Vocabulary 中且并不是 ``no_create_entry的`` ,则这个词将认为是需要创建单独的 vector 的。
+
"""
self._add_no_create_entry(word_lst, no_create_entry)
self.word_count.update(word_lst)
return self
@_check_build_status
- def add(self, word, no_create_entry=False):
+ def add(self, word:str, no_create_entry:bool=False):
r"""
增加一个新词在词典中的出现频率
- :param str word: 新词
- :param bool no_create_entry: 如果词语来自于非训练集建议设置为True。在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。
- 如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独
- 的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新
- 加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这
- 个词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary中且并不是no_create_entry的,
- 则这个词将认为是需要创建单独的vector的。
+ :param word: 要添加进字典的新词, ``word`` 为一个字符串
+ :param no_create_entry: 如果词语来自于非训练集建议设置为 ``True`` 。
+
+ * 如果为 ``True`` -- 则不会有这个词语创建一个单独的 entry ,它将一直被指向 ```` 的表示;
+ * 如果为 ``False`` -- 为这个词创建一个单独的 entry。如果这个词来自于验证集或训练集,一般设置为 ``True`` ,如果来自于训练集一
+ 般设置为 ``False``;
+
+ 有以下两种情况: 如果新加入一个 word ,且 ``no_create_entry`` 为 ``True``,但这个词之前已经在 Vocabulary 中且并不是
+ ``no_create_entry`` 的,则还是会为这个词创建一个单独的 vector ; 如果 ``no_create_entry`` 为 ``False`` ,但这个词之
+ 前已经在 Vocabulary 中且并不是 ``no_create_entry的`` ,则这个词将认为是需要创建单独的 vector 的。
+
"""
self._add_no_create_entry(word, no_create_entry)
self.word_count[word] += 1
return self
- def _add_no_create_entry(self, word, no_create_entry):
+ def _add_no_create_entry(self, word:Union[str, List[str]], no_create_entry:bool):
r"""
在新加入word时,检查_no_create_word的设置。
- :param str List[str] word:
- :param bool no_create_entry:
+ :param word: 要添加的新词或者是 :class:`List`类型的新词,如 word='I' 或者 word=['I', 'am', 'a', 'Chinese'] 均可
+ :param no_create_entry: 如果词语来自于非训练集建议设置为 ``True`` 。
+
+ * 如果为 ``True`` -- 则不会有这个词语创建一个单独的 entry ,它将一直被指向 ```` 的表示;
+ * 如果为 ``False`` -- 为这个词创建一个单独的 entry。如果这个词来自于验证集或训练集,一般设置为 ``True`` ,如果来自于训练集一
+ 般设置为 ``False``;
+
:return:
+
"""
if isinstance(word, str) or not _is_iterable(word):
word = [word]
@@ -170,41 +193,48 @@ class Vocabulary(object):
self._no_create_word.pop(w)
@_check_build_status
- def add_word(self, word, no_create_entry=False):
+ def add_word(self, word:str, no_create_entry:bool=False):
r"""
增加一个新词在词典中的出现频率
- :param str word: 新词
- :param bool no_create_entry: 如果词语来自于非训练集建议设置为True。在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。
- 如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独
- 的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新
- 加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这
- 个词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary中且并不是no_create_entry的,
- 则这个词将认为是需要创建单独的vector的。
+ :param word: 要添加进字典的新词, ``word`` 为一个字符串
+ :param no_create_entry: 如果词语来自于非训练集建议设置为 ``True`` 。
+
+ * 如果为 ``True`` -- 则不会有这个词语创建一个单独的 entry ,它将一直被指向 ```` 的表示;
+ * 如果为 ``False`` -- 为这个词创建一个单独的 entry。如果这个词来自于验证集或训练集,一般设置为 ``True`` ,如果来自于训练集一
+ 般设置为 ``False``;
+
+ 有以下两种情况: 如果新加入一个 word ,且 ``no_create_entry`` 为 ``True``,但这个词之前已经在 Vocabulary 中且并不是
+ ``no_create_entry`` 的,则还是会为这个词创建一个单独的 vector ; 如果 ``no_create_entry`` 为 ``False`` ,但这个词之
+ 前已经在 Vocabulary 中且并不是 ``no_create_entry的`` ,则这个词将认为是需要创建单独的 vector 的。
+
"""
self.add(word, no_create_entry=no_create_entry)
@_check_build_status
- def add_word_lst(self, word_lst, no_create_entry=False):
+ def add_word_lst(self, word_lst: List[str], no_create_entry:bool=False):
r"""
依次增加序列中词在词典中的出现频率
- :param list[str] word_lst: 词的序列
- :param bool no_create_entry: 如果词语来自于非训练集建议设置为True。在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。
- 如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独
- 的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新
- 加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这
- 个词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary中且并不是no_create_entry的,
- 则这个词将认为是需要创建单独的vector的。
+ :param word_lst: 需要添加的新词的 list 序列,如 word_lst=['I', 'am', 'a', 'Chinese'] 。
+ :param no_create_entry: 如果词语来自于非训练集建议设置为 ``True`` 。
+
+ * 如果为 ``True`` -- 则不会有这个词语创建一个单独的 entry ,它将一直被指向 ```` 的表示;
+ * 如果为 ``False`` -- 为这个词创建一个单独的 entry。如果这个词来自于验证集或训练集,一般设置为 ``True`` ,如果来自于训练集一
+ 般设置为 ``False``;
+
+ 有以下两种情况: 如果新加入一个 word ,且 ``no_create_entry`` 为 ``True``,但这个词之前已经在 Vocabulary 中且并不是
+ ``no_create_entry`` 的,则还是会为这个词创建一个单独的 vector ; 如果 ``no_create_entry`` 为 ``False`` ,但这个词之
+ 前已经在 Vocabulary 中且并不是 ``no_create_entry的`` ,则这个词将认为是需要创建单独的 vector 的。
+
"""
self.update(word_lst, no_create_entry=no_create_entry)
return self
def build_vocab(self):
r"""
- 根据已经出现的词和出现频率构建词典. 注意: 重复构建可能会改变词典的大小,
- 但已经记录在词典中的词, 不会改变对应的 `int`
-
+ 根据已经出现的词和出现频率构建词典。注意:重复构建可能会改变词典的大小,
+ 但已经记录在词典中的词,不会改变对应的 :class:`int`
"""
if self._word2idx is None:
self._word2idx = {}
@@ -238,7 +268,7 @@ class Vocabulary(object):
return len(self._word2idx)
@_check_build_vocab
- def __contains__(self, item):
+ def __contains__(self, item:str):
r"""
检查词是否被记录
@@ -247,7 +277,7 @@ class Vocabulary(object):
"""
return item in self._word2idx
- def has_word(self, w):
+ def has_word(self, w:str):
r"""
检查词是否被记录::
@@ -255,7 +285,7 @@ class Vocabulary(object):
# equals to
has_abc = 'abc' in vocab
- :param item: the word
+ :param item: 输入的str类型的词
:return: ``True`` or ``False``
"""
return self.__contains__(w)
@@ -263,7 +293,7 @@ class Vocabulary(object):
@_check_build_vocab
def __getitem__(self, w):
r"""
- To support usage like::
+ 支持从字典中直接得到词语的index,例如::
vocab[w]
"""
@@ -275,18 +305,18 @@ class Vocabulary(object):
raise ValueError("word `{}` not in vocabulary".format(w))
@_check_build_vocab
- def index_dataset(self, *datasets, field_name, new_field_name=None):
+ def index_dataset(self, *datasets, field_name:Union[List, str], new_field_name:Union[List, str, None]=None):
r"""
- 将DataSet中对应field的词转为数字,Example::
+ 将 ``DataSet`` 中对应 field 的词转为数字,例如::
# remember to use `field_name`
vocab.index_dataset(train_data, dev_data, test_data, field_name='words')
- :param ~fastNLP.DataSet,List[~fastNLP.DataSet] datasets: 需要转index的一个或多个数据集
- :param list,str field_name: 需要转index的field, 若有多个 DataSet, 每个DataSet都必须有此 field.
- 目前支持 ``str`` , ``List[str]``
- :param list,str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field.
- Default: ``None``.
+ :param datasets: 其类型为 :class:`~fastNLP.core.dataset.DataSet` 或者 :class:`List` [ :class:`~fastNLP.core.dataset.DataSet` ],
+ 即需要处理的一个或多个数据集
+ :param field_name: 需要转为 index 的 field, 若有多个 DataSet, 每个 DataSet 都必须有此 field.
+ 目前支持 :class:`str` , :class:`List` [ :class:`str` ]
+ :param new_field_name: 保存结果的 field_name。 若为 ``None`` , 将覆盖原 field。
"""
def index_instance(field):
@@ -319,39 +349,44 @@ class Vocabulary(object):
for idx, dataset in enumerate(datasets):
if isinstance(dataset, DataSet):
- try:
- for f_n, n_f_n in zip(field_name, new_field_name):
- dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n)
- except Exception as e:
- logger.info("When processing the `{}` dataset, the following error occurred.".format(idx))
- raise e
+ ds_lst = [dataset]
+ elif _is_iterable(dataset):
+ ds_lst = list(dataset)
else:
- raise RuntimeError("Only DataSet type is allowed.")
+ raise TypeError(f"Only DataSet type is allowed, instead of {type(dataset)}.")
+ try:
+ for ds in ds_lst:
+ for f_n, n_f_n in zip(field_name, new_field_name):
+ ds.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n, progress_bar=None)
+ except Exception as e:
+ logger.error("When processing the `{}` dataset, the following error occurred.".format(idx))
+ raise e
return self
@property
def _no_create_word_length(self):
return len(self._no_create_word)
- def from_dataset(self, *datasets, field_name, no_create_entry_dataset=None):
+ def from_dataset(self, *datasets, field_name:Union[str,List[str]], no_create_entry_dataset=None):
r"""
使用dataset的对应field中词构建词典::
# remember to use `field_name`
- vocab.from_dataset(train_data1, train_data2, field_name='words')
-
- :param ~fastNLP.DataSet,List[~fastNLP.DataSet] datasets: 需要转index的一个或多个数据集
- :param str,List[str] field_name: 可为 ``str`` 或 ``List[str]`` .
- 构建词典所使用的 field(s), 支持一个或多个field,若有多个 DataSet, 每个DataSet都必须有这些field. 目前支持的field结构
- : ``str`` , ``List[str]``
- :param no_create_entry_dataset: 可以传入DataSet, List[DataSet]或者None(默认), 建议直接将非训练数据都传入到这个参数。该选项用在接下来的模型会使用pretrain
- 的embedding(包括glove, word2vec, elmo与bert)且会finetune的情况。如果仅使用来自于train的数据建立vocabulary,会导致test与dev
- 中的数据无法充分利用到来自于预训练embedding的信息,所以在建立词表的时候将test与dev考虑进来会使得最终的结果更好。
- 如果一个词出现在了train中,但是没在预训练模型中,embedding会为它用unk初始化,但它是单独的一个vector,如果
- finetune embedding的话,这个词在更新之后可能会有更好的表示; 而如果这个词仅出现在了dev或test中,那么就不能为它们单独建立vector,
- 而应该让它指向unk这个vector的值。所以只位于no_create_entry_dataset中的token,将首先从预训练的词表中寻找它的表示,
- 如果找到了,就使用该表示; 如果没有找到,则认为该词的表示应该为unk的表示。
- :return self:
+ vocab.from_dataset(train_data1, train_data2, field_name='words', no_create_entry_dataset=[test_data1, test_data2])
+
+ :param datasets: 其类型为 :class:`~fastNLP.core.dataset.DataSet` 或者 List[:class:`~fastNLP.core.dataset.DataSet`]。
+ :param field_name: 构建词典所使用的 field(s), 支持一个或多个 field,若有多个 DataSet, 每个 DataSet 都必须有这些 field.
+ 目前支持的field结构: ``str`` , ``List[str]``
+ :param no_create_entry_dataset: 可以传入 :class:`~fastNLP.core.dataset.DataSet`, :class:`List` [ :class:`~fastNLP.core.dataset.DataSet` ] 或者
+ ``None`` (默认),建议直接将非训练数据都传入到这个参数。该选项用于接下来的模型会使用预训练的 embedding (包括 ``glove``, ``word2vec`` ,
+ ``elmo`` 与 ``bert`` )且会 finetune 的情况。如果仅使用来自于训练集的数据建立词表,会导致测试集与验证集中的数据无法充分利用到来自于预训练
+ embedding 的信息,所以在建立词表的时候将测试集与验证集考虑进来会使得最终的结果更好。
+ 如果一个词出现在了训练集中,但是没在预训练模型中, embedding 会为它用 ```` 初始化;但如果它是单独的一个 vector ,并且 finetune embedding
+ 的话,这个词在更新之后可能会有更好的表示;而如果这个词仅出现在了验证集或者测试集中,那么就不能为它们单独建立 vector,而应该让它指向 ```` 这个
+ vector 的值。所以只位于 ``no_create_entry_dataset`` 中的 token 将首先从预训练的词表中寻找它的表示,如果找到了,就使用该表示; 如果没有找到,则认
+ 为该词的表示应该为 ```` 的表示。
+ :return: Vocabulary 自身
+
"""
if isinstance(field_name, str):
field_name = [field_name]
@@ -376,43 +411,49 @@ class Vocabulary(object):
for idx, dataset in enumerate(datasets):
if isinstance(dataset, DataSet):
- try:
- dataset.apply(construct_vocab)
- except BaseException as e:
- logger.error("When processing the `{}` dataset, the following error occurred:".format(idx))
- raise e
+ ds_lst = [dataset]
+ elif _is_iterable(dataset):
+ ds_lst = list(dataset)
else:
- raise TypeError("Only DataSet type is allowed.")
+ raise TypeError(f"Only DataSet type is allowed, instead of {type(dataset)}.")
+
+ try:
+ for ds in ds_lst:
+ ds.apply(construct_vocab, progress_bar=None)
+ except BaseException as e:
+ logger.error("When processing the `{}` dataset, the following error occurred:".format(idx))
+ raise e
if no_create_entry_dataset is not None:
partial_construct_vocab = partial(construct_vocab, no_create_entry=True)
if isinstance(no_create_entry_dataset, DataSet):
- no_create_entry_dataset.apply(partial_construct_vocab)
+ no_create_entry_dataset.apply(partial_construct_vocab, progress_bar=None)
elif isinstance(no_create_entry_dataset, list):
for dataset in no_create_entry_dataset:
if not isinstance(dataset, DataSet):
raise TypeError("Only DataSet type is allowed.")
- dataset.apply(partial_construct_vocab)
+ dataset.apply(partial_construct_vocab, progress_bar=None)
return self
- def _is_word_no_create_entry(self, word):
+ def _is_word_no_create_entry(self, word:str):
r"""
判断当前的word是否是不需要创建entry的,具体参见from_dataset的说明
- :param word: str
- :return: bool
+
+ :param word: 输入的str类型的词语
+ :return: bool值的判断结果
"""
return word in self._no_create_word
- def to_index(self, w):
+ def to_index(self, w:str):
r"""
- 将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出 ``ValueError`` ::
+ 将词转为数字。 若词不在词典中被记录, 将视为 `unknown`, 若 ``unknown=None`` , 将抛出 ``ValueError`` ::
index = vocab.to_index('abc')
# equals to
index = vocab['abc']
- :param str w: a word
- :return int index: the number
+ :param w: 需要输入的词语
+ :return: 词语 ``w`` 对应的 :class:`int`类型的 index
"""
return self.__getitem__(w)
@@ -420,7 +461,7 @@ class Vocabulary(object):
@_check_build_vocab
def unknown_idx(self):
r"""
- unknown 对应的数字.
+ 获得 ``unknown`` 对应的数字.
"""
if self.unknown is None:
return None
@@ -430,27 +471,27 @@ class Vocabulary(object):
@_check_build_vocab
def padding_idx(self):
r"""
- padding 对应的数字
+ 获得 ``padding`` 对应的数字
"""
if self.padding is None:
return None
return self._word2idx[self.padding]
@_check_build_vocab
- def to_word(self, idx):
+ def to_word(self, idx: int):
r"""
给定一个数字, 将其转为对应的词.
- :param int idx: the index
- :return str word: the word
+ :param idx:
+ :return: ``idx`` 对应的词
"""
return self._idx2word[idx]
def clear(self):
r"""
- 删除Vocabulary中的词表数据。相当于重新初始化一下。
+ 删除 :class:Vocabulary`` 中的词表数据。相当于重新初始化一下。
- :return:
+ :return: 自身
"""
self.word_count.clear()
self._word2idx = None
@@ -460,7 +501,8 @@ class Vocabulary(object):
return self
def __getstate__(self):
- r"""Use to prepare data for pickle.
+ r"""
+ 用来从 pickle 中加载 data
"""
len(self) # make sure vocab has been built
@@ -470,7 +512,8 @@ class Vocabulary(object):
return state
def __setstate__(self, state):
- r"""Use to restore state from pickle.
+ r"""
+ 支持 pickle 的保存,保存到 pickle 的 data state
"""
self.__dict__.update(state)
@@ -485,11 +528,11 @@ class Vocabulary(object):
for index in range(len(self._word2idx)):
yield self.to_word(index), index
- def save(self, filepath):
+ def save(self, filepath: Union[str, io.StringIO]):
r"""
+ 保存当前词表。
- :param str,io.StringIO filepath: Vocabulary的储存路径
- :return:
+ :param filepath: 词表储存路径
"""
if isinstance(filepath, io.IOBase):
assert filepath.writable()
@@ -500,7 +543,7 @@ class Vocabulary(object):
except Exception as e:
raise e
else:
- raise TypeError("Illegal `filepath`.")
+ raise TypeError("Illegal `path`.")
f.write(f'max_size\t{self.max_size}\n')
f.write(f'min_freq\t{self.min_freq}\n')
@@ -521,11 +564,12 @@ class Vocabulary(object):
f.close()
@staticmethod
- def load(filepath):
+ def load(filepath: Union[str,io.StringIO]):
r"""
+ 从文件路径中加载数据
- :param str,io.StringIO filepath: Vocabulary的读取路径
- :return: Vocabulary
+ :param filepath: 词表的读取路径
+ :return: 读取的 :class:`Vocabulary`
"""
if isinstance(filepath, io.IOBase):
assert filepath.writable()
@@ -536,7 +580,7 @@ class Vocabulary(object):
except Exception as e:
raise e
else:
- raise TypeError("Illegal `filepath`.")
+ raise TypeError("Illegal `path`.")
vocab = Vocabulary()
for line in f:
diff --git a/fastNLP/doc_utils.py b/fastNLP/doc_utils.py
deleted file mode 100644
index 3f7889e4..00000000
--- a/fastNLP/doc_utils.py
+++ /dev/null
@@ -1,54 +0,0 @@
-r"""undocumented
-用于辅助生成 fastNLP 文档的代码
-"""
-
-__all__ = []
-
-import inspect
-import sys
-
-
-def doc_process(m):
- for name, obj in inspect.getmembers(m):
- if inspect.isclass(obj) or inspect.isfunction(obj):
- if obj.__module__ != m.__name__:
- if obj.__doc__ is None:
- # print(name, obj.__doc__)
- pass
- else:
- module_name = obj.__module__
-
- # 识别并标注类和函数在不同层次中的位置
-
- while 1:
- defined_m = sys.modules[module_name]
- try:
- if "undocumented" not in defined_m.__doc__ and name in defined_m.__all__:
- obj.__doc__ = r"别名 :class:`" + m.__name__ + "." + name + "`" \
- + " :class:`" + module_name + "." + name + "`\n" + obj.__doc__
- break
- module_name = ".".join(module_name.split('.')[:-1])
- if module_name == m.__name__:
- # print(name, ": not found defined doc.")
- break
- except:
- print("Warning: Module {} lacks `__doc__`".format(module_name))
- break
-
- # 识别并标注基类,只有基类也在 fastNLP 中定义才显示
-
- if inspect.isclass(obj):
- for base in obj.__bases__:
- if base.__module__.startswith("fastNLP"):
- parts = base.__module__.split(".") + []
- module_name, i = "fastNLP", 1
- for i in range(len(parts) - 1):
- defined_m = sys.modules[module_name]
- try:
- if "undocumented" not in defined_m.__doc__ and name in defined_m.__all__:
- obj.__doc__ = r"基类 :class:`" + defined_m.__name__ + "." + base.__name__ + "` \n\n" + obj.__doc__
- break
- module_name += "." + parts[i + 1]
- except:
- print("Warning: Module {} lacks `__doc__`".format(module_name))
- break
diff --git a/fastNLP/embeddings/__init__.py b/fastNLP/embeddings/__init__.py
index dae75995..e69de29b 100644
--- a/fastNLP/embeddings/__init__.py
+++ b/fastNLP/embeddings/__init__.py
@@ -1,46 +0,0 @@
-r"""
-embeddings 模块主要用于从各种预训练的模型中获取词语的分布式表示,目前支持的预训练模型包括word2vec, glove, ELMO, BERT等。这里所有
-embedding的forward输入都是形状为 ``(batch_size, max_len)`` 的torch.LongTensor,输出都是 ``(batch_size, max_len, embedding_dim)`` 的
-torch.FloatTensor。所有的embedding都可以使用 `self.num_embedding` 获取最大的输入index范围, 用 `self.embeddig_dim` 或 `self.embed_size` 获取embedding的
-输出维度。
-"""
-
-__all__ = [
- "Embedding",
- "TokenEmbedding",
- "StaticEmbedding",
- "ElmoEmbedding",
- "BertEmbedding",
- "BertWordPieceEncoder",
-
- "RobertaEmbedding",
- "RobertaWordPieceEncoder",
-
- "TransformersEmbedding",
- "TransformersWordPieceEncoder",
-
- "GPT2Embedding",
- "GPT2WordPieceEncoder",
-
- "StackEmbedding",
- "LSTMCharEmbedding",
- "CNNCharEmbedding",
-
- "get_embeddings",
- "get_sinusoid_encoding_table"
-]
-
-from .embedding import Embedding, TokenEmbedding
-from .static_embedding import StaticEmbedding
-from .elmo_embedding import ElmoEmbedding
-from .bert_embedding import BertEmbedding, BertWordPieceEncoder
-from .roberta_embedding import RobertaEmbedding, RobertaWordPieceEncoder
-from .transformers_embedding import TransformersEmbedding, TransformersWordPieceEncoder
-from .gpt2_embedding import GPT2WordPieceEncoder, GPT2Embedding
-from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding
-from .stack_embedding import StackEmbedding
-from .utils import get_embeddings, get_sinusoid_encoding_table
-
-import sys
-from ..doc_utils import doc_process
-doc_process(sys.modules[__name__])
\ No newline at end of file
diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py
deleted file mode 100644
index 01e646a7..00000000
--- a/fastNLP/embeddings/bert_embedding.py
+++ /dev/null
@@ -1,658 +0,0 @@
-r"""
-.. todo::
- doc
-"""
-
-__all__ = [
- "BertEmbedding",
- "BertWordPieceEncoder"
-]
-
-import os
-import warnings
-from itertools import chain
-from functools import partial
-import json
-import numpy as np
-import torch
-from torch import nn
-
-from .contextual_embedding import ContextualEmbedding
-from ..core import logger
-from ..core.vocabulary import Vocabulary
-from ..io.file_utils import PRETRAINED_BERT_MODEL_DIR
-from ..modules.encoder.bert import BertModel
-from ..modules.tokenizer import BertTokenizer
-
-# TODO 需要重新修改,使得encoder可以直接读取embedding的权重
-VOCAB_NAME = 'vocab.txt'
-BERT_EMBED_HYPER = 'bert_hyper.json'
-BERT_EMBED_FOLDER = 'bert'
-BERT_ENCODER_HYPER = 'bert_hyper.json'
-BERT_ENCODER_FOLDER = 'bert'
-
-
-class BertEmbedding(ContextualEmbedding):
- r"""
- 使用BERT对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于
- 预训练的bert模型长度限制为512个token,而因为输入的word是未进行word piece分割的(word piece的分割有BertEmbedding在输入word
- 时切分),在分割之后长度可能会超过最大长度限制。
-
- BertEmbedding可以支持自动下载权重,当前支持的模型:
- en: base-cased
- en-base-uncased:
- en-large-cased-wwm:
- en-large-cased:
- en-large-uncased:
- en-large-uncased-wwm
- cn: 中文BERT wwm by HIT
- cn-base: 中文BERT base-chinese
- cn-wwm-ext: 中文BERT wwm by HIT with extra data pretrain.
- multi-base-cased: multilingual cased
- multi-base-uncased: multilingual uncased
-
- Example::
-
- >>> import torch
- >>> from fastNLP import Vocabulary
- >>> from fastNLP.embeddings import BertEmbedding
- >>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
- >>> embed = BertEmbedding(vocab, model_dir_or_name='en-base-uncased', requires_grad=False, layers='4,-2,-1')
- >>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]])
- >>> outputs = embed(words)
- >>> outputs.size()
- >>> # torch.Size([1, 5, 2304])
- """
-
- def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1',
- pool_method: str = 'first', word_dropout=0, dropout=0, include_cls_sep: bool = False,
- pooled_cls=True, requires_grad: bool = True, auto_truncate: bool = False, **kwargs):
- r"""
-
- :param ~fastNLP.Vocabulary vocab: 词表
- :param str model_dir_or_name: 模型所在目录或者模型的名称。当传入模型所在目录时,目录中应该包含一个词表文件(以.txt作为后缀名),
- 权重文件(以.bin作为文件后缀名), 配置文件(以.json作为后缀名)。
- :param str layers: 输出embedding表示来自于哪些层,不同层的结果按照layers中的顺序在最后一维concat起来。以','隔开层数,层的序号是
- 从0开始,可以以负数去索引倒数几层。 layer=0为embedding层(包括wordpiece embedding,
- position embedding和segment embedding)
- :param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces
- 中计算得到它对应的表示。支持 ``last`` , ``first`` , ``avg`` , ``max``。
- :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
- :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
- :param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样
- 会使得word embedding的结果比输入的结果长两个token。如果该值为True,则在使用 :class::StackEmbedding 可能会与其它类型的
- embedding长度不匹配。
- :param bool pooled_cls: 返回的[CLS]是否使用预训练中的BertPool映射一下,仅在include_cls_sep时有效。如果下游任务只取[CLS]做预测,
- 一般该值为True。
- :param bool requires_grad: 是否需要gradient以更新Bert的权重。
- :param bool auto_truncate: 当句子words拆分为word pieces长度超过bert最大允许长度(一般为512), 自动截掉拆分后的超过510个
- word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS]
- 来进行分类的任务将auto_truncate置为True。
- :param kwargs:
- int min_freq: 小于该次数的词会被unk代替, 默认为1
- """
- super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
-
- if word_dropout > 0:
- assert vocab.unknown != None, "When word_drop>0, Vocabulary must contain the unknown token."
-
- if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
- if 'cn' in model_dir_or_name.lower() and pool_method not in ('first', 'last'):
- logger.warning("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve"
- " faster speed.")
- warnings.warn("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve"
- " faster speed.")
-
- self._word_sep_index = -100
- if '[SEP]' in vocab:
- self._word_sep_index = vocab['[SEP]']
- self._word_cls_index = -100
- if '[CLS]' in vocab:
- self._word_cls_index = vocab['[CLS]']
-
- min_freq = kwargs.pop('min_freq', 1)
- self._min_freq = min_freq
- self.model = _BertWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
- pool_method=pool_method, include_cls_sep=include_cls_sep,
- pooled_cls=pooled_cls, min_freq=min_freq, auto_truncate=auto_truncate,
- **kwargs)
-
- self.requires_grad = requires_grad
- self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size
-
- def _delete_model_weights(self):
- del self.model
-
- def forward(self, words):
- r"""
- 计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要
- 删除这两个token的表示。
-
- :param torch.LongTensor words: [batch_size, max_len]
- :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
- """
- words = self.drop_word(words)
- outputs = self._get_sent_reprs(words)
- if outputs is not None:
- return self.dropout(outputs)
- outputs = self.model(words)
- outputs = torch.cat([*outputs], dim=-1)
-
- return self.dropout(outputs)
-
- def drop_word(self, words):
- r"""
- 按照设定随机将words设置为unknown_index。
-
- :param torch.LongTensor words: batch_size x max_len
- :return:
- """
- if self.word_dropout > 0 and self.training:
- with torch.no_grad():
- mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
- mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
- pad_mask = words.ne(self._word_pad_index)
- mask = pad_mask.__and__(mask) # pad的位置不为unk
- if self._word_sep_index!=-100:
- not_sep_mask = words.ne(self._word_sep_index)
- mask = mask.__and__(not_sep_mask)
- if self._word_cls_index!=-100:
- not_cls_mask = words.ne(self._word_cls_index)
- mask = mask.__and__(not_cls_mask)
- words = words.masked_fill(mask, self._word_unk_index)
- return words
-
- def save(self, folder):
- """
- 将embedding保存到folder这个目录下,将会保存三个文件vocab.txt, bert_embed_hyper.txt, bert_embed/, 其中bert_embed下包含
- config.json,pytorch_model.bin,vocab.txt三个文件(该folder下的数据也可以直接被BERTModel读取)
-
- :param str folder:
- :return:
- """
- os.makedirs(folder, exist_ok=True)
-
- self.get_word_vocab().save(os.path.join(folder, VOCAB_NAME))
-
- hyper = {}
- hyper['min_freq'] = self._min_freq
- hyper['layers'] = ','.join(map(str, self.model.layers))
- hyper['pool_method'] = self.model.pool_method
- hyper['dropout'] = self.dropout_layer.p
- hyper['word_dropout'] = self.word_dropout
- hyper['include_cls_sep'] = self.model.include_cls_sep
- hyper['pooled_cls'] = self.model.pooled_cls
- hyper['auto_truncate'] = self.model.auto_truncate
- hyper['requires_grad'] = bool(self.requires_grad)
-
- with open(os.path.join(folder, BERT_EMBED_HYPER), 'w', encoding='utf-8') as f:
- json.dump(hyper, f, indent=2)
-
- os.makedirs(os.path.join(folder, BERT_EMBED_FOLDER), exist_ok=True)
- self.model.save(os.path.join(folder, BERT_EMBED_FOLDER))
- logger.debug(f"BERTEmbedding has been saved in {folder}")
-
- @classmethod
- def load(cls, folder):
- """
- 给定一个folder, 需要包含以下三个内容vocab.txt, bert_embed_hyper.txt, bert_embed/
-
- :param str folder:
- :return:
- """
- for name in [VOCAB_NAME, BERT_EMBED_FOLDER, BERT_EMBED_HYPER]:
- assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}."
-
- vocab = Vocabulary.load(os.path.join(folder, VOCAB_NAME))
-
- with open(os.path.join(folder, BERT_EMBED_HYPER), 'r', encoding='utf-8') as f:
- hyper = json.load(f)
-
- model_dir_or_name = os.path.join(os.path.join(folder, BERT_EMBED_FOLDER))
-
- bert_embed = cls(vocab=vocab, model_dir_or_name=model_dir_or_name, **hyper)
- return bert_embed
-
-
-class BertWordPieceEncoder(nn.Module):
- r"""
- 读取bert模型,读取之后调用index_dataset方法在dataset中生成word_pieces这一列。
-
- BertWordPieceEncoder可以支持自动下载权重,当前支持的模型:
- en: base-cased
- en-large-cased-wwm:
- en-large-cased:
- en-large-uncased:
- en-large-uncased-wwm
- cn: 中文BERT wwm by HIT
- cn-base: 中文BERT base-chinese
- cn-wwm-ext: 中文BERT wwm by HIT with extra data pretrain.
- multi-base-cased: multilingual cased
- multi-base-uncased: multilingual uncased
-
- """
-
- def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False,
- word_dropout=0, dropout=0, requires_grad: bool = True, **kwargs):
- r"""
-
- :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased``
- :param str layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层。layer=0为embedding层(包括wordpiece embedding,
- position embedding和segment embedding)
- :param bool pooled_cls: 返回的句子开头的[CLS]是否使用预训练中的BertPool映射一下。如果下游任务取[CLS]做预测,一般该值为True。
- :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
- :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
- :param bool requires_grad: 是否需要gradient。
- """
- super().__init__()
-
- self.model = _BertWordPieceModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls)
- self._sep_index = self.model._sep_index
- self._cls_index = self.model._cls_index
- self._wordpiece_pad_index = self.model._wordpiece_pad_index
- self._wordpiece_unk_index = self.model._wordpiece_unknown_index
- self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size
- self.requires_grad = requires_grad
- self.word_dropout = word_dropout
- self.dropout_layer = nn.Dropout(dropout)
-
- @property
- def embed_size(self):
- return self._embed_size
-
- @property
- def embedding_dim(self):
- return self._embed_size
-
- @property
- def num_embedding(self):
- return self.model.encoder.config.vocab_size
-
- def index_datasets(self, *datasets, field_name, add_cls_sep=True):
- r"""
- 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了
- bert的pad value。
-
- :param ~fastNLP.DataSet datasets: DataSet对象
- :param str field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。
- :param bool add_cls_sep: 如果首尾不是[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP]。
- :return:
- """
-
- self.model.index_datasets(*datasets, field_name=field_name, add_cls_sep=add_cls_sep)
-
- def forward(self, word_pieces, token_type_ids=None):
- r"""
- 计算words的bert embedding表示。传入的words中应该自行包含[CLS]与[SEP]的tag。
-
- :param words: batch_size x max_len
- :param token_type_ids: batch_size x max_len, 用于区分前一句和后一句话. 如果不传入,则自动生成(大部分情况,都不需要输入),
- 第一个[SEP]及之前为0, 第二个[SEP]及到第一个[SEP]之间为1; 第三个[SEP]及到第二个[SEP]之间为0,依次往后推。
- :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
- """
- if token_type_ids is None:
- with torch.no_grad():
- sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len
- sep_mask_cumsum = sep_mask.long().flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
- token_type_ids = sep_mask_cumsum.fmod(2)
- token_type_ids = token_type_ids[:, :1].__xor__(token_type_ids) # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0
-
- word_pieces = self.drop_word(word_pieces)
- outputs = self.model(word_pieces, token_type_ids)
- outputs = torch.cat([*outputs], dim=-1)
-
- return self.dropout_layer(outputs)
-
- def drop_word(self, words):
- r"""
- 按照设定随机将words设置为unknown_index。
-
- :param torch.LongTensor words: batch_size x max_len
- :return:
- """
- if self.word_dropout > 0 and self.training:
- with torch.no_grad():
- not_sep_mask = words.ne(self._sep_index)
- not_cls_mask = words.ne(self._cls_index)
- replaceable_mask = not_sep_mask.__and__(not_cls_mask)
- mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
- mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
- pad_mask = words.ne(self._wordpiece_pad_index)
- mask = pad_mask.__and__(mask).__and__(replaceable_mask) # pad的位置不为unk
- words = words.masked_fill(mask, self._wordpiece_unk_index)
- return words
-
- def save(self, folder):
- """
- 会在folder下创建两个文件bert_encoder_hyper.json与bert_encoder/, bert_encoder下包含三个文件config.json,
- pytorch_model.bin,vocab.txt三个文件(该folder下的数据也可以直接被BERTModel读取)
-
- :param str folder:
- :return:
- """
- os.makedirs(folder, exist_ok=True)
-
- hyper = {}
- hyper['layers'] = ','.join(map(str, self.model.layers))
- hyper['dropout'] = self.dropout_layer.p
- hyper['word_dropout'] = self.word_dropout
- hyper['pooled_cls'] = self.model.pooled_cls
- hyper['requires_grad'] = bool(self.requires_grad)
-
- with open(os.path.join(folder, BERT_ENCODER_HYPER), 'w', encoding='utf-8') as f:
- json.dump(hyper, f, indent=2)
-
- os.makedirs(os.path.join(folder, BERT_ENCODER_FOLDER), exist_ok=True)
- self.model.save(os.path.join(folder, BERT_ENCODER_FOLDER))
- logger.debug(f"BertWordPieceEncoder has been saved in {folder}")
-
- @classmethod
- def load(cls, folder):
- """
- 会在folder下创建两个文件bert_encoder_hyper.json与bert_encoder/, bert_encoder下包含三个文件
-
- :param folder:
- :return:
- """
- for name in [BERT_ENCODER_HYPER, BERT_ENCODER_FOLDER]:
- assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}."
-
- with open(os.path.join(folder, BERT_ENCODER_HYPER), 'r', encoding='utf-8') as f:
- hyper = json.load(f)
-
- model_dir_or_name = os.path.join(os.path.join(folder, BERT_ENCODER_FOLDER))
-
- bert_encoder = cls(model_dir_or_name=model_dir_or_name, **hyper)
- return bert_encoder
-
-
-class _BertWordModel(nn.Module):
- def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first',
- include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2,
- **kwargs):
- super().__init__()
-
- if isinstance(layers, list):
- self.layers = [int(l) for l in layers]
- elif isinstance(layers, str):
- if layers.lower() == 'all':
- self.layers = None
- else:
- self.layers = list(map(int, layers.split(',')))
- else:
- raise TypeError("`layers` only supports str or list[int]")
-
- neg_num_output_layer = -16384
- pos_num_output_layer = 0
- if self.layers is None:
- neg_num_output_layer = -1
- else:
- for layer in self.layers:
- if layer < 0:
- neg_num_output_layer = max(layer, neg_num_output_layer)
- else:
- pos_num_output_layer = max(layer, pos_num_output_layer)
-
- self.tokenizer = BertTokenizer.from_pretrained(model_dir_or_name)
- self.encoder = BertModel.from_pretrained(model_dir_or_name,
- neg_num_output_layer=neg_num_output_layer,
- pos_num_output_layer=pos_num_output_layer,
- **kwargs)
- self._max_position_embeddings = self.encoder.config.max_position_embeddings
- # 检查encoder_layer_number是否合理
- encoder_layer_number = len(self.encoder.encoder.layer)
- if self.layers is None:
- self.layers = [idx for idx in range(encoder_layer_number + 1)]
- logger.info(f'Bert Model will return {len(self.layers)} layers (layer-0 '
- f'is embedding result): {self.layers}')
- assert len(self.layers) > 0, "There is no layer selected!"
- for layer in self.layers:
- if layer < 0:
- assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
- f"a bert model with {encoder_layer_number} layers."
- else:
- assert layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
- f"a bert model with {encoder_layer_number} layers."
-
- assert pool_method in ('avg', 'max', 'first', 'last')
- self.pool_method = pool_method
- self.include_cls_sep = include_cls_sep
- self.pooled_cls = pooled_cls
- self.auto_truncate = auto_truncate
-
- # 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP]
- self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids
-
- word_to_wordpieces = []
- word_pieces_lengths = []
- for word, index in vocab:
- if index == vocab.padding_idx: # pad是个特殊的符号
- word = '[PAD]'
- elif index == vocab.unknown_idx:
- word = '[UNK]'
- elif vocab.word_count[word] < min_freq:
- word = '[UNK]'
- word_pieces = self.tokenizer.wordpiece_tokenizer.tokenize(word)
- word_pieces = self.tokenizer.convert_tokens_to_ids(word_pieces)
- word_to_wordpieces.append(word_pieces)
- word_pieces_lengths.append(len(word_pieces))
- self._cls_index = self.tokenizer.vocab['[CLS]']
- self._sep_index = self.tokenizer.vocab['[SEP]']
- self._word_pad_index = vocab.padding_idx
- self._wordpiece_pad_index = self.tokenizer.vocab['[PAD]'] # 需要用于生成word_piece
- self.word_to_wordpieces = np.array(word_to_wordpieces, dtype=object)
- self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths))
- logger.debug("Successfully generate word pieces.")
-
- def forward(self, words):
- r"""
-
- :param words: torch.LongTensor, batch_size x max_len
- :return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size
- """
- with torch.no_grad():
- batch_size, max_word_len = words.size()
- word_mask = words.ne(self._word_pad_index) # 为1的地方有word
- seq_len = word_mask.sum(dim=-1)
- batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False),
- 0) # batch_size x max_len
- word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size
- max_word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding)
- if max_word_piece_length + 2 > self._max_position_embeddings:
- if self.auto_truncate:
- word_pieces_lengths = word_pieces_lengths.masked_fill(
- word_pieces_lengths + 2 > self._max_position_embeddings,
- self._max_position_embeddings - 2)
- else:
- raise RuntimeError(
- "After split words into word pieces, the lengths of word pieces are longer than the "
- f"maximum allowed sequence length:{self._max_position_embeddings} of bert. You can set "
- f"`auto_truncate=True` for BertEmbedding to automatically truncate overlong input.")
-
- # +2是由于需要加入[CLS]与[SEP]
- word_pieces = words.new_full((batch_size, min(max_word_piece_length + 2, self._max_position_embeddings)),
- fill_value=self._wordpiece_pad_index)
- attn_masks = torch.zeros_like(word_pieces)
- # 1. 获取words的word_pieces的id,以及对应的span范围
- word_indexes = words.cpu().numpy()
- for i in range(batch_size):
- word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i, :seq_len[i]]]))
- if self.auto_truncate and len(word_pieces_i) > self._max_position_embeddings - 2:
- word_pieces_i = word_pieces_i[:self._max_position_embeddings - 2]
- word_pieces[i, 1:word_pieces_lengths[i] + 1] = torch.LongTensor(word_pieces_i)
- attn_masks[i, :word_pieces_lengths[i] + 2].fill_(1)
- # 添加[cls]和[sep]
- word_pieces[:, 0].fill_(self._cls_index)
- batch_indexes = torch.arange(batch_size).to(words)
- word_pieces[batch_indexes, word_pieces_lengths + 1] = self._sep_index
- if self._has_sep_in_vocab: # 但[SEP]在vocab中出现应该才会需要token_ids
- sep_mask = word_pieces.eq(self._sep_index).long() # batch_size x max_len
- sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
- token_type_ids = sep_mask_cumsum.fmod(2)
- token_type_ids = token_type_ids[:, :1].__xor__(token_type_ids) # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0
- else:
- token_type_ids = torch.zeros_like(word_pieces)
- # 2. 获取hidden的结果,根据word_pieces进行对应的pool计算
- # all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...]
- bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids,
- attention_mask=attn_masks,
- output_all_encoded_layers=True)
- # output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size
-
- if self.include_cls_sep:
- s_shift = 1
- outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2,
- bert_outputs[-1].size(-1))
-
- else:
- s_shift = 0
- outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len,
- bert_outputs[-1].size(-1))
- batch_word_pieces_cum_length = batch_word_pieces_length.new_zeros(batch_size, max_word_len + 1)
- batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len
-
- if self.pool_method == 'first':
- batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()]
- batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(max_word_piece_length), 0)
- _batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
- elif self.pool_method == 'last':
- batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, 1:seq_len.max()+1] - 1
- batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(max_word_piece_length), 0)
- _batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
-
- for l_index, l in enumerate(self.layers):
- output_layer = bert_outputs[l]
- real_word_piece_length = output_layer.size(1) - 2
- if max_word_piece_length > real_word_piece_length: # 如果实际上是截取出来的
- paddings = output_layer.new_zeros(batch_size,
- max_word_piece_length - real_word_piece_length,
- output_layer.size(2))
- output_layer = torch.cat((output_layer, paddings), dim=1).contiguous()
- # 从word_piece collapse到word的表示
- truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size
- if self.pool_method == 'first':
- tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length]
- tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0)
- outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp
-
- elif self.pool_method == 'last':
- tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length]
- tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0)
- outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp
- elif self.pool_method == 'max':
- for i in range(batch_size):
- for j in range(seq_len[i]):
- start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1]
- outputs[l_index, i, j + s_shift], _ = torch.max(truncate_output_layer[i, start:end], dim=-2)
- else:
- for i in range(batch_size):
- for j in range(seq_len[i]):
- start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1]
- outputs[l_index, i, j + s_shift] = torch.mean(truncate_output_layer[i, start:end], dim=-2)
- if self.include_cls_sep:
- if l in (len(bert_outputs) - 1, -1) and self.pooled_cls:
- outputs[l_index, :, 0] = pooled_cls
- else:
- outputs[l_index, :, 0] = output_layer[:, 0]
- outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, word_pieces_lengths + s_shift]
-
- # 3. 最终的embedding结果
- return outputs
-
- def save(self, folder):
- """
- 给定一个folder保存pytorch_model.bin, config.json, vocab.txt
-
- :param str folder:
- :return:
- """
- self.tokenizer.save_pretrained(folder)
- self.encoder.save_pretrained(folder)
-
-
-class _BertWordPieceModel(nn.Module):
- r"""
- 这个模块用于直接计算word_piece的结果.
-
- """
-
- def __init__(self, model_dir_or_name: str, layers: str = '-1', pooled_cls: bool=False):
- super().__init__()
-
- self.tokenizer = BertTokenizer.from_pretrained(model_dir_or_name)
- self.encoder = BertModel.from_pretrained(model_dir_or_name)
- # 检查encoder_layer_number是否合理
- encoder_layer_number = len(self.encoder.encoder.layer)
-
- if isinstance(layers, list):
- self.layers = [int(l) for l in layers]
- elif isinstance(layers, str):
- self.layers = list(map(int, layers.split(',')))
- else:
- raise TypeError("`layers` only supports str or list[int]")
-
- for layer in self.layers:
- if layer < 0:
- assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
- f"a bert model with {encoder_layer_number} layers."
- else:
- assert layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
- f"a bert model with {encoder_layer_number} layers."
-
- self._cls_index = self.tokenizer.cls_index
- self._sep_index = self.tokenizer.sep_index
- self._wordpiece_unknown_index = self.tokenizer.unk_index
- self._wordpiece_pad_index = self.tokenizer.pad_index # 需要用于生成word_piece
- self.pooled_cls = pooled_cls
-
- def index_datasets(self, *datasets, field_name, add_cls_sep=True):
- r"""
- 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是
- [CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。
-
- :param datasets: DataSet对象
- :param field_name: 基于哪一列index
- :return:
- """
-
- encode_func = partial(self.tokenizer.encode, add_special_tokens=add_cls_sep)
-
- for index, dataset in enumerate(datasets):
- try:
- dataset.apply_field(encode_func, field_name=field_name, new_field_name='word_pieces',
- is_input=True)
- dataset.set_pad_val('word_pieces', self._wordpiece_pad_index)
- except Exception as e:
- logger.error(f"Exception happens when processing the {index} dataset.")
- raise e
-
- def forward(self, word_pieces, token_type_ids=None):
- r"""
-
- :param word_pieces: torch.LongTensor, batch_size x max_len
- :param token_type_ids: torch.LongTensor, batch_size x max_len
- :return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size
- """
- batch_size, max_len = word_pieces.size()
-
- attn_masks = word_pieces.ne(self._wordpiece_pad_index)
- bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks,
- output_all_encoded_layers=True)
- # output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size
- outputs = bert_outputs[0].new_zeros((len(self.layers), batch_size, max_len, bert_outputs[0].size(-1)))
- for l_index, l in enumerate(self.layers):
- bert_output = bert_outputs[l]
- if l in (len(bert_outputs)-1, -1) and self.pooled_cls:
- bert_output[:, 0] = pooled_cls
- outputs[l_index] = bert_output
- return outputs
-
- def save(self, folder):
- """
- 给定一个folder保存pytorch_model.bin, config.json, vocab.txt
-
- :param folder:
- :return:
- """
- self.tokenizer.save_pretrained(folder)
- self.encoder.save_pretrained(folder)
diff --git a/fastNLP/embeddings/char_embedding.py b/fastNLP/embeddings/char_embedding.py
deleted file mode 100644
index a2996ae2..00000000
--- a/fastNLP/embeddings/char_embedding.py
+++ /dev/null
@@ -1,284 +0,0 @@
-r"""
-该文件中主要包含的是character的Embedding,包括基于CNN与LSTM的character Embedding。与其它Embedding一样,这里的Embedding输入也是
-词的index而不需要使用词语中的char的index来获取表达。
-"""
-
-__all__ = [
- "CNNCharEmbedding",
- "LSTMCharEmbedding"
-]
-
-from typing import List
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from .embedding import TokenEmbedding
-from .static_embedding import StaticEmbedding
-from .utils import _construct_char_vocab_from_vocab
-from .utils import get_embeddings
-from ..core import logger
-from ..core.vocabulary import Vocabulary
-from ..modules.encoder.lstm import LSTM
-
-
-class CNNCharEmbedding(TokenEmbedding):
- r"""
- 使用CNN生成character embedding。CNN的结构为, embed(x) -> Dropout(x) -> CNN(x) -> activation(x) -> pool -> fc -> Dropout.
- 不同的kernel大小的fitler结果是concat起来然后通过一层fully connected layer, 然后输出word的表示。
-
- Example::
-
- >>> import torch
- >>> from fastNLP import Vocabulary
- >>> from fastNLP.embeddings import CNNCharEmbedding
- >>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
- >>> embed = CNNCharEmbedding(vocab, embed_size=50)
- >>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]])
- >>> outputs = embed(words)
- >>> outputs.size()
- >>> # torch.Size([1, 5,50])
-
- """
-
- def __init__(self, vocab: Vocabulary, embed_size: int = 50, char_emb_size: int = 50, word_dropout: float = 0,
- dropout: float = 0, filter_nums: List[int] = (40, 30, 20), kernel_sizes: List[int] = (5, 3, 1),
- pool_method: str = 'max', activation='relu', min_char_freq: int = 2, pre_train_char_embed: str = None,
- requires_grad:bool=True, include_word_start_end:bool=True):
- r"""
-
- :param vocab: 词表
- :param embed_size: 该CNNCharEmbedding的输出维度大小,默认值为50.
- :param char_emb_size: character的embed的维度。character是从vocab中生成的。默认值为50.
- :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
- :param float dropout: 以多大的概率drop分布式表示与char embedding的输出。
- :param filter_nums: filter的数量. 长度需要和kernels一致。默认值为[40, 30, 20].
- :param kernel_sizes: kernel的大小. 默认值为[5, 3, 1].
- :param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'.
- :param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数.
- :param min_char_freq: character的最少出现次数。默认值为2.
- :param pre_train_char_embed: 可以有两种方式调用预训练好的character embedding:第一种是传入embedding文件夹
- (文件夹下应该只有一个以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,
- 没有的话将自动下载。如果输入为None则使用embedding_dim的维度随机初始化一个embedding.
- :param requires_grad: 是否更新权重
- :param include_word_start_end: 是否在每个word开始的character前和结束的character增加特殊标示符号;
- """
- super(CNNCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
-
- for kernel in kernel_sizes:
- assert kernel % 2 == 1, "Only odd kernel is allowed."
-
- assert pool_method in ('max', 'avg')
- self.pool_method = pool_method
- # activation function
- if isinstance(activation, str):
- if activation.lower() == 'relu':
- self.activation = F.relu
- elif activation.lower() == 'sigmoid':
- self.activation = F.sigmoid
- elif activation.lower() == 'tanh':
- self.activation = F.tanh
- elif activation is None:
- self.activation = lambda x: x
- elif callable(activation):
- self.activation = activation
- else:
- raise Exception(
- "Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]")
-
- logger.info("Start constructing character vocabulary.")
- # 建立char的词表
- self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq,
- include_word_start_end=include_word_start_end)
- self.char_pad_index = self.char_vocab.padding_idx
- logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.")
- # 对vocab进行index
- max_word_len = max(map(lambda x: len(x[0]), vocab))
- if include_word_start_end:
- max_word_len += 2
- self.register_buffer('words_to_chars_embedding', torch.full((len(vocab), max_word_len),
- fill_value=self.char_pad_index, dtype=torch.long))
- self.register_buffer('word_lengths', torch.zeros(len(vocab)).long())
- for word, index in vocab:
- # if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了。修改为不区分pad, 这样所有的也是同一个embed
- if include_word_start_end:
- word = [''] + list(word) + ['']
- self.words_to_chars_embedding[index, :len(word)] = \
- torch.LongTensor([self.char_vocab.to_index(c) for c in word])
- self.word_lengths[index] = len(word)
- # self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)
- if pre_train_char_embed:
- self.char_embedding = StaticEmbedding(self.char_vocab, model_dir_or_name=pre_train_char_embed)
- else:
- self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size))
-
- self.convs = nn.ModuleList([nn.Conv1d(
- self.char_embedding.embedding_dim, filter_nums[i], kernel_size=kernel_sizes[i], bias=True,
- padding=kernel_sizes[i] // 2)
- for i in range(len(kernel_sizes))])
- self._embed_size = embed_size
- self.fc = nn.Linear(sum(filter_nums), embed_size)
- self.requires_grad = requires_grad
-
- def forward(self, words):
- r"""
- 输入words的index后,生成对应的words的表示。
-
- :param words: [batch_size, max_len]
- :return: [batch_size, max_len, embed_size]
- """
- words = self.drop_word(words)
- batch_size, max_len = words.size()
- chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len
- word_lengths = self.word_lengths[words] # batch_size x max_len
- max_word_len = word_lengths.max()
- chars = chars[:, :, :max_word_len]
- # 为1的地方为mask
- chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了
- chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size
- chars = self.dropout(chars)
- reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1)
- reshaped_chars = reshaped_chars.transpose(1, 2) # B' x E x M
- conv_chars = [conv(reshaped_chars).transpose(1, 2).reshape(batch_size, max_len, max_word_len, -1)
- for conv in self.convs]
- conv_chars = torch.cat(conv_chars, dim=-1).contiguous() # B x max_len x max_word_len x sum(filters)
- conv_chars = self.activation(conv_chars)
- if self.pool_method == 'max':
- conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf'))
- chars, _ = torch.max(conv_chars, dim=-2) # batch_size x max_len x sum(filters)
- else:
- conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), 0)
- chars = torch.sum(conv_chars, dim=-2) / chars_masks.eq(False).sum(dim=-1, keepdim=True).float()
- chars = self.fc(chars)
- return self.dropout(chars)
-
-
-class LSTMCharEmbedding(TokenEmbedding):
- r"""
- 使用LSTM的方式对character进行encode. embed(x) -> Dropout(x) -> LSTM(x) -> activation(x) -> pool -> Dropout
-
- Example::
-
- >>> import torch
- >>> from fastNLP import Vocabulary
- >>> from fastNLP.embeddings import LSTMCharEmbedding
- >>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
- >>> embed = LSTMCharEmbedding(vocab, embed_size=50)
- >>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]])
- >>> outputs = embed(words)
- >>> outputs.size()
- >>> # torch.Size([1, 5,50])
-
- """
-
- def __init__(self, vocab: Vocabulary, embed_size: int = 50, char_emb_size: int = 50, word_dropout: float = 0,
- dropout: float = 0, hidden_size=50, pool_method: str = 'max', activation='relu',
- min_char_freq: int = 2, bidirectional=True, pre_train_char_embed: str = None,
- requires_grad:bool=True, include_word_start_end:bool=True):
- r"""
-
- :param vocab: 词表
- :param embed_size: LSTMCharEmbedding的输出维度。默认值为50.
- :param char_emb_size: character的embedding的维度。默认值为50.
- :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
- :param dropout: 以多大概率drop character embedding的输出以及最终的word的输出。
- :param hidden_size: LSTM的中间hidden的大小,如果为bidirectional的,hidden会除二,默认为50.
- :param pool_method: 支持'max', 'avg'。
- :param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数.
- :param min_char_freq: character的最小出现次数。默认值为2.
- :param bidirectional: 是否使用双向的LSTM进行encode。默认值为True。
- :param pre_train_char_embed: 可以有两种方式调用预训练好的character embedding:第一种是传入embedding文件夹
- (文件夹下应该只有一个以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,
- 没有的话将自动下载。如果输入为None则使用embedding_dim的维度随机初始化一个embedding.
- :param requires_grad: 是否更新权重
- :param include_word_start_end: 是否在每个word开始的character前和结束的character增加特殊标示符号;
- """
- super(LSTMCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
-
- assert hidden_size % 2 == 0, "Only even kernel is allowed."
-
- assert pool_method in ('max', 'avg')
- self.pool_method = pool_method
- # activation function
- if isinstance(activation, str):
- if activation.lower() == 'relu':
- self.activation = F.relu
- elif activation.lower() == 'sigmoid':
- self.activation = F.sigmoid
- elif activation.lower() == 'tanh':
- self.activation = F.tanh
- elif activation is None:
- self.activation = lambda x: x
- elif callable(activation):
- self.activation = activation
- else:
- raise Exception(
- "Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]")
-
- logger.info("Start constructing character vocabulary.")
- # 建立char的词表
- self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq,
- include_word_start_end=include_word_start_end)
- self.char_pad_index = self.char_vocab.padding_idx
- logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.")
- # 对vocab进行index
- max_word_len = max(map(lambda x: len(x[0]), vocab))
- if include_word_start_end:
- max_word_len += 2
- self.register_buffer('words_to_chars_embedding', torch.full((len(vocab), max_word_len),
- fill_value=self.char_pad_index, dtype=torch.long))
- self.register_buffer('word_lengths', torch.zeros(len(vocab)).long())
- for word, index in vocab:
- # if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了. 修改为不区分pad与否
- if include_word_start_end:
- word = [''] + list(word) + ['']
- self.words_to_chars_embedding[index, :len(word)] = \
- torch.LongTensor([self.char_vocab.to_index(c) for c in word])
- self.word_lengths[index] = len(word)
- if pre_train_char_embed:
- self.char_embedding = StaticEmbedding(self.char_vocab, pre_train_char_embed)
- else:
- self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size))
-
- self.fc = nn.Linear(hidden_size, embed_size)
- hidden_size = hidden_size // 2 if bidirectional else hidden_size
-
- self.lstm = LSTM(self.char_embedding.embedding_dim, hidden_size, bidirectional=bidirectional, batch_first=True)
- self._embed_size = embed_size
- self.bidirectional = bidirectional
- self.requires_grad = requires_grad
-
- def forward(self, words):
- r"""
- 输入words的index后,生成对应的words的表示。
-
- :param words: [batch_size, max_len]
- :return: [batch_size, max_len, embed_size]
- """
- words = self.drop_word(words)
- batch_size, max_len = words.size()
- chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len
- word_lengths = self.word_lengths[words] # batch_size x max_len
- max_word_len = word_lengths.max()
- chars = chars[:, :, :max_word_len]
- # 为mask的地方为1
- chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了
- chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size
- chars = self.dropout(chars)
- reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1)
- char_seq_len = chars_masks.eq(False).sum(dim=-1).reshape(batch_size * max_len)
- lstm_chars = self.lstm(reshaped_chars, char_seq_len)[0].reshape(batch_size, max_len, max_word_len, -1)
- # B x M x M x H
-
- lstm_chars = self.activation(lstm_chars)
- if self.pool_method == 'max':
- lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf'))
- chars, _ = torch.max(lstm_chars, dim=-2) # batch_size x max_len x H
- else:
- lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), 0)
- chars = torch.sum(lstm_chars, dim=-2) / chars_masks.eq(False).sum(dim=-1, keepdim=True).float()
-
- chars = self.fc(chars)
-
- return self.dropout(chars)
diff --git a/fastNLP/embeddings/contextual_embedding.py b/fastNLP/embeddings/contextual_embedding.py
deleted file mode 100644
index d3ae6b4e..00000000
--- a/fastNLP/embeddings/contextual_embedding.py
+++ /dev/null
@@ -1,113 +0,0 @@
-r"""
-.. todo::
- doc
-"""
-
-__all__ = [
- "ContextualEmbedding"
-]
-
-from abc import abstractmethod
-
-import torch
-
-from .embedding import TokenEmbedding
-from ..core import logger
-from ..core.batch import DataSetIter
-from ..core.dataset import DataSet
-from ..core.sampler import SequentialSampler
-from ..core.utils import _move_model_to_device, _get_model_device
-from ..core.vocabulary import Vocabulary
-
-
-class ContextualEmbedding(TokenEmbedding):
- r"""
- ContextualEmbedding组件. BertEmbedding与ElmoEmbedding的基类
- """
- def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0):
- super(ContextualEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
-
- def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights: bool = True):
- r"""
- 由于动态embedding生成比较耗时,所以可以把每句话embedding缓存下来,这样就不需要每次都运行生成过程。
-
- :param datasets: DataSet对象
- :param batch_size: int, 生成cache的sentence表示时使用的batch的大小
- :param device: 参考 :class::fastNLP.Trainer 的device
- :param delete_weights: 似乎在生成了cache之后删除权重,在不需要finetune动态模型的情况下,删除权重会大量减少内存占用。
- :return:
- """
- for index, dataset in enumerate(datasets):
- try:
- assert isinstance(dataset, DataSet), "Only fastNLP.DataSet object is allowed."
- assert 'words' in dataset.get_input_name(), "`words` field has to be set as input."
- except Exception as e:
- logger.error(f"Exception happens at {index} dataset.")
- raise e
-
- sent_embeds = {}
- _move_model_to_device(self, device=device)
- device = _get_model_device(self)
- pad_index = self._word_vocab.padding_idx
- logger.info("Start to calculate sentence representations.")
- with torch.no_grad():
- for index, dataset in enumerate(datasets):
- try:
- batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler())
- for batch_x, batch_y in batch:
- words = batch_x['words'].to(device)
- words_list = words.tolist()
- seq_len = words.ne(pad_index).sum(dim=-1)
- max_len = words.size(1)
- # 因为有些情况可能包含CLS, SEP, 从后面往前计算比较安全。
- seq_len_from_behind = (max_len - seq_len).tolist()
- word_embeds = self(words).detach().cpu().numpy()
- for b in range(words.size(0)):
- length = seq_len_from_behind[b]
- if length == 0:
- sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b]
- else:
- sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b, :-length]
- except Exception as e:
- logger.error(f"Exception happens at {index} dataset.")
- raise e
- logger.info("Finish calculating sentence representations.")
- self.sent_embeds = sent_embeds
- if delete_weights:
- self._delete_model_weights()
-
- def _get_sent_reprs(self, words):
- r"""
- 获取sentence的表示,如果有缓存,则返回缓存的值; 没有缓存则返回None
-
- :param words: torch.LongTensor
- :return:
- """
- if hasattr(self, 'sent_embeds'):
- words_list = words.tolist()
- seq_len = words.ne(self._word_pad_index).sum(dim=-1)
- _embeds = []
- for b in range(len(words)):
- words_i = tuple(words_list[b][:seq_len[b]])
- embed = self.sent_embeds[words_i]
- _embeds.append(embed)
- max_sent_len = max(map(len, _embeds))
- embeds = words.new_zeros(len(_embeds), max_sent_len, self.embed_size, dtype=torch.float,
- device=words.device)
- for i, embed in enumerate(_embeds):
- embeds[i, :len(embed)] = torch.FloatTensor(embed).to(words.device)
- return embeds
- return None
-
- @abstractmethod
- def _delete_model_weights(self):
- r"""删除计算表示的模型以节省资源"""
- raise NotImplementedError
-
- def remove_sentence_cache(self):
- r"""
- 删除缓存的句子表示. 删除之后如果模型权重没有被删除,将开始使用动态计算权重。
-
- :return:
- """
- del self.sent_embeds
diff --git a/fastNLP/embeddings/elmo_embedding.py b/fastNLP/embeddings/elmo_embedding.py
deleted file mode 100644
index 39cd4b30..00000000
--- a/fastNLP/embeddings/elmo_embedding.py
+++ /dev/null
@@ -1,335 +0,0 @@
-r"""
-.. todo::
- doc
-"""
-
-__all__ = [
- "ElmoEmbedding"
-]
-
-import codecs
-import json
-import os
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from .contextual_embedding import ContextualEmbedding
-from ..core import logger
-from ..core.vocabulary import Vocabulary
-from ..io.file_utils import cached_path, _get_embedding_url, PRETRAINED_ELMO_MODEL_DIR
-from ..modules.encoder._elmo import ElmobiLm, ConvTokenEmbedder
-
-
-class ElmoEmbedding(ContextualEmbedding):
- r"""
- 使用ELMo的embedding。初始化之后,只需要传入words就可以得到对应的embedding。
- 当前支持的使用名称初始化的模型:
-
- .. code::
-
- en: 即en-medium hidden_size 1024; output_size 12
- en-medium: hidden_size 2048; output_size 256
- en-origial: hidden_size 4096; output_size 512
- en-original-5.5b: hidden_size 4096; output_size 512
- en-small: hidden_size 1024; output_size 128
-
- Example::
-
- >>> import torch
- >>> from fastNLP import Vocabulary
- >>> from fastNLP.embeddings import ElmoEmbedding
- >>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
- >>> # 使用不同层的concat的结果
- >>> embed = ElmoEmbedding(vocab, model_dir_or_name='en', layers='1,2', requires_grad=False)
- >>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]])
- >>> outputs = embed(words)
- >>> outputs.size()
- >>> # torch.Size([1, 5, 2048])
-
- >>> # 使用不同层的weighted sum。
- >>> embed = ElmoEmbedding(vocab, model_dir_or_name='en', layers='mix', requires_grad=False)
- >>> embed.set_mix_weights_requires_grad() # 使得weighted的权重是可以学习的,但ELMO的LSTM部分是不更新
-
- """
-
- def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', layers: str = '2', requires_grad: bool = True,
- word_dropout=0.0, dropout=0.0, cache_word_reprs: bool = False):
- r"""
-
- :param vocab: 词表
- :param model_dir_or_name: 可以有两种方式调用预训练好的ELMo embedding:第一种是传入ELMo所在文件夹,该文件夹下面应该有两个文件,
- 其中一个是以json为后缀的配置文件,另一个是以pkl为后缀的权重文件;第二种是传入ELMo版本的名称,将自动查看缓存中是否存在该模型,
- 没有的话将自动下载并缓存。
- :param layers: str, 指定返回的层数(从0开始), 以,隔开不同的层。如果要返回第二层的结果'2', 返回后两层的结果'1,2'。不同的层的结果
- 按照这个顺序concat起来,默认为'2'。'mix'会使用可学习的权重结合不同层的表示(权重是否可训练与requires_grad保持一致,
- 初始化权重对三层结果进行mean-pooling, 可以通过ElmoEmbedding.set_mix_weights_requires_grad()方法只将mix weights设置为可学习。)
- :param requires_grad: bool, 该层是否需要gradient, 默认为False.
- :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
- :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
- :param cache_word_reprs: 可以选择对word的表示进行cache; 设置为True的话,将在初始化的时候为每个word生成对应的embedding,
- 并删除character encoder,之后将直接使用cache的embedding。默认为False。
- """
- super(ElmoEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
-
- # 根据model_dir_or_name检查是否存在并下载
- if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR:
- model_url = _get_embedding_url('elmo', model_dir_or_name.lower())
- model_dir = cached_path(model_url, name='embedding')
- # 检查是否存在
- elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))):
- model_dir = model_dir_or_name
- else:
- raise ValueError(f"Cannot recognize {model_dir_or_name}.")
- self.model = _ElmoModel(model_dir, vocab, cache_word_reprs=cache_word_reprs)
- num_layers = self.model.encoder.num_layers
-
- if layers == 'mix':
- self.layer_weights = nn.Parameter(torch.zeros(self.model.config['lstm']['n_layers'] + 1),
- requires_grad=requires_grad)
- self.gamma = nn.Parameter(torch.ones(1), requires_grad=requires_grad)
- self._get_outputs = self._get_mixed_outputs
- self._embed_size = self.model.config['lstm']['projection_dim'] * 2
- else:
- layers = list(map(int, layers.split(',')))
- assert len(layers) > 0, "Must choose at least one output, but got None."
- for layer in layers:
- assert 0 <= layer <= num_layers, f"Layer index should be in range [0, {num_layers}], but got {layer}."
- self.layers = layers
- self._get_outputs = self._get_layer_outputs
- self._embed_size = len(self.layers) * self.model.config['lstm']['projection_dim'] * 2
-
- self.requires_grad = requires_grad
-
- def _get_mixed_outputs(self, outputs):
- # outputs: num_layers x batch_size x max_len x hidden_size
- # return: batch_size x max_len x hidden_size
- weights = F.softmax(self.layer_weights + 1 / len(outputs), dim=0).to(outputs)
- outputs = torch.einsum('l,lbij->bij', weights, outputs)
- return self.gamma.to(outputs) * outputs
-
- def set_mix_weights_requires_grad(self, flag=True):
- r"""
- 当初始化ElmoEmbedding时layers被设置为mix时,可以通过调用该方法设置mix weights是否可训练。如果layers不是mix,调用
- 该方法没有用。
-
- :param bool flag: 混合不同层表示的结果是否可以训练。
- :return:
- """
- if hasattr(self, 'layer_weights'):
- self.layer_weights.requires_grad = flag
- self.gamma.requires_grad = flag
-
- def _get_layer_outputs(self, outputs):
- if len(self.layers) == 1:
- outputs = outputs[self.layers[0]]
- else:
- outputs = torch.cat(tuple([*outputs[self.layers]]), dim=-1)
-
- return outputs
-
- def forward(self, words: torch.LongTensor):
- r"""
- 计算words的elmo embedding表示。根据elmo文章中介绍的ELMO实际上是有2L+1层结果,但是为了让结果比较容易拆分,token的
- 被重复了一次,使得实际上layer=0的结果是[token_embedding;token_embedding], 而layer=1的结果是[forward_hiddens;
- backward_hiddens].
-
- :param words: batch_size x max_len
- :return: torch.FloatTensor. batch_size x max_len x (512*len(self.layers))
- """
- words = self.drop_word(words)
- outputs = self._get_sent_reprs(words)
- if outputs is not None:
- return self.dropout(outputs)
- outputs = self.model(words)
- outputs = self._get_outputs(outputs)
- return self.dropout(outputs)
-
- def _delete_model_weights(self):
- for name in ['layers', 'model', 'layer_weights', 'gamma']:
- if hasattr(self, name):
- delattr(self, name)
-
-
-class _ElmoModel(nn.Module):
- r"""
- 该Module是ElmoEmbedding中进行所有的heavy lifting的地方。做的工作,包括
- (1) 根据配置,加载模型;
- (2) 根据vocab,对模型中的embedding进行调整. 并将其正确初始化
- (3) 保存一个words与chars的对应转换,获取时自动进行相应的转换
- (4) 设计一个保存token的embedding,允许缓存word的表示。
-
- """
-
- def __init__(self, model_dir: str, vocab: Vocabulary = None, cache_word_reprs: bool = False):
- super(_ElmoModel, self).__init__()
- self.model_dir = model_dir
- dir = os.walk(self.model_dir)
- config_file = None
- weight_file = None
- config_count = 0
- weight_count = 0
- for path, dir_list, file_list in dir:
- for file_name in file_list:
- if file_name.__contains__(".json"):
- config_file = file_name
- config_count += 1
- elif file_name.__contains__(".pkl"):
- weight_file = file_name
- weight_count += 1
- if config_count > 1 or weight_count > 1:
- raise Exception(f"Multiple config files(*.json) or weight files(*.hdf5) detected in {model_dir}.")
- elif config_count == 0 or weight_count == 0:
- raise Exception(f"No config file or weight file found in {model_dir}")
- with open(os.path.join(model_dir, config_file), 'r') as config_f:
- config = json.load(config_f)
- self.weight_file = os.path.join(model_dir, weight_file)
- self.config = config
-
- OOV_TAG = ''
- PAD_TAG = ''
- BOS_TAG = ''
- EOS_TAG = ''
- BOW_TAG = ''
- EOW_TAG = ''
-
- # For the model trained with character-based word encoder.
- char_lexicon = {}
- with codecs.open(os.path.join(model_dir, 'char.dic'), 'r', encoding='utf-8') as fpi:
- for line in fpi:
- tokens = line.strip().split('\t')
- if len(tokens) == 1:
- tokens.insert(0, '\u3000')
- token, i = tokens
- char_lexicon[token] = int(i)
-
- # 做一些sanity check
- for special_word in [PAD_TAG, OOV_TAG, BOW_TAG, EOW_TAG]:
- assert special_word in char_lexicon, f"{special_word} not found in char.dic."
-
- # 从vocab中构建char_vocab
- char_vocab = Vocabulary(unknown=OOV_TAG, padding=PAD_TAG)
- # 需要保证与在里面
- char_vocab.add_word_lst([BOW_TAG, EOW_TAG, BOS_TAG, EOS_TAG])
-
- for word, index in vocab:
- char_vocab.add_word_lst(list(word))
-
- self.bos_index, self.eos_index, self._pad_index = len(vocab), len(vocab) + 1, vocab.padding_idx
- # 根据char_lexicon调整, 多设置一位,是预留给word padding的(该位置的char表示为全0表示)
- char_emb_layer = nn.Embedding(len(char_vocab) + 1, int(config['char_cnn']['embedding']['dim']),
- padding_idx=len(char_vocab))
-
- # 读入预训练权重 这里的elmo_model 包含char_cnn和 lstm 的 state_dict
- elmo_model = torch.load(os.path.join(self.model_dir, weight_file), map_location='cpu')
-
- char_embed_weights = elmo_model["char_cnn"]['char_emb_layer.weight']
-
- found_char_count = 0
- for char, index in char_vocab: # 调整character embedding
- if char in char_lexicon:
- index_in_pre = char_lexicon.get(char)
- found_char_count += 1
- else:
- index_in_pre = char_lexicon[OOV_TAG]
- char_emb_layer.weight.data[index] = char_embed_weights[index_in_pre]
-
- logger.info(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.")
- # 生成words到chars的映射
- max_chars = config['char_cnn']['max_characters_per_token']
- self.register_buffer('words_to_chars_embedding', torch.full((len(vocab) + 2, max_chars),
- fill_value=len(char_vocab),
- dtype=torch.long))
- for word, index in list(iter(vocab)) + [(BOS_TAG, len(vocab)), (EOS_TAG, len(vocab) + 1)]:
- if len(word) + 2 > max_chars:
- word = word[:max_chars - 2]
- if index == self._pad_index:
- continue
- elif word == BOS_TAG or word == EOS_TAG:
- char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(word)] + [
- char_vocab.to_index(EOW_TAG)]
- char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids))
- else:
- char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(c) for c in word] + [
- char_vocab.to_index(EOW_TAG)]
- char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids))
- self.words_to_chars_embedding[index] = torch.LongTensor(char_ids)
-
- self.char_vocab = char_vocab
-
- self.token_embedder = ConvTokenEmbedder(
- config, self.weight_file, None, char_emb_layer)
- elmo_model["char_cnn"]['char_emb_layer.weight'] = char_emb_layer.weight
- self.token_embedder.load_state_dict(elmo_model["char_cnn"])
-
- self.output_dim = config['lstm']['projection_dim']
-
- # lstm encoder
- self.encoder = ElmobiLm(config)
- self.encoder.load_state_dict(elmo_model["lstm"])
-
- if cache_word_reprs:
- if config['char_cnn']['embedding']['dim'] > 0: # 只有在使用了chars的情况下有用
- logger.info("Start to generate cache word representations.")
- batch_size = 320
- # bos eos
- word_size = self.words_to_chars_embedding.size(0)
- num_batches = word_size // batch_size + \
- int(word_size % batch_size != 0)
-
- self.cached_word_embedding = nn.Embedding(word_size,
- config['lstm']['projection_dim'])
- with torch.no_grad():
- for i in range(num_batches):
- words = torch.arange(i * batch_size,
- min((i + 1) * batch_size, word_size)).long()
- chars = self.words_to_chars_embedding[words].unsqueeze(1) # batch_size x 1 x max_chars
- word_reprs = self.token_embedder(words.unsqueeze(1),
- chars).detach() # batch_size x 1 x config['encoder']['projection_dim']
- self.cached_word_embedding.weight.data[words] = word_reprs.squeeze(1)
-
- logger.info("Finish generating cached word representations. Going to delete the character encoder.")
- del self.token_embedder, self.words_to_chars_embedding
- else:
- logger.info("There is no need to cache word representations, since no character information is used.")
-
- def forward(self, words):
- r"""
-
- :param words: batch_size x max_len
- :return: num_layers x batch_size x max_len x hidden_size
- """
- # 扩展,
- batch_size, max_len = words.size()
- expanded_words = words.new_zeros(batch_size, max_len + 2) # 因为pad一定为0,
- seq_len = words.ne(self._pad_index).sum(dim=-1)
- expanded_words[:, 1:-1] = words
- expanded_words[:, 0].fill_(self.bos_index)
- expanded_words[torch.arange(batch_size).to(words), seq_len + 1] = self.eos_index
- seq_len = seq_len + 2
- zero_tensor = expanded_words.new_zeros(expanded_words.shape)
- mask = (expanded_words == zero_tensor).unsqueeze(-1)
- if hasattr(self, 'cached_word_embedding'):
- token_embedding = self.cached_word_embedding(expanded_words)
- else:
- if hasattr(self, 'words_to_chars_embedding'):
- chars = self.words_to_chars_embedding[expanded_words]
- else:
- chars = None
- token_embedding = self.token_embedder(expanded_words, chars) # batch_size x max_len x embed_dim
-
- encoder_output = self.encoder(token_embedding, seq_len)
- if encoder_output.size(2) < max_len + 2:
- num_layers, _, output_len, hidden_size = encoder_output.size()
- dummy_tensor = encoder_output.new_zeros(num_layers, batch_size,
- max_len + 2 - output_len, hidden_size)
- encoder_output = torch.cat((encoder_output, dummy_tensor), 2)
- sz = encoder_output.size() # 2, batch_size, max_len, hidden_size
- token_embedding = token_embedding.masked_fill(mask, 0)
- token_embedding = torch.cat((token_embedding, token_embedding), dim=2).view(1, sz[1], sz[2], sz[3])
- encoder_output = torch.cat((token_embedding, encoder_output), dim=0)
-
- # 删除, . 这里没有精确地删除,但应该也不会影响最后的结果了。
- encoder_output = encoder_output[:, :, 1:-1]
- return encoder_output
diff --git a/fastNLP/embeddings/embedding.py b/fastNLP/embeddings/embedding.py
deleted file mode 100644
index 9b6a1a7f..00000000
--- a/fastNLP/embeddings/embedding.py
+++ /dev/null
@@ -1,212 +0,0 @@
-r"""
-该模块中的Embedding主要用于随机初始化的embedding(更推荐使用 :class:`fastNLP.embeddings.StaticEmbedding` ),或按照预训练权重初始化Embedding。
-
-"""
-
-__all__ = [
- "Embedding",
- "TokenEmbedding"
-]
-
-from abc import abstractmethod
-
-import torch
-import torch.nn as nn
-
-from .utils import get_embeddings
-
-
-class Embedding(nn.Module):
- r"""
- 词向量嵌入,支持输入多种方式初始化. 可以通过self.num_embeddings获取词表大小; self.embedding_dim获取embedding的维度.
-
- Example::
-
- >>> import numpy as np
- >>> from fastNLP.embeddings import Embedding
- >>> init_embed = (2000, 100)
- >>> embed = Embedding(init_embed) # 随机初始化一个具有2000个词,每个词表示为100维的词向量
- >>> init_embed = np.zeros((2000, 100))
- >>> embed = Embedding(init_embed) # 使用numpy.ndarray的值作为初始化值初始化一个Embedding
-
- """
-
- def __init__(self, init_embed, word_dropout=0, dropout=0.0, unk_index=None):
- r"""
-
- :param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray init_embed: 支持传入Embedding的大小(传入tuple(int, int),
- 第一个int为vocab_zie, 第二个int为embed_dim); 或传入Tensor, Embedding, numpy.ndarray等则直接使用该值初始化Embedding;
- :param float word_dropout: 按照一定概率随机将word设置为unk_index,这样可以使得unk这个token得到足够的训练, 且会对网络有
- 一定的regularize的作用。设置该值时,必须同时设置unk_index
- :param float dropout: 对Embedding的输出的dropout。
- :param int unk_index: drop word时替换为的index。fastNLP的Vocabulary的unk_index默认为1。
- """
- super(Embedding, self).__init__()
-
- self.embed = get_embeddings(init_embed)
-
- self.dropout = nn.Dropout(dropout)
- if not isinstance(self.embed, TokenEmbedding):
- if hasattr(self.embed, 'embed_size'):
- self._embed_size = self.embed.embed_size
- elif hasattr(self.embed, 'embedding_dim'):
- self._embed_size = self.embed.embedding_dim
- else:
- self._embed_size = self.embed.weight.size(1)
- if word_dropout > 0 and not isinstance(unk_index, int):
- raise ValueError("When drop word is set, you need to pass in the unk_index.")
- else:
- self._embed_size = self.embed.embed_size
- unk_index = self.embed.get_word_vocab().unknown_idx
- self.unk_index = unk_index
- self.word_dropout = word_dropout
-
- def forward(self, words):
- r"""
- :param torch.LongTensor words: [batch, seq_len]
- :return: torch.Tensor : [batch, seq_len, embed_dim]
- """
- if self.word_dropout > 0 and self.training:
- mask = torch.ones_like(words).float() * self.word_dropout
- mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
- words = words.masked_fill(mask, self.unk_index)
- words = self.embed(words)
- return self.dropout(words)
-
- @property
- def num_embedding(self) -> int:
- if isinstance(self.embed, nn.Embedding):
- return self.embed.weight.size(0)
- else:
- return self.embed.num_embeddings
-
- def __len__(self):
- return len(self.embed)
-
- @property
- def embed_size(self) -> int:
- return self._embed_size
-
- @property
- def embedding_dim(self) -> int:
- return self._embed_size
-
- @property
- def requires_grad(self):
- r"""
- Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
- :return:
- """
- if not isinstance(self.embed, TokenEmbedding):
- return self.embed.weight.requires_grad
- else:
- return self.embed.requires_grad
-
- @requires_grad.setter
- def requires_grad(self, value):
- if not isinstance(self.embed, TokenEmbedding):
- self.embed.weight.requires_grad = value
- else:
- self.embed.requires_grad = value
-
- @property
- def size(self):
- if isinstance(self.embed, TokenEmbedding):
- return self.embed.size
- else:
- return self.embed.weight.size()
-
-
-class TokenEmbedding(nn.Module):
- r"""
- fastNLP中各种Embedding的基类
-
- """
- def __init__(self, vocab, word_dropout=0.0, dropout=0.0):
- super(TokenEmbedding, self).__init__()
- if vocab.rebuild:
- vocab.build_vocab()
- assert vocab.padding is not None, "Vocabulary must have a padding entry."
- self._word_vocab = vocab
- self._word_pad_index = vocab.padding_idx
- if word_dropout > 0:
- assert vocab.unknown is not None, "Vocabulary must have unknown entry when you want to drop a word."
- self.word_dropout = word_dropout
- self._word_unk_index = vocab.unknown_idx
- self.dropout_layer = nn.Dropout(dropout)
-
- def drop_word(self, words):
- r"""
- 按照设定随机将words设置为unknown_index。
-
- :param torch.LongTensor words: batch_size x max_len
- :return:
- """
- if self.word_dropout > 0 and self.training:
- mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
- mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
- pad_mask = words.ne(self._word_pad_index)
- mask = mask.__and__(pad_mask)
- words = words.masked_fill(mask, self._word_unk_index)
- return words
-
- def dropout(self, words):
- r"""
- 对embedding后的word表示进行drop。
-
- :param torch.FloatTensor words: batch_size x max_len x embed_size
- :return:
- """
- return self.dropout_layer(words)
-
- @property
- def requires_grad(self):
- r"""
- Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
- :return:
- """
- requires_grads = set([param.requires_grad for param in self.parameters()])
- if len(requires_grads) == 1:
- return requires_grads.pop()
- else:
- return None
-
- @requires_grad.setter
- def requires_grad(self, value):
- for param in self.parameters():
- param.requires_grad = value
-
- def __len__(self):
- return len(self._word_vocab)
-
- @property
- def embed_size(self) -> int:
- return self._embed_size
-
- @property
- def embedding_dim(self) -> int:
- return self._embed_size
-
- @property
- def num_embeddings(self) -> int:
- r"""
- 这个值可能会大于实际的embedding矩阵的大小。
- :return:
- """
- return len(self._word_vocab)
-
- def get_word_vocab(self):
- r"""
- 返回embedding的词典。
-
- :return: Vocabulary
- """
- return self._word_vocab
-
- @property
- def size(self):
- return torch.Size(self.num_embeddings, self._embed_size)
-
- @abstractmethod
- def forward(self, words):
- raise NotImplementedError
diff --git a/fastNLP/embeddings/gpt2_embedding.py b/fastNLP/embeddings/gpt2_embedding.py
deleted file mode 100644
index a9ce3202..00000000
--- a/fastNLP/embeddings/gpt2_embedding.py
+++ /dev/null
@@ -1,656 +0,0 @@
-"""
-.. todo::
- doc
-"""
-
-__all__ = [
- "GPT2Embedding",
- "GPT2WordPieceEncoder"
-]
-
-import warnings
-from functools import partial
-from itertools import chain
-from collections import OrderedDict
-
-import torch
-from torch import nn
-import numpy as np
-
-from .contextual_embedding import ContextualEmbedding
-from ..core import logger
-from ..core.utils import _get_model_device
-from ..core.vocabulary import Vocabulary
-from ..io.file_utils import PRETRAINED_BERT_MODEL_DIR
-from ..modules.tokenizer import GPT2Tokenizer
-from ..modules.encoder.gpt2 import GPT2LMHeadModel, GPT2Model
-
-
-class GPT2Embedding(ContextualEmbedding):
- """
- 使用GPT2对words进行编码的Embedding。
-
- GPT2Embedding可以支持自动下载权重,当前支持的模型:
- en: gpt2
- en-medium: gpt2-medium
-
- Example::
-
- >>> import torch
- >>> from fastNLP import Vocabulary
- >>> from fastNLP.embeddings import BertEmbedding
- >>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
- >>> embed = GPT2Embedding(vocab, model_dir_or_name='en-small', requires_grad=False, layers='4,-2,-1')
- >>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]])
- >>> outputs = embed(words)
- >>> outputs.size()
- >>> # torch.Size([1, 5, 3096])
- """
-
- def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', layers: str = '-1',
- pool_method: str = 'first', dropout=0, requires_grad: bool = True,
- auto_truncate: bool = False, language_model: bool = False, **kwargs):
- """
-
- :param ~fastNLP.Vocabulary vocab: 词表
- :param str model_dir_or_name: 模型所在目录或者模型的名称。当传入模型所在目录时,目录中应该包含一个词表文件(以.txt作为后缀名),
- 权重文件(以.bin作为文件后缀名), 配置文件(以.json作为后缀名)。
- :param str layers: 输出embedding表示来自于哪些层,不同层的结果按照layers中的顺序在最后一维concat起来。以','隔开层数,层的序号是
- 从0开始,可以以负数去索引倒数几层。
- :param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces
- 中计算得到它对应的表示。支持 ``last`` , ``first`` , ``avg`` , ``max``。
- :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
- :param bool requires_grad: 是否需要gradient以更新Bert的权重。
- :param bool auto_truncate: 当句子words拆分为word pieces长度超过bert最大允许长度(一般为512), 自动截掉拆分后的超过510个
- word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS]
- 来进行分类的任务将auto_truncate置为True。
- :param bool language_model: 是否计算gpt2的lm loss,可以通过get_loss()获取,输入一个batch之后的get_loss调用即为batch的language
- model的loss
- :param **kwargs:
- bool only_use_pretrain_bpe: 仅使用出现在pretrain词表中的bpe,如果该词没法tokenize则使用unk。如果embedding不需要更新
- 建议设置为True。
- int min_freq: 仅在only_use_pretrain_bpe为False有效,大于等于该次数的词会被新加入GPT2的BPE词表中
- bool truncate_embed: 是否仅保留用到的bpe(这样会减内存占用和加快速度)
- """
- super().__init__(vocab, word_dropout=0, dropout=dropout)
-
- if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
- if 'cn' in model_dir_or_name.lower() and pool_method not in ('first', 'last'):
- logger.warning("For Chinese GPT, pooled_method should choose from 'first', 'last' in order to achieve"
- " faster speed.")
- warnings.warn("For Chinese GPT, pooled_method should choose from 'first', 'last' in order to achieve"
- " faster speed.")
-
- only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False)
- truncate_embed = kwargs.get('truncate_embed', True)
- min_freq = kwargs.get('min_freq', 1)
-
- self.lm_loss =language_model
- self.model = _GPT2Model(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
- pool_method=pool_method, auto_truncate=auto_truncate, language_model=language_model,
- only_use_pretrain_bpe=only_use_pretrain_bpe, truncate_embed=truncate_embed,
- min_freq=min_freq)
-
- self.requires_grad = requires_grad
- self._embed_size = len(self.model.layers) * self.model.encoder.config.n_embd
-
- def _delete_model_weights(self):
- del self.model
-
- def forward(self, words):
- """
- 计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要
- 删除这两个token的表示。
-
- :param torch.LongTensor words: [batch_size, max_len]
- :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
- """
- outputs = self._get_sent_reprs(words)
- if outputs is not None:
- return self.dropout(outputs)
- outputs = self.model(words)
- outputs = torch.cat([*outputs], dim=-1)
-
- return self.dropout(outputs)
-
- def drop_word(self, words):
- """
- :param torch.LongTensor words: batch_size x max_len
- :return:
- """
- if self.word_dropout > 0 and self.training:
- with torch.no_grad():
- mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
- mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
- words = words.masked_fill(mask, self._word_unk_index)
- return words
-
- def get_lm_loss(self, release=True):
- """
- 当language_model=True时,可以通过该接口获取当前batch的language model loss的大小
-
- :param bool release: 如果为True,获取了lm_loss后在下一次forward完成之前都无法获取lm_loss了
- :return: torch.FloatTensor([])
- """
- if hasattr(self.model, '_lm_loss_value'):
- lm_loss_value = self.model._lm_loss_value
- if release:
- delattr(self.model, '_lm_loss_value')
- return lm_loss_value
- elif self.lm_loss:
- raise RuntimeError("Make sure you have passed a batch into GPT2Embdding before accessing loss.")
- else:
- raise RuntimeError("Initialize your GPT2Embedding with language_model=True.")
-
-
-class GPT2WordPieceEncoder(nn.Module):
- """
- GPT2模型,使用时先使用本模型对应的Tokenizer对数据进行tokenize
- GPT2WordPieceEncoder可以支持自动下载权重,当前支持的模型:
- en: gpt2
- en-medium: gpt2-medium
-
- """
-
- def __init__(self, model_dir_or_name: str = 'en', layers: str = '-1',
- word_dropout=0, dropout=0, requires_grad: bool = True, language_model:bool=False):
- """
-
- :param str model_dir_or_name: 模型所在目录或者模型的名称。
- :param str,list layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层
- :param float word_dropout: 多大概率将word piece置为<|endoftext|>
- :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
- :param bool language_model: 是否使用language model
- :param bool requires_grad: 是否需要gradient。
- """
- super().__init__()
-
- self.model = _GPT2WordPieceModel(model_dir_or_name=model_dir_or_name, layers=layers, language_model=language_model)
- self._wordpiece_pad_index = self.model._wordpiece_pad_index
- self._embed_size = len(self.model.layers) * self.model.encoder.config.n_embd
- self.requires_grad = requires_grad
- self.dropout_layer = nn.Dropout(dropout)
- self._wordpiece_endoftext_index = self.model._endoftext_index
- self.word_dropout = word_dropout
- self.language_model = language_model
-
- @property
- def embed_size(self):
- return self._embed_size
-
- @property
- def embedding_dim(self):
- return self._embed_size
-
- @property
- def num_embedding(self):
- return self.model.encoder.config.vocab_size
-
- def index_datasets(self, *datasets, field_name, add_endoftext=False, add_prefix_space=True):
- """
- 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了
- bert的pad value。
-
- :param ~fastNLP.DataSet datasets: DataSet对象
- :param list[str] field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。
- :param bool add_endoftext: 在句子开头加入<|endofline|>。
- :param bool add_prefix_space: 是否在句首增加空格
- :return:
- """
- self.model.index_datasets(*datasets, field_name=field_name, add_endoftext=add_endoftext,
- add_prefix_space=add_prefix_space)
-
- def forward(self, word_pieces, token_type_ids=None):
- """
- 计算words的bert embedding表示。传入的words中应该在开头包含<|endofline|>。
-
- :param word_pieces: batch_size x max_len
- :param token_type_ids: batch_size x max_len,
- :return: torch.FloatTensor.
- """
-
- outputs = self.model(word_pieces)
- outputs = torch.cat([*outputs], dim=-1)
-
- return self.dropout_layer(outputs)
-
- def drop_word(self, words):
- """
-
- :param torch.LongTensor words: batch_size x max_len
- :return:
- """
- if self.word_dropout > 0 and self.training:
- with torch.no_grad():
- mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
- mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
- endoftext_mask = words.ne(self._wordpiece_endoftext_index)
- mask = endoftext_mask.__and__(mask) # pad的位置不为unk
- words = words.masked_fill(mask, self._wordpiece_unk_index)
- return words
-
- def generate_from_str(self, text='', max_len=40, do_sample=True, num_beams=1, temperature=1, top_k=50, top_p=1.0,
- repetition_penalty=1.0, length_penalty=1.0):
- """
-
- :param str text: 故事的开头
- :param int max_len: 生成多长的句子
- :param bool do_sample: 是否使用采样的方式生成,如果使用采样,相同的参数可能出现不同的句子。
- :param int num_beams: 使用多大的beam size
- :param float temperature: 用以调节采样分布的
- :param int top_k: 只保留此表中top_k个词进行生成。范围1-infinity
- :param float top_p: 保留概率累积为top_p的词汇,范围0-1.
- :param float repetition_penalty: 对重复token的惩罚
- :param float length_penalty: 惩罚过长的句子
- :return: list[str]
- """
- if len(text)==0:
- word_pieces = torch.LongTensor([[self.model.tokenizer.bos_index]])
- start_idx = 1
- else:
- assert isinstance(text, str), "Only string input allowed."
- assert self.language_model, "You must set `language_model=True`."
- word_pieces = self.model.convert_words_to_word_pieces(text, add_prefix_space=True)
- word_pieces = torch.LongTensor([word_pieces])
- start_idx = 0
- device = _get_model_device(self)
- word_pieces = word_pieces.to(device)
- outputs = self.model.encoder.generate(input_ids=word_pieces,
- max_length=max_len,
- do_sample=do_sample,
- num_beams=num_beams,
- temperature=temperature,
- top_k=top_k,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- bos_token_id=self.model.tokenizer.bos_index,
- pad_token_id=self.model.tokenizer.eos_index, # 使用<|endoftext|>代替pad
- eos_token_ids=self.model.tokenizer.eos_index,
- length_penalty=length_penalty).squeeze(0)
-
- output_strs = []
- if outputs.dim()==1:
- outputs = outputs[None]
- outputs = outputs[:, start_idx:]
- for i in range(len(outputs)):
- str_ = self.model.tokenizer.convert_tokens_to_string(self.model.tokenizer.convert_ids_to_tokens(outputs[i].tolist()))
- output_strs.append(str_)
-
- return output_strs
-
- def generate(self, word_pieces=None, max_len=40, do_sample=True, num_beams=1, temperature=1, top_k=50, top_p=1.0,
- repetition_penalty=1.0, length_penalty=1.0):
- """
-
- :param torch.LongTensor,None word_pieces: 如果传入tensor,shape应该为batch_size x start_len; 如果传入None,会随机生成。
- :param int max_len: 生成多长的句子
- :param bool do_sample: 是否使用采样的方式生成,如果使用采样,相同的参数可能出现不同的句子。
- :param int num_beams: 使用多大的beam size
- :param float temperature: 用以调节采样分布的
- :param int top_k: 只保留此表中top_k个词进行生成。范围1-infinity
- :param float top_p: 保留概率累积为top_p的词汇,范围0-1.
- :param float repetition_penalty: 对重复token的惩罚
- :param float length_penalty: 惩罚过长的句子
- :return:
- """
- raise NotImplemented
-
- def get_lm_loss(self, release=True):
- """
- 当language_model=True时,可以通过该接口获取当前batch的language model loss的大小
-
- :param bool release: 如果为True,获取了lm_loss后在下一次forward完成之前都无法获取lm_loss了
- :return: torch.FloatTensor([])
- """
- if hasattr(self.model, '_lm_loss_value'):
- lm_loss_value = self.model._lm_loss_value
- if release:
- delattr(self.model, '_lm_loss_value')
- return lm_loss_value
- elif self.lm_loss:
- raise RuntimeError("Make sure you have passed a batch into GPT2Embdding before accessing loss.")
- else:
- raise RuntimeError("Initialize your GPT2Embedding with language_model=True.")
-
-
-class _GPT2Model(nn.Module):
- def __init__(self, model_dir_or_name, vocab, layers, pool_method='first', auto_truncate=True, language_model=False,
- only_use_pretrain_bpe=False, min_freq=1, truncate_embed=False):
- super().__init__()
-
- self.tokenzier = GPT2Tokenizer.from_pretrained(model_dir_or_name)
- if language_model:
- self.encoder = GPT2LMHeadModel.from_pretrained(model_dir_or_name)
- else:
- self.encoder = GPT2Model.from_pretrained(model_dir_or_name)
-
- self.lm_loss = language_model
- self._max_position_embeddings = self.encoder.config.max_position_embeddings
- # 检查encoder_layer_number是否合理
- encoder_layer_number = self.encoder.config.n_layer
- if isinstance(layers, list):
- self.layers = [int(l) for l in layers]
- elif isinstance(layers, str):
- self.layers = list(map(int, layers.split(',')))
- else:
- raise TypeError("`layers` only supports str or list[int]")
- for layer in self.layers:
- if layer < 0:
- assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
- f"a GPT2 model with {encoder_layer_number} layers."
- else:
- assert layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
- f"a GPT2 model with {encoder_layer_number} layers."
-
- assert pool_method in ('avg', 'max', 'first', 'last')
- self.pool_method = pool_method
- self.auto_truncate = auto_truncate
-
- # 将所有vocab中word的wordpiece计算出来, 需要额外考虑和
- logger.info("Start to generate word pieces for word.")
- # 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值
- word_piece_dict = {'<|endoftext|>': 1} # 用到的word_piece以及新增的
- found_count = 0
- new_add_to_bpe_vocab = 0
- unsegment_count = 0
-
- for word, index in vocab:
- if index == vocab.padding_idx: # pad是个特殊的符号
- word = '<|endoftext|>'
- elif index == vocab.unknown_idx:
- word = '<|endoftext|>'
- # _words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split() # 这里暂时不考虑中文内容
- word_pieces = []
- word_pieces.extend(self.tokenzier.tokenize(word, add_prefix_space=True))
- if len(word_pieces) == 1:
- if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到
- if index not in (vocab.unknown_idx, vocab.padding_idx) and word_pieces[0] == '<|endoftext|>': # 说明这个词不在原始的word里面
- if vocab.word_count[word] >= min_freq and not vocab._is_word_no_create_entry(
- word) and not only_use_pretrain_bpe: # 出现次数大于这个次数才新增
- word_piece_dict[word] = 1 # 新增一个值
- new_add_to_bpe_vocab += 1
- unsegment_count += 1
- continue
- for word_piece in word_pieces:
- word_piece_dict[word_piece] = 1
- found_count += 1
-
- if unsegment_count>0:
- if only_use_pretrain_bpe or new_add_to_bpe_vocab==0:
- logger.info(f"{unsegment_count} words are unsegmented.")
- else:
- logger.info(f"{unsegment_count} words are unsegmented. Among them, {new_add_to_bpe_vocab} added to the BPE vocab.")
-
- original_embed = self.encoder.get_input_embeddings().weight
- # 特殊词汇要特殊处理
- if not truncate_embed: # 如果不删除的话需要将已有的加上
- word_piece_dict.update(self.tokenzier.encoder)
-
- embed = nn.Embedding(len(word_piece_dict), original_embed.size(1)) # 新的embed
- new_word_piece_vocab = OrderedDict()
-
- for index, token in enumerate(['<|endoftext|>']):
- index = word_piece_dict.pop(token, None)
- if index is not None:
- new_word_piece_vocab[token] = len(new_word_piece_vocab)
- embed.weight.data[new_word_piece_vocab[token]] = original_embed[self.tokenzier.encoder[token]]
-
- for token in word_piece_dict.keys():
- if token not in new_word_piece_vocab:
- new_word_piece_vocab[token] = len(new_word_piece_vocab)
- index = new_word_piece_vocab[token]
- if token in self.tokenzier.encoder:
- embed.weight.data[index] = original_embed[self.tokenzier.encoder[token]]
- else:
- embed.weight.data[index] = original_embed[self.tokenzier.encoder['<|endoftext|>']]
-
- self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab)
- self.encoder.set_input_embeddings(embed)
- self.encoder.tie_weights()
- self.encoder.config.vocab_size = len(new_word_piece_vocab)
-
- word_to_wordpieces = []
- word_pieces_lengths = []
- for word, index in vocab:
- if index == vocab.padding_idx: # pad是个特殊的符号
- word = '<|endoftext|>'
- elif index == vocab.unknown_idx:
- word = '<|endoftext|>'
- word_pieces = self.tokenzier.tokenize(word)
- word_pieces = self.tokenzier.convert_tokens_to_ids(word_pieces)
- word_to_wordpieces.append(word_pieces)
- word_pieces_lengths.append(len(word_pieces))
- self._word_pad_index = vocab.padding_idx
- self._endoftext_index = self.tokenzier.encoder.get('<|endoftext|>')
- self._wordpiece_pad_index = self.tokenzier.encoder.get('<|endoftext|>') # 需要用于生成word_piece
- self.word_to_wordpieces = np.array(word_to_wordpieces, dtype=object)
- self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths))
- logger.debug("Successfully generate word pieces.")
-
- def forward(self, words):
- """
-
- :param words: torch.LongTensor, batch_size x max_len
- :return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size
- """
- with torch.no_grad():
- batch_size, max_word_len = words.size()
- word_mask = words.ne(self._word_pad_index) # 为1的地方有word
- seq_len = word_mask.sum(dim=-1)
- batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False),
- 0) # batch_size x max_len
- word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size
- max_word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding)
- if max_word_piece_length > self._max_position_embeddings:
- if self.auto_truncate:
- word_pieces_lengths = word_pieces_lengths.masked_fill(
- word_pieces_lengths > self._max_position_embeddings,
- self._max_position_embeddings)
- else:
- raise RuntimeError(
- "After split words into word pieces, the lengths of word pieces are longer than the "
- f"maximum allowed sequence length:{self._max_position_embeddings} of GPT2. You can set "
- f"`auto_truncate=True` for BertEmbedding to automatically truncate overlong input.")
-
- word_pieces = words.new_full((batch_size, min(max_word_piece_length, self._max_position_embeddings)),
- fill_value=self._wordpiece_pad_index)
- word_labels = word_pieces.clone()
- attn_masks = torch.zeros_like(word_pieces)
- # 1. 获取words的word_pieces的id,以及对应的span范围
- word_indexes = words.cpu().numpy()
- for i in range(batch_size):
- word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i, :seq_len[i]]]))
- if self.auto_truncate and len(word_pieces_i) > self._max_position_embeddings:
- word_pieces_i = word_pieces_i[:self._max_position_embeddings]
- word_pieces[i, :word_pieces_lengths[i]] = torch.LongTensor(word_pieces_i)
- word_labels[i, word_pieces_lengths[i]:].fill_(-100) # 计算lm_loss用的
- attn_masks[i, :word_pieces_lengths[i]].fill_(1)
- # 添加<|endoftext|>, 默认不添加了
- # word_pieces[:, 0].fill_(self._endoftext_index)
- batch_indexes = torch.arange(batch_size).to(words)
- # 2. 获取hidden的结果,根据word_pieces进行对应的pool计算
- # all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...]
- if self.lm_loss:
- gpt2_outputs = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks, labels=word_labels,
- output_attentions=False)
- gpt2_outputs, self._lm_loss_value = gpt2_outputs[-1], gpt2_outputs[0] # n_layers x batch_size x max_len x hidden_size
- else:
- gpt2_outputs = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks,
- output_attentions=False)[-1]
- outputs = gpt2_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len,
- gpt2_outputs[-1].size(-1))
-
- batch_word_pieces_cum_length = batch_word_pieces_length.new_zeros(batch_size, max_word_len+1)
- batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len
-
- if self.pool_method == 'first':
- batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()]
- batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(max_word_piece_length), 0)
- _batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
- elif self.pool_method == 'last':
- batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()] - 1
- batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(max_word_piece_length), 0)
- _batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
-
- for l_index, l in enumerate(self.layers):
- output_layer = gpt2_outputs[l]
- real_word_piece_length = output_layer.size(1)
- if max_word_piece_length > real_word_piece_length: # 如果实际上是截取出来的
- paddings = output_layer.new_zeros(batch_size,
- max_word_piece_length - real_word_piece_length,
- output_layer.size(2))
- output_layer = torch.cat((output_layer, paddings), dim=1).contiguous()
- # 从word_piece collapse到word的表示
- # truncate_output_layer = output_layer # 删除endoftext batch_size x len x hidden_size
- if self.pool_method == 'first':
- tmp = output_layer[_batch_indexes, batch_word_pieces_cum_length]
- tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0)
- outputs[l_index, :, :batch_word_pieces_cum_length.size(1)] = tmp
- elif self.pool_method == 'last':
- tmp = output_layer[_batch_indexes, batch_word_pieces_cum_length]
- tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0)
- outputs[l_index, :, :batch_word_pieces_cum_length.size(1)] = tmp
- elif self.pool_method == 'max':
- for i in range(batch_size):
- for j in range(seq_len[i]):
- start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1]
- outputs[l_index, i, j], _ = torch.max(output_layer[i, start:end], dim=-2)
- else:
- for i in range(batch_size):
- for j in range(seq_len[i]):
- start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1]
- outputs[l_index, i, j] = torch.mean(output_layer[i, start:end], dim=-2)
-
- # 3. 最终的embedding结果
- return outputs
-
- def get_lm_loss(self):
- """
- 当language_model为True时,通过该接口可以获取最近传入的一个batch的lanuage model loss
-
- :return:
- """
- return self._lm_loss_value
-
-
-class _GPT2WordPieceModel(nn.Module):
- """
- 这个模块用于直接计算word_piece的结果.
-
- """
-
- def __init__(self, model_dir_or_name: str, layers: str = '-1', language_model: bool=False):
- super().__init__()
-
- self.tokenizer = GPT2Tokenizer.from_pretrained(model_dir_or_name)
- if language_model:
- self.encoder = GPT2LMHeadModel.from_pretrained(model_dir_or_name)
- else:
- self.encoder = GPT2Model.from_pretrained(model_dir_or_name)
-
- self.lm_loss = language_model
-
- # 检查encoder_layer_number是否合理
- encoder_layer_number = self.encoder.config.n_layer
-
- if isinstance(layers, list):
- self.layers = [int(l) for l in layers]
- elif isinstance(layers, str):
- self.layers = list(map(int, layers.split(',')))
- else:
- raise TypeError("`layers` only supports str or list[int]")
-
- for layer in self.layers:
- if layer < 0:
- assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
- f"a gpt2 model with {encoder_layer_number} layers."
- else:
- assert layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
- f"a gpt2 model with {encoder_layer_number} layers."
-
- self._endoftext_index = self.tokenizer.encoder.get('<|endoftext|>')
- self._wordpiece_pad_index = self.tokenizer.encoder.get('<|endoftext|>') # 原来并没有pad,使用这个值替代一下。这个pad值并不重要,因为是从左到右计算的
- self._max_position_embeddings = self.encoder.config.max_position_embeddings
-
- def index_datasets(self, *datasets, field_name, add_endoftext=False, add_prefix_space=True):
- """
- 使用gpt2的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果开头不是<|endoftext|>, 且将
- word_pieces这一列的pad value设置为了bert的pad value。
-
- :param datasets: DataSet对象
- :param field_name: 基于哪一列index
- :param bool add_prefix_space: 是否添加句首的空格
- :return:
- """
- convert_words_to_word_pieces = partial(self.convert_words_to_word_pieces, add_endoftext=add_endoftext,
- add_prefix_space=add_prefix_space)
- for index, dataset in enumerate(datasets):
- try:
- dataset.apply_field(convert_words_to_word_pieces, field_name=field_name, new_field_name='word_pieces',
- is_input=True)
- dataset.set_pad_val('word_pieces', self._wordpiece_pad_index)
- except Exception as e:
- logger.error(f"Exception happens when processing the {index} dataset.")
- raise e
-
- def convert_words_to_word_pieces(self, words, add_endoftext=False, add_prefix_space=True):
- """
-
- :param list[str],str words: 将str数据转换为index
- :param bool add_endoftext: 是否在句首增加endoftext
- :param bool add_prefix_space: 是否添加句首的空格
- :return:
- """
- word_pieces = []
- if isinstance(words, str):
- words = self.tokenizer.tokenize(words, add_prefix_space=add_prefix_space)
- word_piece_ids = self.tokenizer.convert_tokens_to_ids(words)
- word_pieces.extend(word_piece_ids)
- else:
- for word in words:
- tokens = self.tokenizer.tokenize(word, add_prefix_space=add_prefix_space)
- word_piece_ids = self.tokenizer.convert_tokens_to_ids(tokens)
- word_pieces.extend(word_piece_ids)
- if add_endoftext:
- if word_pieces[0] != self._endoftext_index:
- word_pieces.insert(0, self._endoftext_index)
- if len(word_pieces) > self._max_position_embeddings:
- word_pieces[self._max_position_embeddings - 1] = word_pieces[-1]
- word_pieces = word_pieces[:self._max_position_embeddings]
- return word_pieces
-
- def forward(self, word_pieces, token_type_ids=None):
- """
-
- :param word_pieces: torch.LongTensor, batch_size x max_len
- :param token_type_ids: torch.LongTensor, batch_size x max_len
- :return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size
- """
- batch_size, max_len = word_pieces.size()
-
- attn_masks = word_pieces.ne(self._wordpiece_pad_index) # 可能会错误导致开头的词被mask掉
- word_pieces = word_pieces.masked_fill(attn_masks.eq(0), self._endoftext_index) # 替换pad的值
- if self.lm_loss:
- labels = word_pieces.clone()
- labels = labels.masked_fill(labels.eq(self._wordpiece_pad_index), -100)
- gpt_outputs = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks,
- output_attentions=False, labels=labels)
- gpt_outputs, self._lm_loss_value = gpt_outputs[-1], gpt_outputs[0] # n_layers x batch_size x max_len x hidden_size
- else:
- gpt_outputs = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks,
- output_attentions=False)
- gpt_outputs = gpt_outputs[-1]
- # output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size
- outputs = gpt_outputs[0].new_zeros((len(self.layers), batch_size, max_len, gpt_outputs[0].size(-1)))
- for l_index, l in enumerate(self.layers):
- outputs[l_index] = gpt_outputs[l] # 删除开头
- return outputs
-
- def get_lm_loss(self):
- """
- 当language_model为True时,通过该接口可以获取最近传入的一个batch的lanuage model loss
-
- :return:
- """
- return self._lm_loss_value
-
diff --git a/fastNLP/embeddings/roberta_embedding.py b/fastNLP/embeddings/roberta_embedding.py
deleted file mode 100644
index 4b7040c0..00000000
--- a/fastNLP/embeddings/roberta_embedding.py
+++ /dev/null
@@ -1,589 +0,0 @@
-r"""
-.. todo::
- doc
-"""
-
-__all__ = [
- "RobertaEmbedding",
- "RobertaWordPieceEncoder"
-]
-
-
-from functools import partial
-import os
-import json
-from itertools import chain
-
-import numpy as np
-import torch
-import torch.nn as nn
-
-from .contextual_embedding import ContextualEmbedding
-from ..core import logger, Vocabulary
-from ..modules.encoder.roberta import RobertaModel
-from ..modules.tokenizer import RobertaTokenizer
-
-
-VOCAB_NAME = 'vocab.txt'
-ROBERTA_EMBED_HYPER = 'roberta_hyper.json'
-ROBERTA_ENCODER_HYPER = 'roberta_hyper.json'
-ROBERTA_EMBED_FOLDER = 'roberta'
-ROBERTA_ENCODER_FOLDER = 'roberta'
-
-
-class RobertaEmbedding(ContextualEmbedding):
- r"""
- 使用RoBERTa对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于
- 预训练的bert模型长度限制为512个token,而因为输入的word是未进行word piece分割的(word piece的分割有RobertaEmbedding在输入word
- 时切分),在分割之后长度可能会超过最大长度限制。
-
- RobertaEmbedding可以支持自动下载权重,当前支持的模型:
- en: roberta-base
- en-large: roberta-large
-
- Example::
-
- >>> import torch
- >>> from fastNLP import Vocabulary
- >>> from fastNLP.embeddings import RobertaEmbedding
- >>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
- >>> embed = RobertaEmbedding(vocab, model_dir_or_name='en', requires_grad=False, layers='4,-2,-1')
- >>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]])
- >>> outputs = embed(words)
- >>> outputs.size()
- >>> # torch.Size([1, 5, 2304])
- """
-
- def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', layers: str = '-1',
- pool_method: str = 'first', word_dropout=0, dropout=0, include_cls_sep: bool = False,
- pooled_cls=True, requires_grad: bool = True, auto_truncate: bool = False, **kwargs):
- r"""
-
- :param ~fastNLP.Vocabulary vocab: 词表
- :param str model_dir_or_name: 模型所在目录或者模型的名称。当传入模型所在目录时,目录中应该包含一个词表文件
- (以vocab.json作为后缀名), 权重文件(以.bin作为文件后缀名), 配置文件(以config.json作为后缀名)。
- :param str,list layers: 输出embedding表示来自于哪些层,不同层的结果按照layers中的顺序在最后一维concat起来。以','隔开层数,层的序号是
- 从0开始,可以以负数去索引倒数几层。layer=0为embedding层(包括wordpiece embedding, position embedding)
- :param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces
- 中计算得到它对应的表示。支持 ``last`` , ``first`` , ``avg`` , ``max``。
- :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
- :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
- :param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样
- 会使得word embedding的结果比输入的结果长两个token。如果该值为True,则在使用 :class::StackEmbedding 可能会与其它类型的
- embedding长度不匹配。
- :param bool pooled_cls: 返回的是否使用预训练中的BertPool映射一下,仅在include_cls_sep时有效。如果下游任务只取做预测,
- 一般该值为True。
- :param bool requires_grad: 是否需要gradient以更新Bert的权重。
- :param bool auto_truncate: 当句子words拆分为word pieces长度超过bert最大允许长度(一般为512), 自动截掉拆分后的超过510个
- word pieces后的内容,并将第512个word piece置为。超过长度的部分的encode结果直接全部置零。一般仅有只使用
- 来进行分类的任务将auto_truncate置为True。
- :param kwargs:
- int min_freq: 小于该次数的词会被unk代替, 默认为1
- """
- super().__init__(vocab, word_dropout=word_dropout, dropout=dropout)
-
- if word_dropout > 0:
- assert vocab.unknown is not None, "When word_drop > 0, Vocabulary must contain the unknown token."
-
- self._word_sep_index = -100
- if '' in vocab:
- self._word_sep_index = vocab['']
-
- self._word_cls_index = -100
- if '' in vocab:
- self._word_cls_index = vocab['']
-
- min_freq = kwargs.pop('min_freq', 1)
- self._min_freq = min_freq
-
- self.model = _RobertaWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
- pool_method=pool_method, include_cls_sep=include_cls_sep,
- pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq,
- **kwargs)
- self.requires_grad = requires_grad
- self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size
-
- def _delete_model_weights(self):
- del self.model
-
- def forward(self, words):
- r"""
- 计算words的roberta embedding表示。计算之前会在每句话的开始增加在结束增加, 并根据include_cls_sep判断要不要
- 删除这两个token的表示。
-
- :param torch.LongTensor words: [batch_size, max_len]
- :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
- """
- words = self.drop_word(words)
- outputs = self._get_sent_reprs(words)
- if outputs is not None:
- return self.dropout(outputs)
- outputs = self.model(words)
- outputs = torch.cat([*outputs], dim=-1)
-
- return self.dropout(outputs)
-
- def drop_word(self, words):
- r"""
- 按照设定随机将words设置为unknown_index。
-
- :param torch.LongTensor words: batch_size x max_len
- :return:
- """
- if self.word_dropout > 0 and self.training:
- with torch.no_grad():
- mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
- mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
- pad_mask = words.ne(self._word_pad_index)
- mask = pad_mask.__and__(mask) # pad的位置不为unk
- if self._word_sep_index!=-100:
- not_sep_mask = words.ne(self._word_sep_index)
- mask = mask.__and__(not_sep_mask)
- if self._word_cls_index!=-100:
- not_cls_mask = words.ne(self._word_cls_index)
- mask = mask.__and__(not_cls_mask)
- words = words.masked_fill(mask, self._word_unk_index)
- return words
-
- def save(self, folder):
- """
- 将roberta embedding保存到folder,保存之后包含三个文件vocab.txt, roberta_embed_hyper.txt, roberta_embed/,
-
- :param str folder: 保存地址
- :return:
- """
- os.makedirs(folder, exist_ok=True)
- self.get_word_vocab().save(os.path.join(folder, VOCAB_NAME))
-
- hyper = {}
- hyper['min_freq'] = self._min_freq
- hyper['layers'] = ','.join(map(str, self.model.layers))
- hyper['pool_method'] = self.model.pool_method
- hyper['dropout'] = self.dropout_layer.p
- hyper['word_dropout'] = self.word_dropout
- hyper['include_cls_sep'] = self.model.include_cls_sep
- hyper['pooled_cls'] = self.model.pooled_cls
- hyper['auto_truncate'] = self.model.auto_truncate
- hyper['requires_grad'] = bool(self.requires_grad)
-
- with open(os.path.join(folder, ROBERTA_EMBED_HYPER), 'w', encoding='utf-8') as f:
- json.dump(hyper, f, indent=2)
-
- os.makedirs(os.path.join(folder, ROBERTA_EMBED_FOLDER), exist_ok=True)
- self.model.save(os.path.join(folder, ROBERTA_EMBED_FOLDER))
-
- @classmethod
- def load(cls, folder):
- """
- 从folder中读取数据初始化RobertaEmbedding
-
- :param folder:
- :return:
- """
- for name in [VOCAB_NAME, ROBERTA_EMBED_HYPER, ROBERTA_EMBED_FOLDER]:
- assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}."
-
- vocab = Vocabulary.load(os.path.join(folder, VOCAB_NAME))
- with open(os.path.join(folder, ROBERTA_EMBED_HYPER), 'r', encoding='utf-8') as f:
- hyper = json.load(f)
- model_name_or_path = os.path.join(folder, ROBERTA_EMBED_FOLDER)
-
- roberta = cls(vocab=vocab, model_dir_or_name=model_name_or_path, **hyper)
- return roberta
-
-
-class _RobertaWordModel(nn.Module):
- def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first',
- include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2,
- **kwargs):
- super().__init__()
-
- if isinstance(layers, list):
- self.layers = [int(l) for l in layers]
- elif isinstance(layers, str):
- if layers.lower() == 'all':
- self.layers = None
- else:
- self.layers = list(map(int, layers.split(',')))
- else:
- raise TypeError("`layers` only supports str or list[int]")
-
- neg_num_output_layer = -16384
- pos_num_output_layer = 0
- if self.layers is None:
- neg_num_output_layer = -1
- else:
- for layer in self.layers:
- if layer < 0:
- neg_num_output_layer = max(layer, neg_num_output_layer)
- else:
- pos_num_output_layer = max(layer, pos_num_output_layer)
-
- self.tokenizer = RobertaTokenizer.from_pretrained(model_dir_or_name)
- self.encoder = RobertaModel.from_pretrained(model_dir_or_name,
- neg_num_output_layer=neg_num_output_layer,
- pos_num_output_layer=pos_num_output_layer,
- **kwargs)
- # 由于RobertaEmbedding中设置了padding_idx为1, 且使用了非常神奇的position计算方式,所以-2
- self._max_position_embeddings = self.encoder.config.max_position_embeddings - 2
- # 检查encoder_layer_number是否合理
- encoder_layer_number = len(self.encoder.encoder.layer)
- if self.layers is None:
- self.layers = [idx for idx in range(encoder_layer_number + 1)]
- logger.info(f'RoBERTa Model will return {len(self.layers)} layers (layer-0 '
- f'is embedding result): {self.layers}')
- assert len(self.layers) > 0, "There is no layer selected!"
- for layer in self.layers:
- if layer < 0:
- assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
- f"a roberta model with {encoder_layer_number} layers."
- else:
- assert layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
- f"a roberta model with {encoder_layer_number} layers."
-
- assert pool_method in ('avg', 'max', 'first', 'last')
- self.pool_method = pool_method
- self.include_cls_sep = include_cls_sep
- self.pooled_cls = pooled_cls
- self.auto_truncate = auto_truncate
-
- word_to_wordpieces = []
- word_pieces_lengths = []
- for word, index in vocab:
- if index == vocab.padding_idx: # pad是个特殊的符号
- word = ''
- elif index == vocab.unknown_idx:
- word = ''
- elif vocab.word_count[word] < min_freq:
- word = ''
- word_pieces = self.tokenizer.tokenize(word)
- word_pieces = self.tokenizer.convert_tokens_to_ids(word_pieces)
- word_to_wordpieces.append(word_pieces)
- word_pieces_lengths.append(len(word_pieces))
- self._cls_index = self.tokenizer.encoder['']
- self._sep_index = self.tokenizer.encoder['']
- self._word_pad_index = vocab.padding_idx
- self._wordpiece_pad_index = self.tokenizer.encoder[''] # 需要用于生成word_piece
- self.word_to_wordpieces = np.array(word_to_wordpieces, dtype=object)
- self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths))
- logger.debug("Successfully generate word pieces.")
-
- def forward(self, words):
- r"""
-
- :param words: torch.LongTensor, batch_size x max_len
- :return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size
- """
- with torch.no_grad():
- batch_size, max_word_len = words.size()
- word_mask = words.ne(self._word_pad_index) # 为1的地方有word
- seq_len = word_mask.sum(dim=-1)
- batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False),
- 0) # batch_size x max_len
- word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size
- max_word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding)
- if max_word_piece_length + 2 > self._max_position_embeddings:
- if self.auto_truncate:
- word_pieces_lengths = word_pieces_lengths.masked_fill(
- word_pieces_lengths + 2 > self._max_position_embeddings,
- self._max_position_embeddings - 2)
- else:
- raise RuntimeError(
- "After split words into word pieces, the lengths of word pieces are longer than the "
- f"maximum allowed sequence length:{self._max_position_embeddings} of bert. You can set "
- f"`auto_truncate=True` for BertEmbedding to automatically truncate overlong input.")
-
- # +2是由于需要加入与
- word_pieces = words.new_full((batch_size, min(max_word_piece_length + 2, self._max_position_embeddings)),
- fill_value=self._wordpiece_pad_index)
- attn_masks = torch.zeros_like(word_pieces)
- # 1. 获取words的word_pieces的id,以及对应的span范围
- word_indexes = words.cpu().numpy()
- for i in range(batch_size):
- word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i, :seq_len[i]]]))
- if self.auto_truncate and len(word_pieces_i) > self._max_position_embeddings - 2:
- word_pieces_i = word_pieces_i[:self._max_position_embeddings - 2]
- word_pieces[i, 1:word_pieces_lengths[i] + 1] = torch.LongTensor(word_pieces_i)
- attn_masks[i, :word_pieces_lengths[i] + 2].fill_(1)
- # 添加和
- word_pieces[:, 0].fill_(self._cls_index)
- batch_indexes = torch.arange(batch_size).to(words)
- word_pieces[batch_indexes, word_pieces_lengths + 1] = self._sep_index
- token_type_ids = torch.zeros_like(word_pieces)
- # 2. 获取hidden的结果,根据word_pieces进行对应的pool计算
- # all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...]
- bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids,
- attention_mask=attn_masks,
- output_all_encoded_layers=True)
- # output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size
-
- if self.include_cls_sep:
- s_shift = 1
- outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2,
- bert_outputs[-1].size(-1))
-
- else:
- s_shift = 0
- outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len,
- bert_outputs[-1].size(-1))
- batch_word_pieces_cum_length = batch_word_pieces_length.new_zeros(batch_size, max_word_len + 1)
- batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len
-
- if self.pool_method == 'first':
- batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()]
- batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(max_word_piece_length), 0)
- _batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
- elif self.pool_method == 'last':
- batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, 1:seq_len.max() + 1] - 1
- batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(max_word_piece_length), 0)
- _batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
-
- for l_index, l in enumerate(self.layers):
- output_layer = bert_outputs[l]
- real_word_piece_length = output_layer.size(1) - 2
- if max_word_piece_length > real_word_piece_length: # 如果实际上是截取出来的
- paddings = output_layer.new_zeros(batch_size,
- max_word_piece_length - real_word_piece_length,
- output_layer.size(2))
- output_layer = torch.cat((output_layer, paddings), dim=1).contiguous()
- # 从word_piece collapse到word的表示
- truncate_output_layer = output_layer[:, 1:-1] # 删除与 batch_size x len x hidden_size
- if self.pool_method == 'first':
- tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length]
- tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0)
- outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1) + s_shift] = tmp
-
- elif self.pool_method == 'last':
- tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length]
- tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0)
- outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1) + s_shift] = tmp
- elif self.pool_method == 'max':
- for i in range(batch_size):
- for j in range(seq_len[i]):
- start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1]
- outputs[l_index, i, j + s_shift], _ = torch.max(truncate_output_layer[i, start:end], dim=-2)
- else:
- for i in range(batch_size):
- for j in range(seq_len[i]):
- start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1]
- outputs[l_index, i, j + s_shift] = torch.mean(truncate_output_layer[i, start:end], dim=-2)
- if self.include_cls_sep:
- if l in (len(bert_outputs) - 1, -1) and self.pooled_cls:
- outputs[l_index, :, 0] = pooled_cls
- else:
- outputs[l_index, :, 0] = output_layer[:, 0]
- outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, word_pieces_lengths + s_shift]
-
- # 3. 最终的embedding结果
- return outputs
-
- def save(self, folder):
- """
- 给定一个folder保存pytorch_model.bin, config.json, vocab.txt
-
- :param str folder:
- :return:
- """
- self.tokenizer.save_pretrained(folder)
- self.encoder.save_pretrained(folder)
-
-
-class RobertaWordPieceEncoder(nn.Module):
- r"""
- 读取roberta模型,读取之后调用index_dataset方法在dataset中生成word_pieces这一列。
-
- RobertaWordPieceEncoder可以支持自动下载权重,当前支持的模型:
- en: roberta-base
- en-large: roberta-large
-
- """
- def __init__(self, model_dir_or_name: str = 'en', layers: str = '-1', pooled_cls: bool = False,
- word_dropout=0, dropout=0, requires_grad: bool = True, **kwargs):
- r"""
-
- :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased``
- :param str layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层。layer=0为embedding层(包括wordpiece embedding,
- position embedding)
- :param bool pooled_cls: 返回的句子开头的是否使用预训练中的BertPool映射一下。如果下游任务取做预测,一般该值为True。
- :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
- :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
- :param bool requires_grad: 是否需要gradient。
- """
- super().__init__()
-
- self.model = _WordPieceRobertaModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls)
- self._sep_index = self.model._sep_index
- self._cls_index = self.model._cls_index
- self._wordpiece_pad_index = self.model._wordpiece_pad_index
- self._wordpiece_unk_index = self.model._wordpiece_unknown_index
- self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size
- self.requires_grad = requires_grad
- self.word_dropout = word_dropout
- self.dropout_layer = nn.Dropout(dropout)
-
- @property
- def embed_size(self):
- return self._embed_size
-
- @property
- def embedding_dim(self):
- return self._embed_size
-
- @property
- def num_embedding(self):
- return self.model.encoder.config.vocab_size
-
- def index_datasets(self, *datasets, field_name, add_cls_sep=True, add_prefix_space=True):
- r"""
- 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了
- bert的pad value。
-
- :param ~fastNLP.DataSet datasets: DataSet对象
- :param str field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。
- :param bool add_cls_sep: 如果首尾不是与会在首尾额外加入与。
- :param bool add_prefix_spance: 是否在句首添加额外的空格,RoBERTa预训练时该值为True
- :return:
- """
- self.model.index_datasets(*datasets, field_name=field_name, add_cls_sep=add_cls_sep, add_prefix_space=add_prefix_space)
-
- def forward(self, word_pieces, token_type_ids=None):
- r"""
- 计算words的bert embedding表示。传入的words中应该自行包含与>的tag。
-
- :param words: batch_size x max_len
- :param token_type_ids: batch_size x max_len, 用于区分前一句和后一句话. 如果不传入,则自动生成(大部分情况,都不需要输入)。
- :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
- """
- word_pieces = self.drop_word(word_pieces)
- outputs = self.model(word_pieces)
- outputs = torch.cat([*outputs], dim=-1)
-
- return self.dropout_layer(outputs)
-
- def drop_word(self, words):
- r"""
- 按照设定随机将words设置为unknown_index。
-
- :param torch.LongTensor words: batch_size x max_len
- :return:
- """
- if self.word_dropout > 0 and self.training:
- with torch.no_grad():
- not_sep_mask = words.ne(self._sep_index)
- not_cls_mask = words.ne(self._cls_index)
- replaceable_mask = not_sep_mask.__and__(not_cls_mask)
- mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
- mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
- pad_mask = words.ne(self._wordpiece_pad_index)
- mask = pad_mask.__and__(mask).__and__(replaceable_mask) # pad的位置不为unk
- words = words.masked_fill(mask, self._wordpiece_unk_index)
- return words
-
- def save(self, folder):
- os.makedirs(folder, exist_ok=True)
-
- hyper = {}
- hyper['layers'] = ','.join(map(str, self.model.layers))
- hyper['dropout'] = self.dropout_layer.p
- hyper['word_dropout'] = self.word_dropout
- hyper['pooled_cls'] = self.model.pooled_cls
- hyper['requires_grad'] = bool(self.requires_grad)
-
- with open(os.path.join(folder, ROBERTA_ENCODER_HYPER), 'w', encoding='utf-8') as f:
- json.dump(hyper, f, indent=2)
-
- os.makedirs(os.path.join(folder, ROBERTA_ENCODER_FOLDER), exist_ok=True)
- self.model.save(os.path.join(folder, ROBERTA_ENCODER_FOLDER))
- logger.debug(f"RobertaWordPieceEncoder has been saved in {folder}")
-
- @classmethod
- def load(cls, folder):
- for name in [ROBERTA_ENCODER_HYPER, ROBERTA_ENCODER_FOLDER]:
- assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}."
-
- with open(os.path.join(folder, ROBERTA_ENCODER_HYPER), 'r', encoding='utf-8') as f:
- hyper = json.load(f)
-
- model_dir_or_name = os.path.join(os.path.join(folder, ROBERTA_ENCODER_FOLDER))
-
- bert_encoder = cls(model_dir_or_name=model_dir_or_name, **hyper)
- return bert_encoder
-
-
-class _WordPieceRobertaModel(nn.Module):
- def __init__(self, model_dir_or_name: str, layers: str = '-1', pooled_cls: bool=False):
- super().__init__()
-
- self.tokenizer = RobertaTokenizer.from_pretrained(model_dir_or_name)
- self.encoder = RobertaModel.from_pretrained(model_dir_or_name)
- # 检查encoder_layer_number是否合理
- encoder_layer_number = len(self.encoder.encoder.layer)
-
- if isinstance(layers, list):
- self.layers = [int(l) for l in layers]
- elif isinstance(layers, str):
- self.layers = list(map(int, layers.split(',')))
- else:
- raise TypeError("`layers` only supports str or list[int]")
-
- for layer in self.layers:
- if layer < 0:
- assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
- f"a RoBERTa model with {encoder_layer_number} layers."
- else:
- assert layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
- f"a RoBERTa model with {encoder_layer_number} layers."
-
- self._cls_index = self.tokenizer.encoder['']
- self._sep_index = self.tokenizer.encoder['']
- self._wordpiece_pad_index = self.tokenizer.encoder[''] # 需要用于生成word_piece
- self._wordpiece_unknown_index = self.tokenizer.encoder['']
- self.pooled_cls = pooled_cls
-
- def index_datasets(self, *datasets, field_name, add_cls_sep=True, add_prefix_space=True):
- r"""
- 使用roberta的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是
- 与会在首尾额外加入与, 且将word_pieces这一列的pad value设置为了bert的pad value。
-
- :param datasets: DataSet对象
- :param field_name: 基于哪一列index, 这一列一般是raw_string
- :param bool add_cls_sep: 是否在句首句尾添加cls和sep的index
- :param bool add_prefix_space: 是否在句子开头添加空格,预训练时RoBERTa该值为True
- :return:
- """
-
- encode_func = partial(self.tokenizer.encode, add_special_tokens=add_cls_sep, add_prefix_space=add_prefix_space)
-
- for index, dataset in enumerate(datasets):
- try:
- dataset.apply_field(encode_func, field_name=field_name, new_field_name='word_pieces',
- is_input=True)
- dataset.set_pad_val('word_pieces', self._wordpiece_pad_index)
- except Exception as e:
- logger.error(f"Exception happens when processing the {index} dataset.")
- raise e
-
- def forward(self, word_pieces):
- r"""
-
- :param word_pieces: torch.LongTensor, batch_size x max_len
- :return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size
- """
- batch_size, max_len = word_pieces.size()
-
- attn_masks = word_pieces.ne(self._wordpiece_pad_index)
- roberta_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=torch.zeros_like(word_pieces),
- attention_mask=attn_masks,
- output_all_encoded_layers=True)
- # output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size
- outputs = roberta_outputs[0].new_zeros((len(self.layers), batch_size, max_len, roberta_outputs[0].size(-1)))
- for l_index, l in enumerate(self.layers):
- roberta_output = roberta_outputs[l]
- if l in (len(roberta_output)-1, -1) and self.pooled_cls:
- roberta_output[:, 0] = pooled_cls
- outputs[l_index] = roberta_output
- return outputs
-
- def save(self, folder):
- self.tokenizer.save_pretrained(folder)
- self.encoder.save_pretrained(folder)
\ No newline at end of file
diff --git a/fastNLP/embeddings/stack_embedding.py b/fastNLP/embeddings/stack_embedding.py
deleted file mode 100644
index 7ef4736b..00000000
--- a/fastNLP/embeddings/stack_embedding.py
+++ /dev/null
@@ -1,99 +0,0 @@
-r"""
-.. todo::
- doc
-"""
-
-__all__ = [
- "StackEmbedding",
-]
-
-from typing import List
-
-import torch
-from torch import nn as nn
-
-from .embedding import TokenEmbedding
-from .utils import _check_vocab_has_same_index
-
-
-class StackEmbedding(TokenEmbedding):
- r"""
- 支持将多个embedding集合成一个embedding。
-
- Example::
-
- >>> from fastNLP import Vocabulary
- >>> from fastNLP.embeddings import StaticEmbedding, StackEmbedding
- >>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
- >>> embed_1 = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d', requires_grad=True)
- >>> embed_2 = StaticEmbedding(vocab, model_dir_or_name='en-word2vec-300', requires_grad=True)
- >>> embed = StackEmbedding([embed_1, embed_2])
-
- """
-
- def __init__(self, embeds: List[TokenEmbedding], word_dropout=0, dropout=0):
- r"""
-
- :param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致
- :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。不同embedidng会在相同的位置
- 被设置为unknown。如果这里设置了dropout,则组成的embedding就不要再设置dropout了。
- :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
- """
- vocabs = []
- for embed in embeds:
- if hasattr(embed, 'get_word_vocab'):
- vocabs.append(embed.get_word_vocab())
- _vocab = vocabs[0]
- for vocab in vocabs[1:]:
- if _vocab!=vocab:
- _check_vocab_has_same_index(_vocab, vocab)
-
- super(StackEmbedding, self).__init__(_vocab, word_dropout=word_dropout, dropout=dropout)
- assert isinstance(embeds, list)
- for embed in embeds:
- assert isinstance(embed, TokenEmbedding), "Only TokenEmbedding type is supported."
- self.embeds = nn.ModuleList(embeds)
- self._embed_size = sum([embed.embed_size for embed in self.embeds])
-
- def append(self, embed: TokenEmbedding):
- r"""
- 添加一个embedding到结尾。
- :param embed:
- :return:
- """
- assert isinstance(embed, TokenEmbedding)
- _check_vocab_has_same_index(self.get_word_vocab(), embed.get_word_vocab())
- self._embed_size += embed.embed_size
- self.embeds.append(embed)
- return self
-
- def pop(self):
- r"""
- 弹出最后一个embed
- :return:
- """
- embed = self.embeds.pop()
- self._embed_size -= embed.embed_size
- return embed
-
- @property
- def embed_size(self):
- r"""
- 该Embedding输出的vector的最后一维的维度。
- :return:
- """
- return self._embed_size
-
- def forward(self, words):
- r"""
- 得到多个embedding的结果,并把结果按照顺序concat起来。
-
- :param words: batch_size x max_len
- :return: 返回的shape和当前这个stack embedding中embedding的组成有关
- """
- outputs = []
- words = self.drop_word(words)
- for embed in self.embeds:
- outputs.append(embed(words))
- outputs = self.dropout(torch.cat(outputs, dim=-1))
- return outputs
diff --git a/fastNLP/embeddings/static_embedding.py b/fastNLP/embeddings/static_embedding.py
deleted file mode 100644
index 09c44d6c..00000000
--- a/fastNLP/embeddings/static_embedding.py
+++ /dev/null
@@ -1,405 +0,0 @@
-r"""
-.. todo::
- doc
-"""
-
-__all__ = [
- "StaticEmbedding"
-]
-import os
-import warnings
-from collections import defaultdict
-from copy import deepcopy
-import json
-from typing import Union
-
-import numpy as np
-import torch
-import torch.nn as nn
-
-from .embedding import TokenEmbedding
-from ..core import logger
-from ..core.vocabulary import Vocabulary
-from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path
-from ..io.file_utils import _get_file_name_base_on_postfix
-
-
-VOCAB_FILENAME = 'vocab.txt'
-STATIC_HYPER_FILENAME = 'static_hyper.json'
-STATIC_EMBED_FILENAME = 'static.txt'
-
-
-class StaticEmbedding(TokenEmbedding):
- r"""
- StaticEmbedding组件. 给定预训练embedding的名称或路径,根据vocab从embedding中抽取相应的数据(只会将出现在vocab中的词抽取出来,
- 如果没有找到,则会随机初始化一个值(但如果该word是被标记为no_create_entry的话,则不会单独创建一个值,而是会被指向unk的index))。
- 当前支持自动下载的预训练vector有:
-
- .. code::
-
- en: 实际为en-glove-840b-300d(常用)
- en-glove-6b-50d: glove官方的50d向量
- en-glove-6b-100d: glove官方的100d向量
- en-glove-6b-200d: glove官方的200d向量
- en-glove-6b-300d: glove官方的300d向量
- en-glove-42b-300d: glove官方使用42B数据训练版本
- en-glove-840b-300d:
- en-glove-twitter-27b-25d:
- en-glove-twitter-27b-50d:
- en-glove-twitter-27b-100d:
- en-glove-twitter-27b-200d:
- en-word2vec-300d: word2vec官方发布的300d向量
- en-fasttext-crawl: fasttext官方发布的300d英文预训练
- cn-char-fastnlp-100d: fastNLP训练的100d的character embedding
- cn-bi-fastnlp-100d: fastNLP训练的100d的bigram embedding
- cn-tri-fastnlp-100d: fastNLP训练的100d的trigram embedding
- cn-fasttext: fasttext官方发布的300d中文预训练embedding
-
- Example::
-
- >>> from fastNLP import Vocabulary
- >>> from fastNLP.embeddings import StaticEmbedding
- >>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
- >>> embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-50d')
-
- >>> vocab = Vocabulary().add_word_lst(["The", 'the', "THE"])
- >>> embed = StaticEmbedding(vocab, model_dir_or_name="en-glove-50d", lower=True)
- >>> # "the", "The", "THE"它们共用一个vector,且将使用"the"在预训练词表中寻找它们的初始化表示。
-
- >>> vocab = Vocabulary().add_word_lst(["The", "the", "THE"])
- >>> embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5, lower=True)
- >>> words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE"]]])
- >>> embed(words)
- >>> tensor([[[ 0.5773, 0.7251, -0.3104, 0.0777, 0.4849],
- [ 0.5773, 0.7251, -0.3104, 0.0777, 0.4849],
- [ 0.5773, 0.7251, -0.3104, 0.0777, 0.4849]]],
- grad_fn=) # 每种word的输出是一致的。
-
- """
-
- def __init__(self, vocab: Vocabulary, model_dir_or_name: Union[str, None] = 'en', embedding_dim=-1, requires_grad: bool = True,
- init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs):
- r"""
-
- :param Vocabulary vocab: 词表. StaticEmbedding只会加载包含在词表中的词的词向量,在预训练向量中没找到的使用随机初始化
- :param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding文件夹(文件夹下应该只有一个
- 以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。
- 如果输入为None则使用embedding_dim的维度随机初始化一个embedding。
- :param int embedding_dim: 随机初始化的embedding的维度,当该值为大于0的值时,将忽略model_dir_or_name。
- :param bool requires_grad: 是否需要gradient. 默认为True
- :param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法, 传入的方法应该接受一个tensor,并
- inplace地修改其值。
- :param bool lower: 是否将vocab中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独
- 为大写的词语开辟一个vector表示,则将lower设置为False。
- :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
- :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
- :param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。
- :param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。
- :param dict kwargs:
- bool only_train_min_freq: 仅对train中的词语使用min_freq筛选;
- bool only_norm_found_vector: 默认为False, 是否仅对在预训练中找到的词语使用normalize;
- bool only_use_pretrain_word: 默认为False, 仅使用出现在pretrain词表中的词,如果该词没有在预训练的词表中出现则为unk。如果embedding不需要更新建议设置为True。
- """
- super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
- if embedding_dim > 0:
- if model_dir_or_name:
- logger.info(f"StaticEmbedding will ignore `model_dir_or_name`, and randomly initialize embedding with"
- f" dimension {embedding_dim}. If you want to use pre-trained embedding, "
- f"set `embedding_dim` to 0.")
- model_dir_or_name = None
-
- # 得到cache_path
- if model_dir_or_name is None:
- assert embedding_dim >= 1, "The dimension of embedding should be larger than 1."
- embedding_dim = int(embedding_dim)
- model_path = None
- elif model_dir_or_name.lower() in PRETRAIN_STATIC_FILES:
- model_url = _get_embedding_url('static', model_dir_or_name.lower())
- model_path = cached_path(model_url, name='embedding')
- # 检查是否存在
- elif os.path.isfile(os.path.abspath(os.path.expanduser(model_dir_or_name))):
- model_path = os.path.abspath(os.path.expanduser(model_dir_or_name))
- elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))):
- model_path = _get_file_name_base_on_postfix(os.path.abspath(os.path.expanduser(model_dir_or_name)), '.txt')
- else:
- raise ValueError(f"Cannot recognize {model_dir_or_name}.")
-
- kwargs['min_freq'] = min_freq
- kwargs['lower'] = lower
- # 根据min_freq缩小vocab
- truncate_vocab = (vocab.min_freq is None and min_freq > 1) or (vocab.min_freq and vocab.min_freq < min_freq)
- if truncate_vocab:
- truncated_vocab = deepcopy(vocab)
- truncated_vocab.min_freq = min_freq
- truncated_vocab.word2idx = None
- if lower: # 如果有lower,将大小写的的freq需要同时考虑到
- lowered_word_count = defaultdict(int)
- for word, count in truncated_vocab.word_count.items():
- lowered_word_count[word.lower()] += count
- for word in truncated_vocab.word_count.keys():
- word_count = truncated_vocab.word_count[word]
- if lowered_word_count[word.lower()] >= min_freq and word_count < min_freq:
- truncated_vocab.add_word_lst([word] * (min_freq - word_count),
- no_create_entry=truncated_vocab._is_word_no_create_entry(word))
-
- # 只限制在train里面的词语使用min_freq筛选
- if kwargs.get('only_train_min_freq', False) and model_dir_or_name is not None:
- for word in truncated_vocab.word_count.keys():
- if truncated_vocab._is_word_no_create_entry(word) and truncated_vocab.word_count[word] < min_freq:
- truncated_vocab.add_word_lst([word] * (min_freq - truncated_vocab.word_count[word]),
- no_create_entry=True)
- truncated_vocab.build_vocab()
- truncated_words_to_words = torch.arange(len(vocab)).long()
- for word, index in vocab:
- truncated_words_to_words[index] = truncated_vocab.to_index(word)
- logger.info(f"{len(vocab) - len(truncated_vocab)} words have frequency less than {min_freq}.")
- vocab = truncated_vocab
-
- self.only_use_pretrain_word = kwargs.get('only_use_pretrain_word', False)
- self.only_norm_found_vector = kwargs.get('only_norm_found_vector', False)
- # 读取embedding
- if lower:
- lowered_vocab = Vocabulary(padding=vocab.padding, unknown=vocab.unknown)
- for word, index in vocab:
- if vocab._is_word_no_create_entry(word):
- lowered_vocab.add_word(word.lower(), no_create_entry=True)
- else:
- lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的
- logger.info(f"All word in the vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} "
- f"unique lowered words.")
- if model_path:
- embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method)
- else:
- embedding = self._randomly_init_embed(len(lowered_vocab), embedding_dim, init_method)
- self.register_buffer('words_to_words', torch.arange(len(vocab)).long())
- if lowered_vocab.unknown:
- unknown_idx = lowered_vocab.unknown_idx
- else:
- unknown_idx = embedding.size(0) - 1 # 否则是最后一个为unknow
- self.register_buffer('words_to_words', torch.arange(len(vocab)).long())
- words_to_words = torch.full((len(vocab),), fill_value=unknown_idx, dtype=torch.long).long()
- for word, index in vocab:
- if word not in lowered_vocab:
- word = word.lower()
- if word not in lowered_vocab and lowered_vocab._is_word_no_create_entry(word):
- continue # 如果不需要创建entry,已经默认unknown了
- words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)]
- self.register_buffer('words_to_words', words_to_words)
- self._word_unk_index = lowered_vocab.unknown_idx # 替换一下unknown的index
- else:
- if model_path:
- embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method)
- else:
- embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method)
- self.register_buffer('words_to_words', torch.arange(len(vocab)).long())
- if not self.only_norm_found_vector and normalize:
- embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12)
-
- if truncate_vocab:
- for i in range(len(truncated_words_to_words)):
- index_in_truncated_vocab = truncated_words_to_words[i]
- truncated_words_to_words[i] = self.words_to_words[index_in_truncated_vocab]
- del self.words_to_words
- self.register_buffer('words_to_words', truncated_words_to_words)
- self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1],
- padding_idx=vocab.padding_idx,
- max_norm=None, norm_type=2, scale_grad_by_freq=False,
- sparse=False, _weight=embedding)
- self._embed_size = self.embedding.weight.size(1)
- self.requires_grad = requires_grad
- self.kwargs = kwargs
-
- @property
- def weight(self):
- return self.embedding.weight
-
- def _randomly_init_embed(self, num_embedding, embedding_dim, init_embed=None):
- r"""
-
- :param int num_embedding: embedding的entry的数量
- :param int embedding_dim: embedding的维度大小
- :param callable init_embed: 初始化方法
- :return: torch.FloatTensor
- """
- embed = torch.zeros(num_embedding, embedding_dim)
-
- if init_embed is None:
- nn.init.uniform_(embed, -np.sqrt(3 / embedding_dim), np.sqrt(3 / embedding_dim))
- else:
- init_embed(embed)
-
- return embed
-
- def _load_with_vocab(self, embed_filepath, vocab, dtype=np.float32, padding='', unknown='',
- error='ignore', init_method=None):
- r"""
- 从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是
- word2vec(第一行只有两个元素)还是glove格式的数据。
-
- :param str embed_filepath: 预训练的embedding的路径。
- :param vocab: 词表 :class:`~fastNLP.Vocabulary` 类型,读取出现在vocab中的词的embedding。
- 没有出现在vocab中的词的embedding将通过找到的词的embedding的正态分布采样出来,以使得整个Embedding是同分布的。
- :param dtype: 读出的embedding的类型
- :param str padding: 词表中padding的token
- :param str unknown: 词表中unknown的token
- :param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。
- 这里主要可能出错的地方在于词表有空行或者词表出现了维度不一致。
- :param init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。默认使用torch.nn.init.zeros_
- :return torch.tensor: shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。
- """
- assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported."
- if not os.path.exists(embed_filepath):
- raise FileNotFoundError("`{}` does not exist.".format(embed_filepath))
- with open(embed_filepath, 'r', encoding='utf-8') as f:
- line = f.readline().strip()
- parts = line.split()
- start_idx = 0
- if len(parts) == 2:
- dim = int(parts[1])
- start_idx += 1
- else:
- dim = len(parts) - 1
- f.seek(0)
- matrix = {} # index是word在vocab中的index,value是vector或None(如果在pretrain中没有找到该word)
- if vocab.padding:
- matrix[vocab.padding_idx] = torch.zeros(dim)
- if vocab.unknown:
- matrix[vocab.unknown_idx] = torch.zeros(dim)
- found_count = 0
- found_unknown = False
- for idx, line in enumerate(f, start_idx):
- try:
- parts = line.strip().split()
- word = ''.join(parts[:-dim])
- nums = parts[-dim:]
- # 对齐unk与pad
- if word == padding and vocab.padding is not None:
- word = vocab.padding
- elif word == unknown and vocab.unknown is not None:
- word = vocab.unknown
- found_unknown = True
- if word in vocab:
- index = vocab.to_index(word)
- if index in matrix:
- warnings.warn(f"Word has more than one vector in embedding file. Set logger level to "
- f"DEBUG for detail.")
- logger.debug(f"Word:{word} occurs again in line:{idx}(starts from 0)")
- matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim))
- if self.only_norm_found_vector:
- matrix[index] = matrix[index] / np.linalg.norm(matrix[index])
- found_count += 1
- except Exception as e:
- if error == 'ignore':
- warnings.warn("Error occurred at the {} line.".format(idx))
- else:
- logger.error("Error occurred at the {} line.".format(idx))
- raise e
- logger.info("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab)))
- if not self.only_use_pretrain_word: # 如果只用pretrain中的值就不要为未找到的词创建entry了
- for word, index in vocab:
- if index not in matrix and not vocab._is_word_no_create_entry(word):
- if found_unknown: # 如果有unkonwn,用unknown初始化
- matrix[index] = matrix[vocab.unknown_idx]
- else:
- matrix[index] = None
- # matrix中代表是需要建立entry的词
- vectors = self._randomly_init_embed(len(matrix), dim, init_method)
-
- if vocab.unknown is None: # 创建一个专门的unknown
- unknown_idx = len(matrix)
- vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous()
- else:
- unknown_idx = vocab.unknown_idx
- self.register_buffer('words_to_words', torch.full((len(vocab), ), fill_value=unknown_idx, dtype=torch.long).long())
- index = 0
- for word, index_in_vocab in vocab:
- if index_in_vocab in matrix:
- vec = matrix.get(index_in_vocab)
- if vec is not None: # 使用找到的vector, 如果为None说明需要训练
- vectors[index] = vec
- self.words_to_words[index_in_vocab] = index
- index += 1
-
- return vectors
-
- def forward(self, words):
- r"""
- 传入words的index
-
- :param words: torch.LongTensor, [batch_size, max_len]
- :return: torch.FloatTensor, [batch_size, max_len, embed_size]
- """
- if hasattr(self, 'words_to_words'):
- words = self.words_to_words[words]
- words = self.drop_word(words)
- words = self.embedding(words)
- words = self.dropout(words)
- return words
-
- def save(self, folder):
- """
- 将embedding存储到folder下,之后可以通过使用load方法读取
-
- :param str folder: 会在该folder下生成三个文件, vocab.txt, static_embed_hyper.txt, static_embed_hyper.json.
- 其中vocab.txt可以用Vocabulary通过load读取; embedding.txt按照word2vec的方式存储,以空格的方式隔开元素,
- 第一行只有两个元素,剩下的行首先是word然后是各个维度的值; static_embed_hyper.json是StaticEmbedding的超参数
- :return:
- """
- os.makedirs(folder, exist_ok=True)
-
- vocab = self.get_word_vocab()
- vocab_fp = os.path.join(folder, VOCAB_FILENAME)
- vocab.save(vocab_fp)
- kwargs = self.kwargs.copy()
- kwargs['dropout'] = self.dropout_layer.p
- kwargs['word_dropout'] = self.word_dropout
- kwargs['requires_grad'] = self.requires_grad
- kwargs['only_norm_found_vector'] = False
- kwargs['only_use_pretrain_word'] = True
-
- with open(os.path.join(folder, STATIC_HYPER_FILENAME), 'w', encoding='utf-8') as f:
- json.dump(kwargs, f, indent=2)
-
- with open(os.path.join(folder, STATIC_EMBED_FILENAME), 'w', encoding='utf-8') as f:
- f.write('{}\n'.format(' '*30)) # 留白之后再来填写
- word_count = 0
- saved_word = {}
- valid_word_count = 0
- for i in range(len(self.words_to_words)):
- word = vocab.to_word(i)
- if not vocab._is_word_no_create_entry(word):
- word_count += 1
- if kwargs['lower']:
- word = word.lower()
- if word in saved_word:
- continue
- saved_word[word] = 1
- vec_i = self.words_to_words[i]
- if vec_i==vocab.unknown_idx and i!=vocab.unknown_idx:
- continue
- vec = self.embedding.weight.data[vec_i].tolist()
- vec_str = ' '.join(map(str, vec))
- f.write(f'{word} {vec_str}\n')
- valid_word_count += 1
- f.seek(0)
- f.write('{} {}'.format(valid_word_count, self.embedding_dim))
- logger.debug(f"StaticEmbedding has been saved to {folder}.")
-
- @classmethod
- def load(cls, folder):
- """
-
- :param str folder: 该folder下应该有以下三个文件vocab.txt, static_embed.txt, static_hyper.json
- :return:
- """
- for name in [VOCAB_FILENAME, STATIC_EMBED_FILENAME, STATIC_HYPER_FILENAME]:
- assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}."
-
- vocab = Vocabulary.load(os.path.join(folder, VOCAB_FILENAME))
- with open(os.path.join(folder, STATIC_HYPER_FILENAME), 'r', encoding='utf-8') as f:
- hyper = json.load(f)
-
- logger.info(f"Load StaticEmbedding from {folder}.")
- embed = cls(vocab=vocab, model_dir_or_name=os.path.join(folder, STATIC_EMBED_FILENAME), **hyper)
- return embed
-
diff --git a/fastNLP/embeddings/torch/__init__.py b/fastNLP/embeddings/torch/__init__.py
new file mode 100644
index 00000000..d105e329
--- /dev/null
+++ b/fastNLP/embeddings/torch/__init__.py
@@ -0,0 +1,15 @@
+"""
+torch 可使用的几种 Embedding 。
+"""
+__all__ = [
+ "CNNCharEmbedding",
+ "LSTMCharEmbedding",
+ "Embedding",
+ "StackEmbedding",
+ "StaticEmbedding"
+]
+
+from .char_embedding import *
+from .embedding import *
+from .stack_embedding import *
+from .static_embedding import StaticEmbedding
\ No newline at end of file
diff --git a/fastNLP/embeddings/torch/char_embedding.py b/fastNLP/embeddings/torch/char_embedding.py
new file mode 100644
index 00000000..110b4189
--- /dev/null
+++ b/fastNLP/embeddings/torch/char_embedding.py
@@ -0,0 +1,291 @@
+r"""
+该文件中主要包含的是 character 的 Embedding ,包括基于 CNN 与 LSTM 的 character Embedding。与其它 Embedding 一样,这里的 Embedding 输入也是
+词的 index 而不需要使用词语中的 char 的 index 来获取表达。
+"""
+
+__all__ = [
+ "CNNCharEmbedding",
+ "LSTMCharEmbedding"
+]
+
+from typing import List
+
+from ...envs.imports import _NEED_IMPORT_TORCH
+
+if _NEED_IMPORT_TORCH:
+ import torch
+ import torch.nn as nn
+ import torch.nn.functional as F
+ from torch.nn import LSTM
+ import torch.nn.utils.rnn as rnn
+
+
+from .embedding import TokenEmbedding
+from .static_embedding import StaticEmbedding
+from .utils import _construct_char_vocab_from_vocab
+from .utils import get_embeddings
+from ...core import logger
+from ...core.vocabulary import Vocabulary
+
+
+class CNNCharEmbedding(TokenEmbedding):
+ r"""
+ 使用 ``CNN`` 生成 ``character embedding``。``CNN`` 的结构为:char_embed(x) -> Dropout(x) -> CNN(x) -> activation(x) -> pool -> fc -> Dropout.
+ 不同的 ``kernel`` 大小的 ``fitler`` 结果是拼起来然后通过一层 **全连接层** 然后输出 ``word`` 的表示。
+
+ Example::
+
+ >>> import torch
+ >>> from fastNLP import Vocabulary
+ >>> from fastNLP.embeddings.torch import CNNCharEmbedding
+ >>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
+ >>> embed = CNNCharEmbedding(vocab, embed_size=50)
+ >>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]])
+ >>> outputs = embed(words)
+ >>> outputs.size()
+ torch.Size([1, 5,50])
+
+ :param vocab: 词表
+ :param embed_size: 该 :class:`CNNCharEmbedding` 的输出维度大小。
+ :param char_emb_size: character 的 embed 的维度。character 是从 ``vocab`` 中生成的。
+ :param word_dropout: 按照一定概率随机将 word 设置为 ``unk_index`` ,这样可以使得 ```` 这个 token 得到足够的训练,
+ 且会对网络有一定的 regularize 作用。
+ :param dropout: 以多大的概率 drop 分布式表示与 char embedding 的输出。
+ :param filter_nums: filter 的数量。长度需要和 ``kernel_sizes`` 一致。
+ :param kernel_sizes: kernel 的大小。
+ :param pool_method: character 的表示在合成一个表示时所使用的 pool 池化方法,支持 ``['avg', 'max']`` 。
+ :param activation: CNN 之后使用的激活方法,支持 ``['relu', 'sigmoid', 'tanh']`` 或者自定义函数。
+ :param min_char_freq: character 的最少出现次数。
+ :param pre_train_char_embed: 可以有两种方式调用预训练好的 :class:`CNNCharEmbedding` :
+
+ 1. 传入 embedding 文件夹(文件夹下应该只有一个以 **.txt** 作为后缀的文件)或文件路径;
+ 2. 传入 embedding 的名称,第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载;
+ 3. 如果输入为 ``None`` 则使用 ``embedding_dim`` 的维度随机初始化一个 embedding;
+ :param requires_grad: 是否更新权重
+ :param include_word_start_end: 是否在每个 word 开始的 character 前和结束的 character 增加特殊标示符号
+ """
+
+ def __init__(self, vocab: Vocabulary, embed_size: int = 50, char_emb_size: int = 50, word_dropout: float = 0,
+ dropout: float = 0, filter_nums: List[int] = (40, 30, 20), kernel_sizes: List[int] = (5, 3, 1),
+ pool_method: str = 'max', activation='relu', min_char_freq: int = 2, pre_train_char_embed: str = None,
+ requires_grad:bool=True, include_word_start_end:bool=True):
+
+ super(CNNCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
+
+ for kernel in kernel_sizes:
+ assert kernel % 2 == 1, "Only odd kernel is allowed."
+
+ assert pool_method in ('max', 'avg')
+ self.pool_method = pool_method
+ # activation function
+ if isinstance(activation, str):
+ if activation.lower() == 'relu':
+ self.activation = F.relu
+ elif activation.lower() == 'sigmoid':
+ self.activation = F.sigmoid
+ elif activation.lower() == 'tanh':
+ self.activation = F.tanh
+ elif activation is None:
+ self.activation = lambda x: x
+ elif callable(activation):
+ self.activation = activation
+ else:
+ raise Exception(
+ "Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]")
+
+ logger.info("Start constructing character vocabulary.")
+ # 建立char的词表
+ self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq,
+ include_word_start_end=include_word_start_end)
+ self.char_pad_index = self.char_vocab.padding_idx
+ logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.")
+ # 对vocab进行index
+ max_word_len = max(map(lambda x: len(x[0]), vocab))
+ if include_word_start_end:
+ max_word_len += 2
+ self.register_buffer('words_to_chars_embedding', torch.full((len(vocab), max_word_len),
+ fill_value=self.char_pad_index, dtype=torch.long))
+ self.register_buffer('word_lengths', torch.zeros(len(vocab)).long())
+ for word, index in vocab:
+ # if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了。修改为不区分pad, 这样所有的也是同一个embed
+ if include_word_start_end:
+ word = [''] + list(word) + ['']
+ self.words_to_chars_embedding[index, :len(word)] = \
+ torch.LongTensor([self.char_vocab.to_index(c) for c in word])
+ self.word_lengths[index] = len(word)
+ # self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)
+ if pre_train_char_embed:
+ self.char_embedding = StaticEmbedding(self.char_vocab, model_dir_or_name=pre_train_char_embed)
+ else:
+ self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size))
+
+ self.convs = nn.ModuleList([nn.Conv1d(
+ self.char_embedding.embedding_dim, filter_nums[i], kernel_size=kernel_sizes[i], bias=True,
+ padding=kernel_sizes[i] // 2)
+ for i in range(len(kernel_sizes))])
+ self._embed_size = embed_size
+ self.fc = nn.Linear(sum(filter_nums), embed_size)
+ self.requires_grad = requires_grad
+
+ def forward(self, words):
+ r"""
+ 输入 ``words`` 的 index 后,生成对应的 ``words`` 的表示。
+
+ :param words: 形状为 ``[batch_size, max_len]``
+ :return: 形状为 ``[batch_size, max_len, embed_size]`` 的结果
+ """
+ words = self.drop_word(words)
+ batch_size, max_len = words.size()
+ chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len
+ word_lengths = self.word_lengths[words] # batch_size x max_len
+ max_word_len = word_lengths.max()
+ chars = chars[:, :, :max_word_len]
+ # 为1的地方为mask
+ chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了
+ chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size
+ chars = self.dropout(chars)
+ reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1)
+ reshaped_chars = reshaped_chars.transpose(1, 2) # B' x E x M
+ conv_chars = [conv(reshaped_chars).transpose(1, 2).reshape(batch_size, max_len, max_word_len, -1)
+ for conv in self.convs]
+ conv_chars = torch.cat(conv_chars, dim=-1).contiguous() # B x max_len x max_word_len x sum(filters)
+ conv_chars = self.activation(conv_chars)
+ if self.pool_method == 'max':
+ conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf'))
+ chars, _ = torch.max(conv_chars, dim=-2) # batch_size x max_len x sum(filters)
+ else:
+ conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), 0)
+ chars = torch.sum(conv_chars, dim=-2) / chars_masks.eq(False).sum(dim=-1, keepdim=True).float()
+ chars = self.fc(chars)
+ return self.dropout(chars)
+
+
+class LSTMCharEmbedding(TokenEmbedding):
+ r"""
+ 使用 ``LSTM`` 的方式对 ``character`` 进行 ``encode``。结构为:embed(x) -> Dropout(x) -> LSTM(x) -> activation(x) -> pool -> Dropout 。
+
+ Example::
+
+ >>> import torch
+ >>> from fastNLP import Vocabulary
+ >>> from fastNLP.embeddings.torch import LSTMCharEmbedding
+ >>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
+ >>> embed = LSTMCharEmbedding(vocab, embed_size=50)
+ >>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]])
+ >>> outputs = embed(words)
+ >>> outputs.size()
+ >>> # torch.Size([1, 5,50])
+
+ :param vocab: 词表
+ :param embed_size: :class:`LSTMCharEmbedding` 的输出维度。
+ :param char_emb_size: character 的 embedding 的维度。
+ :param word_dropout: 按照一定概率随机将 word 设置为 ``unk_index`` ,这样可以使得 ```` 这个 token 得到足够的训练,
+ 且会对网络有一定的 regularize 作用。
+ :param dropout: 以多大的概率 drop 分布式表示与 char embedding 的输出。
+ :param hidden_size: ``LSTM`` 的中间 hidden 的大小,如果为 ``bidirectional`` 的,hidden 会除二。
+ :param pool_method: character 的表示在合成一个表示时所使用的 pool 池化方法,支持 ``['avg', 'max']`` 。
+ :param activation: LSTM 之后使用的激活方法,支持 ``['relu', 'sigmoid', 'tanh']`` 或者自定义函数。
+ :param min_char_freq: character 的最少出现次数。
+ :param bidirectional: 是否使用双向的 LSTM 进行 encode。
+ :param pre_train_char_embed: 可以有两种方式调用预训练好的 :class:`LSTMCharEmbedding` :
+
+ 1. 传入 embedding 文件夹(文件夹下应该只有一个以 **.txt** 作为后缀的文件)或文件路径;
+ 2. 传入 embedding 的名称,第二种情况将自动查看缓存中是否存在该模型,
+ 没有的话将自动下载;
+ 3. 如果输入为 ``None`` 则使用 ``embedding_dim`` 的维度随机初始化一个 embedding;
+ :param requires_grad: 是否更新权重
+ :param include_word_start_end: 是否在每个 word 开始的 character 前和结束的 character 增加特殊标示符号
+ """
+
+ def __init__(self, vocab: Vocabulary, embed_size: int = 50, char_emb_size: int = 50, word_dropout: float = 0,
+ dropout: float = 0, hidden_size=50, pool_method: str = 'max', activation='relu',
+ min_char_freq: int = 2, bidirectional=True, pre_train_char_embed: str = None,
+ requires_grad:bool=True, include_word_start_end:bool=True):
+
+ super(LSTMCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
+
+ assert hidden_size % 2 == 0, "Only even kernel is allowed."
+
+ assert pool_method in ('max', 'avg')
+ self.pool_method = pool_method
+ # activation function
+ if isinstance(activation, str):
+ if activation.lower() == 'relu':
+ self.activation = F.relu
+ elif activation.lower() == 'sigmoid':
+ self.activation = F.sigmoid
+ elif activation.lower() == 'tanh':
+ self.activation = F.tanh
+ elif activation is None:
+ self.activation = lambda x: x
+ elif callable(activation):
+ self.activation = activation
+ else:
+ raise Exception(
+ "Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]")
+
+ logger.info("Start constructing character vocabulary.")
+ # 建立char的词表
+ self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq,
+ include_word_start_end=include_word_start_end)
+ self.char_pad_index = self.char_vocab.padding_idx
+ logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.")
+ # 对vocab进行index
+ max_word_len = max(map(lambda x: len(x[0]), vocab))
+ if include_word_start_end:
+ max_word_len += 2
+ self.register_buffer('words_to_chars_embedding', torch.full((len(vocab), max_word_len),
+ fill_value=self.char_pad_index, dtype=torch.long))
+ self.register_buffer('word_lengths', torch.zeros(len(vocab)).long())
+ for word, index in vocab:
+ # if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了. 修改为不区分pad与否
+ if include_word_start_end:
+ word = [''] + list(word) + ['']
+ self.words_to_chars_embedding[index, :len(word)] = \
+ torch.LongTensor([self.char_vocab.to_index(c) for c in word])
+ self.word_lengths[index] = len(word)
+ if pre_train_char_embed:
+ self.char_embedding = StaticEmbedding(self.char_vocab, pre_train_char_embed)
+ else:
+ self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size))
+
+ self.fc = nn.Linear(hidden_size, embed_size)
+ hidden_size = hidden_size // 2 if bidirectional else hidden_size
+
+ self.lstm = LSTM(self.char_embedding.embedding_dim, hidden_size, bidirectional=bidirectional, batch_first=True)
+ self._embed_size = embed_size
+ self.bidirectional = bidirectional
+ self.requires_grad = requires_grad
+
+ def forward(self, words):
+ r"""
+ 输入 ``words`` 的 index 后,生成对应的 ``words`` 的表示。
+
+ :param words: 形状为 ``[batch_size, max_len]``
+ :return: 形状为 ``[batch_size, max_len, embed_size]`` 的结果
+ """
+ words = self.drop_word(words)
+ batch_size, max_len = words.size()
+ chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len
+ word_lengths = self.word_lengths[words] # batch_size x max_len
+ max_word_len = word_lengths.max()
+ chars = chars[:, :, :max_word_len]
+ # 为mask的地方为1
+ chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了
+ chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size
+ chars = self.dropout(chars)
+ reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1)
+ lstm_chars = self.lstm(reshaped_chars, None)[0].reshape(batch_size, max_len, max_word_len, -1)
+ # B x M x M x H
+
+ lstm_chars = self.activation(lstm_chars)
+ if self.pool_method == 'max':
+ lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf'))
+ chars, _ = torch.max(lstm_chars, dim=-2) # batch_size x max_len x H
+ else:
+ lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), 0)
+ chars = torch.sum(lstm_chars, dim=-2) / chars_masks.eq(False).sum(dim=-1, keepdim=True).float()
+
+ chars = self.fc(chars)
+
+ return self.dropout(chars)
diff --git a/fastNLP/embeddings/torch/embedding.py b/fastNLP/embeddings/torch/embedding.py
new file mode 100644
index 00000000..efcf7894
--- /dev/null
+++ b/fastNLP/embeddings/torch/embedding.py
@@ -0,0 +1,225 @@
+r"""
+该模块中的 :class:`Embedding` 主要用于随机初始化的 embedding (更推荐使用 :class:`fastNLP.embeddings.torch.StaticEmbedding` ),或按照预训练权重初始化 Embedding。
+
+"""
+
+__all__ = [
+ "Embedding",
+]
+
+from abc import abstractmethod
+from typing import Union, Tuple
+from ...envs.imports import _NEED_IMPORT_TORCH
+
+if _NEED_IMPORT_TORCH:
+ import torch
+ from torch.nn import Module
+ from torch import nn
+else:
+ from ...core.utils.dummy_class import DummyClass as Module
+
+import numpy as np
+
+from .utils import get_embeddings
+
+
+class Embedding(Module):
+ r"""
+ 词向量嵌入,支持输入多种方式初始化. 可以通过 ``self.num_embeddings`` 获取词表大小; ``self.embedding_dim`` 获取 ``embedding`` 的维度.
+
+ Example::
+
+ >>> import numpy as np
+ >>> from fastNLP.embeddings.torch import Embedding
+ >>> init_embed = (2000, 100)
+ >>> embed = Embedding(init_embed) # 随机初始化一个具有2000个词,每个词表示为100维的词向量
+ >>> init_embed = np.zeros((2000, 100))
+ >>> embed = Embedding(init_embed) # 使用numpy.ndarray的值作为初始化值初始化一个Embedding
+
+ :param init_embed: 支持传入 Embedding 的大小。支持以下类型:
+
+ 1. 传入 tuple(int, int),第一个 int 为 ``vocab_size``, 第二个 int ``为embed_dim``;
+ 2. 传入 :class:`Tensor`, :class:`Embedding`, :class:`numpy.ndarray` 等则直接使用该值初始化 Embedding;
+
+ :param word_dropout: 按照一定概率随机将 word 设置为 ``unk_index`` ,这样可以使得 ```` 这个 token 得到足够的训练,
+ 且会对网络有一定的 regularize 作用。设置该值时,必须同时设置 ``unk_index``。
+ :param dropout: 对 Embedding 的输出的 dropout。
+ :param unk_index: drop word 时替换为的 index。**fastNLP** 的 :class:`fastNLP.Vocabulary`` 的 ``unk_index`` 默认为 1。
+ """
+
+ def __init__(self, init_embed:Union[Tuple[int,int],'torch.FloatTensor','nn.Embedding',np.ndarray],
+ word_dropout:float=0, dropout:float=0.0, unk_index:int=None):
+
+ super(Embedding, self).__init__()
+
+ self.embed = get_embeddings(init_embed)
+
+ self.dropout = nn.Dropout(dropout)
+ if not isinstance(self.embed, TokenEmbedding):
+ if hasattr(self.embed, 'embed_size'):
+ self._embed_size = self.embed.embed_size
+ elif hasattr(self.embed, 'embedding_dim'):
+ self._embed_size = self.embed.embedding_dim
+ else:
+ self._embed_size = self.embed.weight.size(1)
+ if word_dropout > 0 and not isinstance(unk_index, int):
+ raise ValueError("When drop word is set, you need to pass in the unk_index.")
+ else:
+ self._embed_size = self.embed.embed_size
+ unk_index = self.embed.get_word_vocab().unknown_idx
+ self.unk_index = unk_index
+ self.word_dropout = word_dropout
+
+ def forward(self, words: "torch.LongTensor") -> "torch.Tensor":
+ r"""
+ :param words: 形状为 ``[batch, seq_len]``
+ :return: 形状为 ``[batch, seq_len, embed_dim]`` 的张量
+ """
+ if self.word_dropout > 0 and self.training:
+ mask = torch.ones_like(words).float() * self.word_dropout
+ mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
+ words = words.masked_fill(mask, self.unk_index)
+ words = self.embed(words)
+ return self.dropout(words)
+
+ @property
+ def num_embedding(self) -> int:
+ if isinstance(self.embed, nn.Embedding):
+ return self.embed.weight.size(0)
+ else:
+ return self.embed.num_embeddings
+
+ def __len__(self):
+ return len(self.embed)
+
+ @property
+ def embed_size(self) -> int:
+ return self._embed_size
+
+ @property
+ def embedding_dim(self) -> int:
+ return self._embed_size
+
+ @property
+ def requires_grad(self):
+ r"""
+ Embedding 的参数是否允许优化:
+
+ - ``True`` -- 所有参数运行优化
+ - ``False`` -- 所有参数不允许优化
+ - ``None`` -- 部分允许优化、部分不允许
+ :return:
+ """
+ if not isinstance(self.embed, TokenEmbedding):
+ return self.embed.weight.requires_grad
+ else:
+ return self.embed.requires_grad
+
+ @requires_grad.setter
+ def requires_grad(self, value):
+ if not isinstance(self.embed, TokenEmbedding):
+ self.embed.weight.requires_grad = value
+ else:
+ self.embed.requires_grad = value
+
+ @property
+ def size(self):
+ if isinstance(self.embed, TokenEmbedding):
+ return self.embed.size
+ else:
+ return self.embed.weight.size()
+
+
+class TokenEmbedding(Module):
+ r"""
+ fastNLP中各种Embedding的基类
+
+ """
+ def __init__(self, vocab, word_dropout=0.0, dropout=0.0):
+ super(TokenEmbedding, self).__init__()
+ if vocab.rebuild:
+ vocab.build_vocab()
+ assert vocab.padding is not None, "Vocabulary must have a padding entry."
+ self._word_vocab = vocab
+ self._word_pad_index = vocab.padding_idx
+ if word_dropout > 0:
+ assert vocab.unknown is not None, "Vocabulary must have unknown entry when you want to drop a word."
+ self.word_dropout = word_dropout
+ self._word_unk_index = vocab.unknown_idx
+ self.dropout_layer = nn.Dropout(dropout)
+
+ def drop_word(self, words):
+ r"""
+ 按照设定随机将words设置为unknown_index。
+
+ :param torch.LongTensor words: batch_size x max_len
+ :return:
+ """
+ if self.word_dropout > 0 and self.training:
+ mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
+ mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
+ pad_mask = words.ne(self._word_pad_index)
+ mask = mask.__and__(pad_mask)
+ words = words.masked_fill(mask, self._word_unk_index)
+ return words
+
+ def dropout(self, words):
+ r"""
+ 对embedding后的word表示进行drop。
+
+ :param torch.FloatTensor words: batch_size x max_len x embed_size
+ :return:
+ """
+ return self.dropout_layer(words)
+
+ @property
+ def requires_grad(self):
+ r"""
+ Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
+ :return:
+ """
+ requires_grads = set([param.requires_grad for param in self.parameters()])
+ if len(requires_grads) == 1:
+ return requires_grads.pop()
+ else:
+ return None
+
+ @requires_grad.setter
+ def requires_grad(self, value):
+ for param in self.parameters():
+ param.requires_grad = value
+
+ def __len__(self):
+ return len(self._word_vocab)
+
+ @property
+ def embed_size(self) -> int:
+ return self._embed_size
+
+ @property
+ def embedding_dim(self) -> int:
+ return self._embed_size
+
+ @property
+ def num_embeddings(self) -> int:
+ r"""
+ 这个值可能会大于实际的embedding矩阵的大小。
+ :return:
+ """
+ return len(self._word_vocab)
+
+ def get_word_vocab(self):
+ r"""
+ 返回embedding的词典。
+
+ :return: Vocabulary
+ """
+ return self._word_vocab
+
+ @property
+ def size(self):
+ return torch.Size(self.num_embeddings, self._embed_size)
+
+ @abstractmethod
+ def forward(self, words):
+ raise NotImplementedError
diff --git a/fastNLP/embeddings/torch/stack_embedding.py b/fastNLP/embeddings/torch/stack_embedding.py
new file mode 100644
index 00000000..591de7a3
--- /dev/null
+++ b/fastNLP/embeddings/torch/stack_embedding.py
@@ -0,0 +1,102 @@
+r"""
+.. todo::
+ doc
+"""
+
+__all__ = [
+ "StackEmbedding",
+]
+
+from typing import List
+
+from ...envs.imports import _NEED_IMPORT_TORCH
+if _NEED_IMPORT_TORCH:
+ import torch
+ from torch import nn
+
+from .embedding import TokenEmbedding
+from .utils import _check_vocab_has_same_index
+
+
+class StackEmbedding(TokenEmbedding):
+ r"""
+ 支持将多个 embedding 集合成一个 embedding。
+
+ Example::
+
+ >>> from fastNLP import Vocabulary
+ >>> from fastNLP.embeddings.torch import StaticEmbedding, StackEmbedding
+ >>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
+ >>> embed_1 = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d', requires_grad=True)
+ >>> embed_2 = StaticEmbedding(vocab, model_dir_or_name='en-word2vec-300', requires_grad=True)
+ >>> embed = StackEmbedding([embed_1, embed_2])
+
+ :param embeds: 一个由若干个 :class:`~fastNLP.embeddings.torch.embedding.TokenEmbedding` 组成的 :class:`list` ,要求
+ 每一个 ``TokenEmbedding`` 的词表都保持一致
+ :param word_dropout: 按照一定概率随机将 word 设置为 ``unk_index`` ,这样可以使得 ```` 这个 token 得到足够的训练,
+ 且会对网络有一定的 regularize 作用。不同 embedidng 会在相同的位置被设置为 ```` 。 如果这里设置了 dropout,则
+ 组成的 embedding 就不要再设置 dropout 了。
+ :param dropout: 以多大的概率对 embedding 的表示进行 Dropout。0.1 即随机将 10% 的值置为 0。
+ """
+
+ def __init__(self, embeds: List[TokenEmbedding], word_dropout=0, dropout=0):
+
+ vocabs = []
+ for embed in embeds:
+ if hasattr(embed, 'get_word_vocab'):
+ vocabs.append(embed.get_word_vocab())
+ _vocab = vocabs[0]
+ for vocab in vocabs[1:]:
+ if _vocab!=vocab:
+ _check_vocab_has_same_index(_vocab, vocab)
+
+ super(StackEmbedding, self).__init__(_vocab, word_dropout=word_dropout, dropout=dropout)
+ assert isinstance(embeds, list)
+ for embed in embeds:
+ assert isinstance(embed, TokenEmbedding), "Only TokenEmbedding type is supported."
+ self.embeds = nn.ModuleList(embeds)
+ self._embed_size = sum([embed.embed_size for embed in self.embeds])
+
+ def append(self, embed: TokenEmbedding):
+ r"""
+ 添加一个 embedding 到结尾。
+
+ :param embed:
+ :return: 自身
+ """
+ assert isinstance(embed, TokenEmbedding)
+ _check_vocab_has_same_index(self.get_word_vocab(), embed.get_word_vocab())
+ self._embed_size += embed.embed_size
+ self.embeds.append(embed)
+ return self
+
+ def pop(self):
+ r"""
+ 弹出最后一个 embedding
+
+ :return: 被弹出的 embedding
+ """
+ embed = self.embeds.pop()
+ self._embed_size -= embed.embed_size
+ return embed
+
+ @property
+ def embed_size(self):
+ r"""
+ 该 Embedding 输出的 vector 的最后一维的维度。
+ """
+ return self._embed_size
+
+ def forward(self, words):
+ r"""
+ 得到多个 embedding 的结果,并把结果按照顺序连接起来。
+
+ :param words: 形状为 ``[batch_size, max_len]``
+ :return: 形状和当前这个 :class:`StackEmbedding` 中 embedding 的组成有关
+ """
+ outputs = []
+ words = self.drop_word(words)
+ for embed in self.embeds:
+ outputs.append(embed(words))
+ outputs = self.dropout(torch.cat(outputs, dim=-1))
+ return outputs
diff --git a/fastNLP/embeddings/torch/static_embedding.py b/fastNLP/embeddings/torch/static_embedding.py
new file mode 100644
index 00000000..6980c851
--- /dev/null
+++ b/fastNLP/embeddings/torch/static_embedding.py
@@ -0,0 +1,410 @@
+r"""
+.. todo::
+ doc
+"""
+
+__all__ = [
+ "StaticEmbedding"
+]
+import os
+from collections import defaultdict
+from copy import deepcopy
+import json
+from typing import Callable, Union
+
+import numpy as np
+
+from fastNLP.core.log import logger
+from .embedding import TokenEmbedding
+from ...core import logger
+from ...core.vocabulary import Vocabulary
+from ...io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path
+from ...io.file_utils import _get_file_name_base_on_postfix
+from ...envs.imports import _NEED_IMPORT_TORCH
+
+if _NEED_IMPORT_TORCH:
+ import torch
+ import torch.nn as nn
+
+
+VOCAB_FILENAME = 'vocab.txt'
+STATIC_HYPER_FILENAME = 'static_hyper.json'
+STATIC_EMBED_FILENAME = 'static.txt'
+
+
+class StaticEmbedding(TokenEmbedding):
+ r"""
+ ``StaticEmbedding`` 组件。给定预训练 embedding 的名称或路径,根据 ``vocab`` 从 embedding 中抽取相应的数据(只会将出现在 ``vocab`` 中的词抽取出来,
+ 如果没有找到,则会随机初始化一个值;但如果该 word 是被标记为 ``no_create_entry`` 的话,则不会单独创建一个值,而是被指向 ```` 的 index)。
+ 当前支持自动下载的预训练 vector 有:
+
+ - ``en`` -- 实际为 ``en-glove-840b-300d`` (常用)
+ - ``en-glove-6b-50d`` -- **glove** 官方的 50d 向量
+ - ``en-glove-6b-100d`` -- **glove** 官方的 100d 向量
+ - ``en-glove-6b-200d`` -- **glove** 官方的 200d 向量
+ - ``en-glove-6b-300d`` -- **glove** 官方的 300d 向量
+ - ``en-glove-42b-300d`` -- **glove** 官方使用 42B 数据训练版本
+ - ``en-glove-840b-300d``
+ - ``en-glove-twitter-27b-25d``
+ - ``en-glove-twitter-27b-50d``
+ - ``en-glove-twitter-27b-100d``
+ - ``en-glove-twitter-27b-200d``
+ - ``en-word2vec-300d`` -- **word2vec** 官方发布的 300d 向量
+ - ``en-fasttext-crawl`` -- **fasttext** 官方发布的 300d 英文预训练
+ - ``cn-char-fastnlp-100d`` -- **fastNLP** 训练的 100d 的 character embedding
+ - ``cn-bi-fastnlp-100d`` -- **fastNLP** 训练的 100d 的 bigram embedding
+ - ``cn-tri-fastnlp-100d`` -- **fastNLP** 训练的 100d 的 trigram embedding
+ - ``cn-fasttext`` -- **fasttext** 官方发布的 300d 中文预训练 embedding
+
+ Example::
+
+ >>> from fastNLP import Vocabulary
+ >>> from fastNLP.embeddings.torch import StaticEmbedding
+ >>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
+ >>> embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-50d')
+
+ >>> vocab = Vocabulary().add_word_lst(["The", 'the', "THE"])
+ >>> embed = StaticEmbedding(vocab, model_dir_or_name="en-glove-50d", lower=True)
+ >>> # "the", "The", "THE"它们共用一个vector,且将使用"the"在预训练词表中寻找它们的初始化表示。
+
+ >>> vocab = Vocabulary().add_word_lst(["The", "the", "THE"])
+ >>> embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5, lower=True)
+ >>> words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE"]]])
+ >>> embed(words)
+ >>> tensor([[[ 0.5773, 0.7251, -0.3104, 0.0777, 0.4849],
+ [ 0.5773, 0.7251, -0.3104, 0.0777, 0.4849],
+ [ 0.5773, 0.7251, -0.3104, 0.0777, 0.4849]]],
+ grad_fn=) # 每种word的输出是一致的。
+
+ :param vocab: 词表。``StaticEmbedding`` 只会加载包含在词表中的词的词向量,在预训练向量中没找到的使用随机初始化
+ :param model_dir_or_name: 可以有两种方式调用预训练好的 :class:`StaticEmbedding` :
+
+ 1. 传入 embedding 文件夹(文件夹下应该只有一个以 **.txt** 作为后缀的文件)或文件路径;
+ 2. 传入 embedding 的名称,第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载;
+ 3. 如果输入为 ``None`` 则使用 ``embedding_dim`` 的维度随机初始化一个 embedding;
+ :param embedding_dim: 随机初始化的 embedding 的维度,当该值为大于 0 的值时,将忽略 ``model_dir_or_name`` 。
+ :param requires_grad: 是否需要梯度。
+ :param init_method: 如何初始化没有找到的值。可以使用 :mod:`torch.nn.init` 中的各种方法,传入的方法应该接受一个 tensor,并
+ inplace 地修改其值。
+ :param lower: 是否将 ``vocab`` 中的词语小写后再和预训练的词表进行匹配。如果您的词表中包含大写的词语,或者就是需要单独
+ 为大写的词语开辟一个 vector 表示,则将 ``lower`` 设置为 ``False``。
+ :param dropout: 以多大的概率对 embedding 的表示进行 Dropout。0.1 即随机将 10% 的值置为 0。
+ :param word_dropout: 按照一定概率随机将 word 设置为 ``unk_index`` ,这样可以使得 ```` 这个 token 得到足够的训练,
+ 且会对网络有一定的 regularize 作用。
+ :param normalize: 是否对 vector 进行 ``normalize`` ,使得每个 vector 的 norm 为 1。
+ :param min_freq: Vocabulary 词频数小于这个数量的 word 将被指向 ````。
+ :kwargs:
+ * *only_train_min_freq* (*bool*) -- 仅对 train 中的词语使用 ``min_freq`` 筛选
+ * *only_norm_found_vector* (*bool*) -- 默认为 ``False``,是否仅对在预训练中找到的词语使用 ``normalize``
+ * *only_use_pretrain_word* (*bool*) -- 默认为 ``False``,仅使用出现在 pretrain 词表中的词,如果该词没有在预训练的词表中出现
+ 则为 ```` 。如果 embedding 不需要更新建议设置为 ``True`` 。
+
+ """
+
+ def __init__(self, vocab: Vocabulary, model_dir_or_name: Union[str, None] = 'en', embedding_dim=-1, requires_grad: bool = True,
+ init_method: Callable = None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs):
+ super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
+ if embedding_dim > 0:
+ if model_dir_or_name:
+ logger.info(f"StaticEmbedding will ignore `model_dir_or_name`, and randomly initialize embedding with"
+ f" dimension {embedding_dim}. If you want to use pre-trained embedding, "
+ f"set `embedding_dim` to 0.")
+ model_dir_or_name = None
+
+ # 得到cache_path
+ if model_dir_or_name is None:
+ assert embedding_dim >= 1, "The dimension of embedding should be larger than 1."
+ embedding_dim = int(embedding_dim)
+ model_path = None
+ elif model_dir_or_name.lower() in PRETRAIN_STATIC_FILES:
+ model_url = _get_embedding_url('static', model_dir_or_name.lower())
+ model_path = cached_path(model_url, name='embedding')
+ # 检查是否存在
+ elif os.path.isfile(os.path.abspath(os.path.expanduser(model_dir_or_name))):
+ model_path = os.path.abspath(os.path.expanduser(model_dir_or_name))
+ elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))):
+ model_path = _get_file_name_base_on_postfix(os.path.abspath(os.path.expanduser(model_dir_or_name)), '.txt')
+ else:
+ raise ValueError(f"Cannot recognize {model_dir_or_name}.")
+
+ kwargs['min_freq'] = min_freq
+ kwargs['lower'] = lower
+ # 根据min_freq缩小vocab
+ truncate_vocab = (vocab.min_freq is None and min_freq > 1) or (vocab.min_freq and vocab.min_freq < min_freq)
+ if truncate_vocab:
+ truncated_vocab = deepcopy(vocab)
+ truncated_vocab.min_freq = min_freq
+ truncated_vocab.word2idx = None
+ if lower: # 如果有lower,将大小写的的freq需要同时考虑到
+ lowered_word_count = defaultdict(int)
+ for word, count in truncated_vocab.word_count.items():
+ lowered_word_count[word.lower()] += count
+ for word in truncated_vocab.word_count.keys():
+ word_count = truncated_vocab.word_count[word]
+ if lowered_word_count[word.lower()] >= min_freq and word_count < min_freq:
+ truncated_vocab.add_word_lst([word] * (min_freq - word_count),
+ no_create_entry=truncated_vocab._is_word_no_create_entry(word))
+
+ # 只限制在train里面的词语使用min_freq筛选
+ if kwargs.get('only_train_min_freq', False) and model_dir_or_name is not None:
+ for word in truncated_vocab.word_count.keys():
+ if truncated_vocab._is_word_no_create_entry(word) and truncated_vocab.word_count[word] < min_freq:
+ truncated_vocab.add_word_lst([word] * (min_freq - truncated_vocab.word_count[word]),
+ no_create_entry=True)
+ truncated_vocab.build_vocab()
+ truncated_words_to_words = torch.arange(len(vocab)).long()
+ for word, index in vocab:
+ truncated_words_to_words[index] = truncated_vocab.to_index(word)
+ logger.info(f"{len(vocab) - len(truncated_vocab)} words have frequency less than {min_freq}.")
+ vocab = truncated_vocab
+
+ self.only_use_pretrain_word = kwargs.get('only_use_pretrain_word', False)
+ self.only_norm_found_vector = kwargs.get('only_norm_found_vector', False)
+ # 读取embedding
+ if lower:
+ lowered_vocab = Vocabulary(padding=vocab.padding, unknown=vocab.unknown)
+ for word, index in vocab:
+ if vocab._is_word_no_create_entry(word):
+ lowered_vocab.add_word(word.lower(), no_create_entry=True)
+ else:
+ lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的
+ logger.info(f"All word in the vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} "
+ f"unique lowered words.")
+ if model_path:
+ embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method)
+ else:
+ embedding = self._randomly_init_embed(len(lowered_vocab), embedding_dim, init_method)
+ self.register_buffer('words_to_words', torch.arange(len(vocab)).long())
+ if lowered_vocab.unknown:
+ unknown_idx = lowered_vocab.unknown_idx
+ else:
+ unknown_idx = embedding.size(0) - 1 # 否则是最后一个为unknow
+ self.register_buffer('words_to_words', torch.arange(len(vocab)).long())
+ words_to_words = torch.full((len(vocab),), fill_value=unknown_idx, dtype=torch.long).long()
+ for word, index in vocab:
+ if word not in lowered_vocab:
+ word = word.lower()
+ if word not in lowered_vocab and lowered_vocab._is_word_no_create_entry(word):
+ continue # 如果不需要创建entry,已经默认unknown了
+ words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)]
+ self.register_buffer('words_to_words', words_to_words)
+ self._word_unk_index = lowered_vocab.unknown_idx # 替换一下unknown的index
+ else:
+ if model_path:
+ embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method)
+ else:
+ embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method)
+ self.register_buffer('words_to_words', torch.arange(len(vocab)).long())
+ if not self.only_norm_found_vector and normalize:
+ embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12)
+
+ if truncate_vocab:
+ for i in range(len(truncated_words_to_words)):
+ index_in_truncated_vocab = truncated_words_to_words[i]
+ truncated_words_to_words[i] = self.words_to_words[index_in_truncated_vocab]
+ del self.words_to_words
+ self.register_buffer('words_to_words', truncated_words_to_words)
+ self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1],
+ padding_idx=vocab.padding_idx,
+ max_norm=None, norm_type=2, scale_grad_by_freq=False,
+ sparse=False, _weight=embedding)
+ self._embed_size = self.embedding.weight.size(1)
+ self.requires_grad = requires_grad
+ self.kwargs = kwargs
+
+ @property
+ def weight(self):
+ return self.embedding.weight
+
+ def _randomly_init_embed(self, num_embedding, embedding_dim, init_embed=None):
+ r"""
+
+ :param int num_embedding: embedding的entry的数量
+ :param int embedding_dim: embedding的维度大小
+ :param callable init_embed: 初始化方法
+ :return: torch.FloatTensor
+ """
+ embed = torch.zeros(num_embedding, embedding_dim)
+
+ if init_embed is None:
+ nn.init.uniform_(embed, -np.sqrt(3 / embedding_dim), np.sqrt(3 / embedding_dim))
+ else:
+ init_embed(embed)
+
+ return embed
+
+ def _load_with_vocab(self, embed_filepath, vocab, dtype=np.float32, padding='', unknown='',
+ error='ignore', init_method=None):
+ r"""
+ 从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是
+ word2vec(第一行只有两个元素)还是glove格式的数据。
+
+ :param str embed_filepath: 预训练的embedding的路径。
+ :param vocab: 词表 :class:`~fastNLP.Vocabulary` 类型,读取出现在vocab中的词的embedding。
+ 没有出现在vocab中的词的embedding将通过找到的词的embedding的正态分布采样出来,以使得整个Embedding是同分布的。
+ :param dtype: 读出的embedding的类型
+ :param str padding: 词表中padding的token
+ :param str unknown: 词表中unknown的token
+ :param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。
+ 这里主要可能出错的地方在于词表有空行或者词表出现了维度不一致。
+ :param init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。默认使用torch.nn.init.zeros_
+ :return torch.tensor: shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。
+ """
+ assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported."
+ if not os.path.exists(embed_filepath):
+ raise FileNotFoundError("`{}` does not exist.".format(embed_filepath))
+ with open(embed_filepath, 'r', encoding='utf-8') as f:
+ line = f.readline().strip()
+ parts = line.split()
+ start_idx = 0
+ if len(parts) == 2:
+ dim = int(parts[1])
+ start_idx += 1
+ else:
+ dim = len(parts) - 1
+ f.seek(0)
+ matrix = {} # index是word在vocab中的index,value是vector或None(如果在pretrain中没有找到该word)
+ if vocab.padding:
+ matrix[vocab.padding_idx] = torch.zeros(dim)
+ if vocab.unknown:
+ matrix[vocab.unknown_idx] = torch.zeros(dim)
+ found_count = 0
+ found_unknown = False
+ for idx, line in enumerate(f, start_idx):
+ try:
+ parts = line.strip().split()
+ word = ''.join(parts[:-dim])
+ nums = parts[-dim:]
+ # 对齐unk与pad
+ if word == padding and vocab.padding is not None:
+ word = vocab.padding
+ elif word == unknown and vocab.unknown is not None:
+ word = vocab.unknown
+ found_unknown = True
+ if word in vocab:
+ index = vocab.to_index(word)
+ if index in matrix:
+ logger.warning(f"Word has more than one vector in embedding file. Set logger level to "
+ f"DEBUG for detail.")
+ logger.debug(f"Word:{word} occurs again in line:{idx}(starts from 0)")
+ matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim))
+ if self.only_norm_found_vector:
+ matrix[index] = matrix[index] / np.linalg.norm(matrix[index])
+ found_count += 1
+ except Exception as e:
+ if error == 'ignore':
+ logger.warning("Error occurred at the {} line.".format(idx))
+ else:
+ logger.error("Error occurred at the {} line.".format(idx))
+ raise e
+ logger.info("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab)))
+ if not self.only_use_pretrain_word: # 如果只用pretrain中的值就不要为未找到的词创建entry了
+ for word, index in vocab:
+ if index not in matrix and not vocab._is_word_no_create_entry(word):
+ if found_unknown: # 如果有unkonwn,用unknown初始化
+ matrix[index] = matrix[vocab.unknown_idx]
+ else:
+ matrix[index] = None
+ # matrix中代表是需要建立entry的词
+ vectors = self._randomly_init_embed(len(matrix), dim, init_method)
+
+ if vocab.unknown is None: # 创建一个专门的unknown
+ unknown_idx = len(matrix)
+ vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous()
+ else:
+ unknown_idx = vocab.unknown_idx
+ self.register_buffer('words_to_words', torch.full((len(vocab), ), fill_value=unknown_idx, dtype=torch.long).long())
+ index = 0
+ for word, index_in_vocab in vocab:
+ if index_in_vocab in matrix:
+ vec = matrix.get(index_in_vocab)
+ if vec is not None: # 使用找到的vector, 如果为None说明需要训练
+ vectors[index] = vec
+ self.words_to_words[index_in_vocab] = index
+ index += 1
+
+ return vectors
+
+ def forward(self, words: "torch.LongTensor") -> "torch.FloatTensor":
+ r"""
+ 传入 ``words`` 的 index
+
+ :param words: 形状为 ``[batch, seq_len]``
+ :return: 形状为 ``[batch, seq_len, embed_dim]`` 的张量
+ """
+ if hasattr(self, 'words_to_words'):
+ words = self.words_to_words[words]
+ words = self.drop_word(words)
+ words = self.embedding(words)
+ words = self.dropout(words)
+ return words
+
+ def save(self, folder: str):
+ """
+ 将 embedding 存储到 ``folder`` 下,之后可以通过使用 :meth:`load` 方法读取
+
+ :param folder: 会在该 ``folder`` 下生成三个文件:
+
+ - ``vocab.txt``,可以通过 :meth:`fastNLP.core.Vocabulary.load` 读取;
+ - ``embedding.txt`` 按照 *word2vec* 的方式存储,以空格的方式隔开元素,第一行只有两个元素,剩下的行首先是
+ word 然后是各个维度的值;
+ - ``static_embed_hyper.json``,:class:`StaticEmbedding` 的超参数;
+ """
+ os.makedirs(folder, exist_ok=True)
+
+ vocab = self.get_word_vocab()
+ vocab_fp = os.path.join(folder, VOCAB_FILENAME)
+ vocab.save(vocab_fp)
+ kwargs = self.kwargs.copy()
+ kwargs['dropout'] = self.dropout_layer.p
+ kwargs['word_dropout'] = self.word_dropout
+ kwargs['requires_grad'] = self.requires_grad
+ kwargs['only_norm_found_vector'] = False
+ kwargs['only_use_pretrain_word'] = True
+
+ with open(os.path.join(folder, STATIC_HYPER_FILENAME), 'w', encoding='utf-8') as f:
+ json.dump(kwargs, f, indent=2)
+
+ with open(os.path.join(folder, STATIC_EMBED_FILENAME), 'w', encoding='utf-8') as f:
+ f.write('{}\n'.format(' '*30)) # 留白之后再来填写
+ word_count = 0
+ saved_word = {}
+ valid_word_count = 0
+ for i in range(len(self.words_to_words)):
+ word = vocab.to_word(i)
+ if not vocab._is_word_no_create_entry(word):
+ word_count += 1
+ if kwargs['lower']:
+ word = word.lower()
+ if word in saved_word:
+ continue
+ saved_word[word] = 1
+ vec_i = self.words_to_words[i]
+ if vec_i==vocab.unknown_idx and i!=vocab.unknown_idx:
+ continue
+ vec = self.embedding.weight.data[vec_i].tolist()
+ vec_str = ' '.join(map(str, vec))
+ f.write(f'{word} {vec_str}\n')
+ valid_word_count += 1
+ f.seek(0)
+ f.write('{} {}'.format(valid_word_count, self.embedding_dim))
+ logger.debug(f"StaticEmbedding has been saved to {folder}.")
+
+ @classmethod
+ def load(cls, folder: str):
+ """
+
+ :param folder: 该 ``folder`` 下应该有以下三个文件 ``vocab.txt``, ``static_embed.txt``, ``static_hyper.json``
+ :return: 加载后的 embedding
+ """
+ for name in [VOCAB_FILENAME, STATIC_EMBED_FILENAME, STATIC_HYPER_FILENAME]:
+ assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}."
+
+ vocab = Vocabulary.load(os.path.join(folder, VOCAB_FILENAME))
+ with open(os.path.join(folder, STATIC_HYPER_FILENAME), 'r', encoding='utf-8') as f:
+ hyper = json.load(f)
+
+ logger.info(f"Load StaticEmbedding from {folder}.")
+ embed = cls(vocab=vocab, model_dir_or_name=os.path.join(folder, STATIC_EMBED_FILENAME), **hyper)
+ return embed
+
diff --git a/fastNLP/embeddings/torch/utils.py b/fastNLP/embeddings/torch/utils.py
new file mode 100644
index 00000000..28521980
--- /dev/null
+++ b/fastNLP/embeddings/torch/utils.py
@@ -0,0 +1,108 @@
+r"""
+.. todo::
+ doc
+"""
+import numpy as np
+from ...envs.imports import _NEED_IMPORT_TORCH
+if _NEED_IMPORT_TORCH:
+ import torch
+ from torch import nn as nn
+
+from ...core.vocabulary import Vocabulary
+
+__all__ = [
+ 'get_embeddings',
+ 'get_sinusoid_encoding_table'
+]
+
+
+def _construct_char_vocab_from_vocab(vocab: Vocabulary, min_freq: int = 1, include_word_start_end=True):
+ r"""
+ 给定一个word的vocabulary生成character的vocabulary.
+
+ :param vocab: 从vocab
+ :param min_freq:
+ :param include_word_start_end: 是否需要包含特殊的和
+ :return:
+ """
+ char_vocab = Vocabulary(min_freq=min_freq)
+ for word, index in vocab:
+ if not vocab._is_word_no_create_entry(word):
+ char_vocab.add_word_lst(list(word))
+ if include_word_start_end:
+ char_vocab.add_word_lst(['', ''])
+ return char_vocab
+
+
+def get_embeddings(init_embed, padding_idx=None):
+ r"""
+ 根据输入的 ``init_embed`` 返回 ``Embedding`` 对象。
+
+ :param init_embed: 支持以下几种输入类型:
+
+ - ``tuple(num_embedings, embedding_dim)``,即 embedding 的大小和每个词的维度,此时将随机初始化一个 :class:`torch.nn.Embedding` 实例;
+ - :class:`torch.nn.Embedding` 或 **fastNLP** 的 ``Embedding`` 对象,此时就以传入的对象作为 embedding;
+ - :class:`numpy.ndarray` ,将使用传入的 ndarray 作为 Embedding 初始化;
+ - :class:`torch.Tensor`,此时将使用传入的值作为 Embedding 初始化;
+
+ :param padding_idx: 当传入 :class:`tuple` 时,``padding_idx`` 有效
+ :return:
+ """
+ if isinstance(init_embed, tuple):
+ res = nn.Embedding(
+ num_embeddings=init_embed[0], embedding_dim=init_embed[1], padding_idx=padding_idx)
+ nn.init.uniform_(res.weight.data, a=-np.sqrt(3 / res.weight.data.size(1)),
+ b=np.sqrt(3 / res.weight.data.size(1)))
+ elif isinstance(init_embed, nn.Module):
+ res = init_embed
+ elif isinstance(init_embed, torch.Tensor):
+ res = nn.Embedding.from_pretrained(init_embed, freeze=False)
+ elif isinstance(init_embed, np.ndarray):
+ init_embed = torch.tensor(init_embed, dtype=torch.float32)
+ res = nn.Embedding.from_pretrained(init_embed, freeze=False)
+ else:
+ raise TypeError(
+ 'invalid init_embed type: {}'.format((type(init_embed))))
+ return res
+
+
+def get_sinusoid_encoding_table(n_position: int, d_hid: int, padding_idx=None) -> "torch.FloatTensor":
+ """
+ sinusoid 的 embedding,其中 ``position`` 的表示中,偶数维 ``(0,2,4,...)`` 是 sin,奇数 ``(1,3,5...)`` 是 cos。
+
+ :param int n_position: 一共多少个 position
+ :param int d_hid: 多少维度,需要为偶数
+ :param padding_idx:
+ :return: 形状为 ``[n_position, d_hid]`` 的张量
+ """
+
+ def cal_angle(position, hid_idx):
+ return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
+
+ def get_posi_angle_vec(position):
+ return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
+
+ sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])
+
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
+
+ if padding_idx is not None:
+ # zero vector for padding dimension
+ sinusoid_table[padding_idx] = 0.
+
+ return torch.FloatTensor(sinusoid_table)
+
+
+def _check_vocab_has_same_index(vocab, other_vocab):
+ """
+ 检查两个vocabulary是否含有相同的word idx
+
+ :param Vocabulary vocab:
+ :param Vocabulary other_vocab:
+ :return:
+ """
+ if other_vocab != vocab:
+ for word, word_ix in vocab:
+ other_word_idx = other_vocab.to_index(word)
+ assert other_word_idx == word_ix, f"Word {word} has different index in vocabs, {word_ix} Vs. {other_word_idx}."
\ No newline at end of file
diff --git a/fastNLP/embeddings/transformers_embedding.py b/fastNLP/embeddings/transformers_embedding.py
deleted file mode 100644
index 4b15ea37..00000000
--- a/fastNLP/embeddings/transformers_embedding.py
+++ /dev/null
@@ -1,502 +0,0 @@
-r"""
-将transformers包中的模型封装成fastNLP中的embedding对象
-
-"""
-import os
-from itertools import chain
-from functools import partial
-
-from torch import nn
-import numpy as np
-import torch
-
-from .contextual_embedding import ContextualEmbedding
-from ..core import logger
-from ..core.vocabulary import Vocabulary
-
-
-__all__ = ['TransformersEmbedding', 'TransformersWordPieceEncoder']
-
-
-class TransformersEmbedding(ContextualEmbedding):
- r"""
- 使用transformers中的模型对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于
- 预训练的bert模型长度限制为512个token,而因为输入的word是未进行word piece分割的(word piece的分割由TransformersEmbedding在输入word
- 时切分),在分割之后长度可能会超过最大长度限制。
-
- Example::
-
- >>> import torch
- >>> from fastNLP import Vocabulary
- >>> from fastNLP.embeddings import TransformersEmbedding
- >>> from transformers import ElectraModel, ElectraTokenizer
- >>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
- >>> model = ElectraModel.from_pretrained("google/electra-small-generator")
- >>> tokenizer = ElectraTokenizer.from_pretrained("google/electra-small-generator")
- >>> embed = TransformersEmbedding(vocab, model_dir_or_name='en', requires_grad=False, layers='4,-2,-1')
- >>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]])
- >>> outputs = embed(words)
- >>> outputs.size()
- >>> # torch.Size([1, 5, 2304])
-
- """
- def __init__(self, vocab, model, tokenizer, layers='-1',
- pool_method: str = 'first', word_dropout=0, dropout=0, requires_grad=True,
- include_cls_sep: bool = False, auto_truncate=True, **kwargs):
- r"""
-
- :param ~fastNLP.Vocabulary vocab: 词表
- :model model: transformers包中的PreTrainedModel对象
- :param tokenizer: transformers包中的PreTrainedTokenizer对象
- :param str,list layers: 输出embedding表示来自于哪些层,不同层的结果按照layers中的顺序在最后一维concat起来。以','隔开层数,层的序号是
- 从0开始,可以以负数去索引倒数几层。layer=0为embedding层(包括wordpiece embedding, position embedding)
- :param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces
- 中计算得到它对应的表示。支持 ``last`` , ``first`` , ``avg`` , ``max``。
- :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
- :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
- :param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样
- 会使得word embedding的结果比输入的结果长两个token。如果该值为True,则在使用 :class::StackEmbedding 可能会与其它类型的
- embedding长度不匹配。
- :param bool pooled_cls: 返回的是否使用预训练中的BertPool映射一下,仅在include_cls_sep时有效。如果下游任务只取做预测,
- 一般该值为True。
- :param bool requires_grad: 是否需要gradient以更新Bert的权重。
- :param bool auto_truncate: 当句子words拆分为word pieces长度超过bert最大允许长度(一般为512), 自动截掉拆分后的超过510个
- word pieces后的内容,并将第512个word piece置为。超过长度的部分的encode结果直接全部置零。一般仅有只使用
- 来进行分类的任务将auto_truncate置为True。
- :param kwargs:
- int min_freq: 小于该次数的词会被unk代替, 默认为1
- dict tokenizer_kwargs: 传递给tokenizer在调用tokenize()方法时所额外使用的参数,例如RoBERTaTokenizer需要传入
- {'add_prefix_space':True}
- """
- super().__init__(vocab, word_dropout=word_dropout, dropout=dropout)
-
- if word_dropout > 0:
- assert vocab.unknown is not None, "When word_drop > 0, Vocabulary must contain the unknown token."
-
- self._word_sep_index = -100
- if tokenizer.sep_token in vocab:
- self._word_sep_index = vocab[tokenizer.sep_token]
-
- self._word_cls_index = -100
- if tokenizer.cls_token in vocab:
- self._word_cls_index = vocab[tokenizer.cls_token]
-
- min_freq = kwargs.get('min_freq', 1)
- self._min_freq = min_freq
-
- tokenizer_kwargs = kwargs.get('tokenizer_kwargs', {})
- self.model = _TransformersWordModel(tokenizer=tokenizer, model=model, vocab=vocab, layers=layers,
- pool_method=pool_method, include_cls_sep=include_cls_sep,
- auto_truncate=auto_truncate, min_freq=min_freq, tokenizer_kwargs=tokenizer_kwargs)
-
- self.requires_grad = requires_grad
- self._embed_size = len(self.model.layers) * model.config.hidden_size
-
- def forward(self, words):
- r"""
- 计算words的roberta embedding表示。计算之前会在每句话的开始增加在结束增加, 并根据include_cls_sep判断要不要
- 删除这两个token的表示。
-
- :param torch.LongTensor words: [batch_size, max_len]
- :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
- """
- words = self.drop_word(words)
- outputs = self._get_sent_reprs(words)
- if outputs is not None:
- return self.dropout(outputs)
- outputs = self.model(words)
- outputs = torch.cat([*outputs], dim=-1)
-
- return self.dropout(outputs)
-
- def drop_word(self, words):
- r"""
- 按照设定随机将words设置为unknown_index。
-
- :param torch.LongTensor words: batch_size x max_len
- :return:
- """
- if self.word_dropout > 0 and self.training:
- with torch.no_grad():
- mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
- mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
- pad_mask = words.ne(self._word_pad_index)
- mask = pad_mask.__and__(mask) # pad的位置不为unk
- if self._word_sep_index!=-100:
- not_sep_mask = words.ne(self._word_sep_index)
- mask = mask.__and__(not_sep_mask)
- if self._word_cls_index!=-100:
- not_cls_mask = words.ne(self._word_cls_index)
- mask = mask.__and__(not_cls_mask)
- words = words.masked_fill(mask, self._word_unk_index)
- return words
-
- def save(self, folder):
- """
- 保存tokenizer和model到folder文件夹。model保存在`folder/{model_name}`, tokenizer在`folder/{tokenizer_name}`下
- :param str folder: 保存地址
- :return:
- """
- os.makedirs(folder, exist_ok=True)
- self.model.save(folder)
-
-
-class TransformersWordPieceEncoder(nn.Module):
- r"""
- 读取roberta模型,读取之后调用index_dataset方法在dataset中生成word_pieces这一列。
-
- RobertaWordPieceEncoder可以支持自动下载权重,当前支持的模型:
- en: roberta-base
- en-large: roberta-large
-
- """
- def __init__(self, model, tokenizer, layers: str = '-1',
- word_dropout=0, dropout=0, requires_grad: bool = True, **kwargs):
- r"""
-
- :param model: transformers的model
- :param tokenizer: transformer的tokenizer
- :param str layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层。layer=0为embedding层(包括wordpiece embedding,
- position embedding)
- :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
- :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
- :param bool requires_grad: 是否需要gradient。
- """
- super().__init__()
-
- self.model = _WordPieceTransformersModel(model=model, tokenizer=tokenizer, layers=layers)
- self._sep_index = self.model._sep_index
- self._cls_index = self.model._cls_index
- self._wordpiece_pad_index = self.model._wordpiece_pad_index
- self._wordpiece_unk_index = self.model._wordpiece_unknown_index
- self._embed_size = len(self.model.layers) * self.model.config.hidden_size
- self.requires_grad = requires_grad
- self.word_dropout = word_dropout
- self.dropout_layer = nn.Dropout(dropout)
-
- @property
- def embed_size(self):
- return self._embed_size
-
- @property
- def embedding_dim(self):
- return self._embed_size
-
- @property
- def num_embedding(self):
- return self.model.encoder.config.vocab_size
-
- def index_datasets(self, *datasets, field_name, **kwargs):
- r"""
- 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了
- bert的pad value。
-
- :param ~fastNLP.DataSet datasets: DataSet对象
- :param str field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是raw_string的形式。
- :param kwargs: 传递给tokenizer的参数
- :return:
- """
- self.model.index_datasets(*datasets, field_name=field_name, **kwargs)
-
- def forward(self, word_pieces, token_type_ids=None):
- r"""
- 计算words的bert embedding表示。传入的words中应该自行包含[CLS]与[SEP]的tag。
-
- :param words: batch_size x max_len
- :param token_type_ids: batch_size x max_len, 用于区分前一句和后一句话. 如果不传入,则自动生成(大部分情况,都不需要输入),
- 第一个[SEP]及之前为0, 第二个[SEP]及到第一个[SEP]之间为1; 第三个[SEP]及到第二个[SEP]之间为0,依次往后推。
- :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
- """
- word_pieces = self.drop_word(word_pieces)
- outputs = self.model(word_pieces)
- outputs = torch.cat([*outputs], dim=-1)
-
- return self.dropout_layer(outputs)
-
- def drop_word(self, words):
- r"""
- 按照设定随机将words设置为unknown_index。
-
- :param torch.LongTensor words: batch_size x max_len
- :return:
- """
- if self.word_dropout > 0 and self.training:
- with torch.no_grad():
- not_sep_mask = words.ne(self._sep_index)
- not_cls_mask = words.ne(self._cls_index)
- replaceable_mask = not_sep_mask.__and__(not_cls_mask)
- mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
- mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
- pad_mask = words.ne(self._wordpiece_pad_index)
- mask = pad_mask.__and__(mask).__and__(replaceable_mask) # pad的位置不为unk
- words = words.masked_fill(mask, self._wordpiece_unk_index)
- return words
-
- def save(self, folder):
- os.makedirs(folder, exist_ok=True)
- self.model.save(os.path.join(folder, folder))
- logger.debug(f"TransformersWordPieceEncoder has been saved in {folder}")
-
-
-class _TransformersWordModel(nn.Module):
- def __init__(self, tokenizer, model, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first',
- include_cls_sep: bool = False, auto_truncate: bool = False, min_freq=2, tokenizer_kwargs={}):
- super().__init__()
-
- self.tokenizer = tokenizer
- self.encoder = model
- self.config = model.config
- self.only_last_layer = True
- if not (isinstance(layers, str) and (layers=='-1' or int(layers)==self.encoder.config.num_hidden_layers)):
- assert self.encoder.config.output_hidden_states == True, \
- f"You have to output all hidden states if you want to" \
- f" access the middle output of `{model.__class__.__name__}` "
- self.only_last_layer = False
-
- self._max_position_embeddings = self.encoder.config.max_position_embeddings - 2
- # 检查encoder_layer_number是否合理
- encoder_layer_number = len(self.encoder.encoder.layer)
- self.encoder_layer_number = encoder_layer_number
- if isinstance(layers, list):
- self.layers = [int(l) for l in layers]
- elif isinstance(layers, str):
- self.layers = list(map(int, layers.split(',')))
- else:
- raise TypeError("`layers` only supports str or list[int]")
-
- for layer in self.layers:
- if layer < 0:
- assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
- f"a {model.__class__.__name__} model with {encoder_layer_number} layers."
- else:
- assert layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
- f"a {model.__class__.__name__} model with {encoder_layer_number} layers."
-
- assert pool_method in ('avg', 'max', 'first', 'last')
- self.pool_method = pool_method
- self.include_cls_sep = include_cls_sep
- self.auto_truncate = auto_truncate
-
- word_to_wordpieces = []
- word_pieces_lengths = []
- for word, index in vocab:
- if index == vocab.padding_idx: # pad是个特殊的符号
- word = tokenizer.pad_token
- elif index == vocab.unknown_idx:
- word = tokenizer.unk_token
- elif vocab.word_count[word] self._max_position_embeddings:
- if self.auto_truncate:
- word_pieces_lengths = word_pieces_lengths.masked_fill(
- word_pieces_lengths + 2 > self._max_position_embeddings, self._max_position_embeddings - 2)
- else:
- raise RuntimeError(
- "After split words into word pieces, the lengths of word pieces are longer than the "
- f"maximum allowed sequence length:{self._max_position_embeddings} of bert. You can set "
- f"`auto_truncate=True` for BertEmbedding to automatically truncate overlong input.")
-
- # +2是由于需要加入与
- word_pieces = words.new_full((batch_size, min(max_word_piece_length + 2, self._max_position_embeddings)),
- fill_value=self._wordpiece_pad_index)
- attn_masks = torch.zeros_like(word_pieces)
- # 1. 获取words的word_pieces的id,以及对应的span范围
- word_indexes = words.cpu().numpy()
- for i in range(batch_size):
- word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i, :seq_len[i]]]))
- if self.auto_truncate and len(word_pieces_i) > self._max_position_embeddings - 2:
- word_pieces_i = word_pieces_i[:self._max_position_embeddings - 2]
- word_pieces[i, 1:word_pieces_lengths[i] + 1] = torch.LongTensor(word_pieces_i)
- attn_masks[i, :word_pieces_lengths[i] + 2].fill_(1)
- word_pieces[:, 0].fill_(self._cls_index)
- batch_indexes = torch.arange(batch_size).to(words)
- word_pieces[batch_indexes, word_pieces_lengths + 1] = self._sep_index
- token_type_ids = torch.zeros_like(word_pieces)
- # 2. 获取hidden的结果,根据word_pieces进行对应的pool计算
- # all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...]
- all_outputs = self.encoder(input_ids=word_pieces, token_type_ids=token_type_ids,
- attention_mask=attn_masks)
- if not self.only_last_layer:
- for _ in all_outputs:
- if isinstance(_, (tuple, list)) and len(_)==self.encoder_layer_number:
- bert_outputs = _
- break
- else:
- bert_outputs = all_outputs[:1]
- # output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size
-
- if self.include_cls_sep:
- s_shift = 1
- outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2,
- bert_outputs[-1].size(-1))
-
- else:
- s_shift = 0
- outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len,
- bert_outputs[-1].size(-1))
- batch_word_pieces_cum_length = batch_word_pieces_length.new_zeros(batch_size, max_word_len + 1)
- batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len
-
- if self.pool_method == 'first':
- batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()]
- batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(max_word_piece_length), 0)
- _batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
- elif self.pool_method == 'last':
- batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, 1:seq_len.max() + 1] - 1
- batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(max_word_piece_length), 0)
- _batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
-
- for l_index, l in enumerate(self.layers):
- output_layer = bert_outputs[l]
- real_word_piece_length = output_layer.size(1) - 2
- if max_word_piece_length > real_word_piece_length: # 如果实际上是截取出来的
- paddings = output_layer.new_zeros(batch_size,
- max_word_piece_length - real_word_piece_length,
- output_layer.size(2))
- output_layer = torch.cat((output_layer, paddings), dim=1).contiguous()
- # 从word_piece collapse到word的表示
- truncate_output_layer = output_layer[:, 1:-1] # 删除与 batch_size x len x hidden_size
- if self.pool_method == 'first':
- tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length]
- tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0)
- outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1) + s_shift] = tmp
-
- elif self.pool_method == 'last':
- tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length]
- tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0)
- outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1) + s_shift] = tmp
- elif self.pool_method == 'max':
- for i in range(batch_size):
- for j in range(seq_len[i]):
- start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1]
- outputs[l_index, i, j + s_shift], _ = torch.max(truncate_output_layer[i, start:end], dim=-2)
- else:
- for i in range(batch_size):
- for j in range(seq_len[i]):
- start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1]
- outputs[l_index, i, j + s_shift] = torch.mean(truncate_output_layer[i, start:end], dim=-2)
- if self.include_cls_sep:
- outputs[l_index, :, 0] = output_layer[:, 0]
- outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, word_pieces_lengths + s_shift]
-
- # 3. 最终的embedding结果
- return outputs
-
- def save(self, folder):
- self.tokenzier.save_pretrained(folder)
- self.encoder.save_pretrained(folder)
-
-
-class _WordPieceTransformersModel(nn.Module):
- def __init__(self, model, tokenizer, layers: str = '-1'):
- super().__init__()
-
- self.tokenizer = tokenizer
- self.encoder = model
- self.config = self.encoder.config
- # 检查encoder_layer_number是否合理
- encoder_layer_number = len(self.encoder.encoder.layer)
- self.only_last_layer = True
- if not (isinstance(layers, str) and (layers=='-1' or int(layers)==self.encoder.config.num_hidden_layers)):
- assert self.encoder.config.output_hidden_states == True, \
- f"You have to output all hidden states if you want to" \
- f" access the middle output of `{model.__class__.__name__}` "
- self.only_last_layer = False
-
- if isinstance(layers, list):
- self.layers = [int(l) for l in layers]
- elif isinstance(layers, str):
- self.layers = list(map(int, layers.split(',')))
- else:
- raise TypeError("`layers` only supports str or list[int]")
-
- for layer in self.layers:
- if layer < 0:
- assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
- f"a RoBERTa model with {encoder_layer_number} layers."
- else:
- assert layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
- f"a RoBERTa model with {encoder_layer_number} layers."
-
- self._cls_index = self.tokenizer.cls_token_id
- self._sep_index = self.tokenizer.sep_token_id
- self._wordpiece_pad_index = self.tokenizer.pad_token_id # 需要用于生成word_piece
- self._wordpiece_unknown_index = self.tokenizer.unk_token_id
-
- def index_datasets(self, *datasets, field_name, **kwargs):
- r"""
- 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是
- [CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。
-
- :param datasets: DataSet对象
- :param field_name: 基于哪一列index
- :param kwargs: 传递给tokenizer的参数
- :return:
- """
- kwargs['add_special_tokens'] = kwargs.get('add_special_tokens', True)
- kwargs['add_prefix_space'] = kwargs.get('add_special_tokens', True)
-
- encode_func = partial(self.tokenizer.encode, **kwargs)
-
- for index, dataset in enumerate(datasets):
- try:
- dataset.apply_field(encode_func, field_name=field_name, new_field_name='word_pieces',
- is_input=True)
- dataset.set_pad_val('word_pieces', self._wordpiece_pad_index)
- except Exception as e:
- logger.error(f"Exception happens when processing the {index} dataset.")
- raise e
-
- def forward(self, word_pieces):
- r"""
-
- :param word_pieces: torch.LongTensor, batch_size x max_len
- :return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size
- """
- batch_size, max_len = word_pieces.size()
-
- attn_masks = word_pieces.ne(self._wordpiece_pad_index)
- all_outputs = self.encoder(word_pieces, token_type_ids=torch.zeros_like(word_pieces),
- attention_mask=attn_masks)
- if not self.only_last_layer:
- for _ in all_outputs:
- if isinstance(_, (tuple, list)) and len(_)==self.encoder_layer_number:
- roberta_outputs = _
- break
- else:
- roberta_outputs = all_outputs[:1]
- # output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size
- outputs = roberta_outputs[0].new_zeros((len(self.layers), batch_size, max_len, roberta_outputs[0].size(-1)))
- for l_index, l in enumerate(self.layers):
- roberta_output = roberta_outputs[l]
- outputs[l_index] = roberta_output
- return outputs
-
- def save(self, folder):
- self.tokenizer.save_pretrained(folder)
- self.encoder.save_pretrained(folder)
diff --git a/fastNLP/embeddings/utils.py b/fastNLP/embeddings/utils.py
deleted file mode 100644
index 9a18bfe3..00000000
--- a/fastNLP/embeddings/utils.py
+++ /dev/null
@@ -1,104 +0,0 @@
-r"""
-.. todo::
- doc
-"""
-import numpy as np
-import torch
-from torch import nn as nn
-
-from ..core.vocabulary import Vocabulary
-
-__all__ = [
- 'get_embeddings',
- 'get_sinusoid_encoding_table'
-]
-
-
-def _construct_char_vocab_from_vocab(vocab: Vocabulary, min_freq: int = 1, include_word_start_end=True):
- r"""
- 给定一个word的vocabulary生成character的vocabulary.
-
- :param vocab: 从vocab
- :param min_freq:
- :param include_word_start_end: 是否需要包含特殊的和
- :return:
- """
- char_vocab = Vocabulary(min_freq=min_freq)
- for word, index in vocab:
- if not vocab._is_word_no_create_entry(word):
- char_vocab.add_word_lst(list(word))
- if include_word_start_end:
- char_vocab.add_word_lst(['