@@ -2,7 +2,7 @@ pipeline { | |||||
agent { | agent { | ||||
docker { | docker { | ||||
image 'ubuntu_tester' | image 'ubuntu_tester' | ||||
args '-u root:root -v ${HOME}/html/docs:/docs -v ${HOME}/html/_ci:/ci' | |||||
args '-u root:root -v ${JENKINS_HOME}/html/docs:/docs -v ${JENKINS_HOME}/html/_ci:/ci' | |||||
} | } | ||||
} | } | ||||
environment { | environment { | ||||
@@ -27,7 +27,6 @@ pipeline { | |||||
} | } | ||||
stage('Package Testing') { | stage('Package Testing') { | ||||
steps { | steps { | ||||
sh 'python -m spacy download en' | |||||
sh 'pip install fitlog' | sh 'pip install fitlog' | ||||
sh 'pytest ./tests --html=test_results.html --self-contained-html' | sh 'pytest ./tests --html=test_results.html --self-contained-html' | ||||
} | } | ||||
@@ -13,7 +13,7 @@ install: | |||||
- pip install pytest-cov | - pip install pytest-cov | ||||
# command to run tests | # command to run tests | ||||
script: | script: | ||||
- python -m spacy download en | |||||
# - python -m spacy download en | |||||
- pytest --cov=fastNLP tests/ | - pytest --cov=fastNLP tests/ | ||||
after_success: | after_success: | ||||
@@ -46,7 +46,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 1, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -66,22 +66,9 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 2, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"In total 3 datasets:\n", | |||||
"\tdev has 1200 instances.\n", | |||||
"\ttrain has 9600 instances.\n", | |||||
"\ttest has 1200 instances.\n", | |||||
"In total 0 vocabs:\n", | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"print(data_bundle)" | "print(data_bundle)" | ||||
] | ] | ||||
@@ -95,20 +82,9 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 6, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"DataSet({'raw_chars': 选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般 type=str,\n", | |||||
"'target': 1 type=str},\n", | |||||
"{'raw_chars': 15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错 type=str,\n", | |||||
"'target': 1 type=str})\n" | |||||
] | |||||
} | |||||
], | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"print(data_bundle.get_dataset('train')[:2]) # 查看Train集前两个sample" | "print(data_bundle.get_dataset('train')[:2]) # 查看Train集前两个sample" | ||||
] | ] | ||||
@@ -125,7 +101,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 3, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -137,24 +113,9 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 4, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"In total 3 datasets:\n", | |||||
"\tdev has 1200 instances.\n", | |||||
"\ttrain has 9600 instances.\n", | |||||
"\ttest has 1200 instances.\n", | |||||
"In total 2 vocabs:\n", | |||||
"\tchars has 4409 entries.\n", | |||||
"\ttarget has 2 entries.\n", | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"print(data_bundle) # 打印data_bundle,查看其变化" | "print(data_bundle) # 打印data_bundle,查看其变化" | ||||
] | ] | ||||
@@ -168,24 +129,9 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 5, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"DataSet({'raw_chars': 选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般 type=str,\n", | |||||
"'target': 1 type=int,\n", | |||||
"'chars': [338, 464, 1400, 784, 468, 739, 3, 289, 151, 21, 5, 88, 143, 2, 9, 81, 134, 2573, 766, 233, 196, 23, 536, 342, 297, 2, 405, 698, 132, 281, 74, 744, 1048, 74, 420, 387, 74, 412, 433, 74, 2021, 180, 8, 219, 1929, 213, 4, 34, 31, 96, 363, 8, 230, 2, 66, 18, 229, 331, 768, 4, 11, 1094, 479, 17, 35, 593, 3, 1126, 967, 2, 151, 245, 12, 44, 2, 6, 52, 260, 263, 635, 5, 152, 162, 4, 11, 336, 3, 154, 132, 5, 236, 443, 3, 2, 18, 229, 761, 700, 4, 11, 48, 59, 653, 2, 8, 230] type=list,\n", | |||||
"'seq_len': 106 type=int},\n", | |||||
"{'raw_chars': 15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错 type=str,\n", | |||||
"'target': 1 type=int,\n", | |||||
"'chars': [50, 133, 20, 135, 945, 520, 343, 24, 3, 301, 176, 350, 86, 785, 2, 456, 24, 461, 163, 443, 128, 109, 6, 47, 7, 2, 916, 152, 162, 524, 296, 44, 301, 176, 2, 1384, 524, 296, 259, 88, 143, 2, 92, 67, 26, 12, 277, 269, 2, 188, 223, 26, 228, 83, 6, 63] type=list,\n", | |||||
"'seq_len': 56 type=int})\n" | |||||
] | |||||
} | |||||
], | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"print(data_bundle.get_dataset('train')[:2])" | "print(data_bundle.get_dataset('train')[:2])" | ||||
] | ] | ||||
@@ -199,17 +145,9 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 6, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Vocabulary(['选', '择', '珠', '江', '花']...)\n" | |||||
] | |||||
} | |||||
], | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"char_vocab = data_bundle.get_vocab('chars')\n", | "char_vocab = data_bundle.get_vocab('chars')\n", | ||||
"print(char_vocab)" | "print(char_vocab)" | ||||
@@ -224,18 +162,9 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 7, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"'选'的index是338\n", | |||||
"index:338对应的汉字是选\n" | |||||
] | |||||
} | |||||
], | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"index = char_vocab.to_index('选')\n", | "index = char_vocab.to_index('选')\n", | ||||
"print(\"'选'的index是{}\".format(index)) # 这个值与上面打印出来的第一个instance的chars的第一个index是一致的\n", | "print(\"'选'的index是{}\".format(index)) # 这个值与上面打印出来的第一个instance的chars的第一个index是一致的\n", | ||||
@@ -252,17 +181,9 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 8, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Found 4321 out of 4409 words in the pre-training embedding.\n" | |||||
] | |||||
} | |||||
], | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"from fastNLP.embeddings import StaticEmbedding\n", | "from fastNLP.embeddings import StaticEmbedding\n", | ||||
"\n", | "\n", | ||||
@@ -279,7 +200,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 9, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -323,288 +244,9 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 10, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"input fields after batch(if batch size is 2):\n", | |||||
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"\tchars: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 106]) \n", | |||||
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"target fields after batch(if batch size is 2):\n", | |||||
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"\n", | |||||
"Evaluate data in 0.01 seconds!\n", | |||||
"training epochs started 2019-09-03-23-57-10\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=3000), HTML(value='')), layout=Layout(display…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.43 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 1/10. Step:300/3000: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.81\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.44 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 2/10. Step:600/3000: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.8675\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.44 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 3/10. Step:900/3000: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.878333\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.43 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 4/10. Step:1200/3000: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.873333\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.44 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 5/10. Step:1500/3000: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.878333\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.42 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 6/10. Step:1800/3000: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.895833\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.44 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 7/10. Step:2100/3000: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.8975\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.43 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 8/10. Step:2400/3000: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.894167\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.48 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 9/10. Step:2700/3000: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.8875\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.43 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 10/10. Step:3000/3000: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.895833\n", | |||||
"\n", | |||||
"\r\n", | |||||
"In Epoch:7/Step:2100, got best dev performance:\n", | |||||
"AccuracyMetric: acc=0.8975\n", | |||||
"Reloaded the best model.\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=19), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.34 seconds!\n", | |||||
"[tester] \n", | |||||
"AccuracyMetric: acc=0.8975\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"{'AccuracyMetric': {'acc': 0.8975}}" | |||||
] | |||||
}, | |||||
"execution_count": 10, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"from fastNLP import Trainer\n", | "from fastNLP import Trainer\n", | ||||
"from fastNLP import CrossEntropyLoss\n", | "from fastNLP import CrossEntropyLoss\n", | ||||
@@ -637,139 +279,9 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 12, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"loading vocabulary file /home/yh/.fastNLP/embedding/bert-chinese-wwm/vocab.txt\n", | |||||
"Load pre-trained BERT parameters from file /home/yh/.fastNLP/embedding/bert-chinese-wwm/chinese_wwm_pytorch.bin.\n", | |||||
"Start to generating word pieces for word.\n", | |||||
"Found(Or segment into word pieces) 4286 words out of 4409.\n", | |||||
"input fields after batch(if batch size is 2):\n", | |||||
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"\tchars: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 106]) \n", | |||||
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"target fields after batch(if batch size is 2):\n", | |||||
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"\n", | |||||
"Evaluate data in 0.05 seconds!\n", | |||||
"training epochs started 2019-09-04-00-02-37\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=3600), HTML(value='')), layout=Layout(display…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 15.89 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 1/3. Step:1200/3600: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.9\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 15.92 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 2/3. Step:2400/3600: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.904167\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 15.91 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 3/3. Step:3600/3600: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.918333\n", | |||||
"\n", | |||||
"\r\n", | |||||
"In Epoch:3/Step:3600, got best dev performance:\n", | |||||
"AccuracyMetric: acc=0.918333\n", | |||||
"Reloaded the best model.\n", | |||||
"Performance on test is:\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=19), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 29.24 seconds!\n", | |||||
"[tester] \n", | |||||
"AccuracyMetric: acc=0.919167\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"{'AccuracyMetric': {'acc': 0.919167}}" | |||||
] | |||||
}, | |||||
"execution_count": 12, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"# 只需要切换一下Embedding即可\n", | "# 只需要切换一下Embedding即可\n", | ||||
"from fastNLP.embeddings import BertEmbedding\n", | "from fastNLP.embeddings import BertEmbedding\n", | ||||
@@ -802,6 +314,224 @@ | |||||
"tester.test()" | "tester.test()" | ||||
] | ] | ||||
}, | }, | ||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### 基于词进行文本分类" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"由于汉字中没有显示的字与字的边界,一般需要通过分词器先将句子进行分词操作。\n", | |||||
"下面的例子演示了如何不基于fastNLP已有的数据读取、预处理代码进行文本分类。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### (1) 读取数据" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"这里我们继续以之前的数据为例,但这次我们不使用fastNLP自带的数据读取代码 " | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from fastNLP.io import ChnSentiCorpLoader\n", | |||||
"\n", | |||||
"loader = ChnSentiCorpLoader() # 初始化一个中文情感分类的loader\n", | |||||
"data_dir = loader.download() # 这一行代码将自动下载数据到默认的缓存地址, 并将该地址返回" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"下面我们先定义一个read_file_to_dataset的函数, 即给定一个文件路径,读取其中的内容,并返回一个DataSet。然后我们将所有的DataSet放入到DataBundle对象中来方便接下来的预处理" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"import os\n", | |||||
"from fastNLP import DataSet, Instance\n", | |||||
"from fastNLP.io import DataBundle\n", | |||||
"\n", | |||||
"\n", | |||||
"def read_file_to_dataset(fp):\n", | |||||
" ds = DataSet()\n", | |||||
" with open(fp, 'r') as f:\n", | |||||
" f.readline() # 第一行是title名称,忽略掉\n", | |||||
" for line in f:\n", | |||||
" line = line.strip()\n", | |||||
" target, chars = line.split('\\t')\n", | |||||
" ins = Instance(target=target, raw_chars=chars)\n", | |||||
" ds.append(ins)\n", | |||||
" return ds\n", | |||||
"\n", | |||||
"data_bundle = DataBundle()\n", | |||||
"for name in ['train.tsv', 'dev.tsv', 'test.tsv']:\n", | |||||
" fp = os.path.join(data_dir, name)\n", | |||||
" ds = read_file_to_dataset(fp)\n", | |||||
" data_bundle.set_dataset(name=name.split('.')[0], dataset=ds)\n", | |||||
"\n", | |||||
"print(data_bundle) # 查看以下数据集的情况\n", | |||||
"# In total 3 datasets:\n", | |||||
"# train has 9600 instances.\n", | |||||
"# dev has 1200 instances.\n", | |||||
"# test has 1200 instances." | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### (2) 数据预处理" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"在这里,我们首先把句子通过 [fastHan](http://gitee.com/fastnlp/fastHan) 进行分词操作,然后创建词表,并将词语转换为序号。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from fastHan import FastHan\n", | |||||
"from fastNLP import Vocabulary\n", | |||||
"\n", | |||||
"model=FastHan()\n", | |||||
"# model.set_device('cuda')\n", | |||||
"\n", | |||||
"# 定义分词处理操作\n", | |||||
"def word_seg(ins):\n", | |||||
" raw_chars = ins['raw_chars']\n", | |||||
" # 由于有些句子比较长,我们只截取前128个汉字\n", | |||||
" raw_words = model(raw_chars[:128], target='CWS')[0]\n", | |||||
" return raw_words\n", | |||||
"\n", | |||||
"for name, ds in data_bundle.iter_datasets():\n", | |||||
" # apply函数将对内部的instance依次执行word_seg操作,并把其返回值放入到raw_words这个field\n", | |||||
" ds.apply(word_seg, new_field_name='raw_words')\n", | |||||
" # 除了apply函数,fastNLP还支持apply_field, apply_more(可同时创建多个field)等操作\n", | |||||
" # 同时我们增加一个seq_len的field\n", | |||||
" ds.add_seq_len('raw_words')\n", | |||||
"\n", | |||||
"vocab = Vocabulary()\n", | |||||
"\n", | |||||
"# 对raw_words列创建词表, 建议把非训练集的dataset放在no_create_entry_dataset参数中\n", | |||||
"# 也可以通过add_word(), add_word_lst()等建立词表,请参考http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_2_vocabulary.html\n", | |||||
"vocab.from_dataset(data_bundle.get_dataset('train'), field_name='raw_words', \n", | |||||
" no_create_entry_dataset=[data_bundle.get_dataset('dev'), \n", | |||||
" data_bundle.get_dataset('test')]) \n", | |||||
"\n", | |||||
"# 将建立好词表的Vocabulary用于对raw_words列建立词表,并把转为序号的列存入到words列\n", | |||||
"vocab.index_dataset(data_bundle.get_dataset('train'), data_bundle.get_dataset('dev'), \n", | |||||
" data_bundle.get_dataset('test'), field_name='raw_words', new_field_name='words')\n", | |||||
"\n", | |||||
"# 建立target的词表,target的词表一般不需要padding和unknown\n", | |||||
"target_vocab = Vocabulary(padding=None, unknown=None) \n", | |||||
"# 一般情况下我们可以只用训练集建立target的词表\n", | |||||
"target_vocab.from_dataset(data_bundle.get_dataset('train'), field_name='target') \n", | |||||
"# 如果没有传递new_field_name, 则默认覆盖原词表\n", | |||||
"target_vocab.index_dataset(data_bundle.get_dataset('train'), data_bundle.get_dataset('dev'), \n", | |||||
" data_bundle.get_dataset('test'), field_name='target')\n", | |||||
"\n", | |||||
"# 我们可以把词表保存到data_bundle中,方便之后使用\n", | |||||
"data_bundle.set_vocab(field_name='words', vocab=vocab)\n", | |||||
"data_bundle.set_vocab(field_name='target', vocab=target_vocab)\n", | |||||
"\n", | |||||
"# 我们把words和target分别设置为input和target,这样它们才会在训练循环中被取出并自动padding, 有关这部分更多的内容参考\n", | |||||
"# http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_6_datasetiter.html\n", | |||||
"data_bundle.set_target('target')\n", | |||||
"data_bundle.set_input('words', 'seq_len') # DataSet也有这两个接口\n", | |||||
"# 如果某些field,您希望它被设置为target或者input,但是不希望fastNLP自动padding或需要使用特定的padding方式,请参考\n", | |||||
"# http://www.fastnlp.top/docs/fastNLP/fastNLP.core.dataset.html\n", | |||||
"\n", | |||||
"print(data_bundle.get_dataset('train')[:2]) # 我们可以看一下当前dataset的内容\n", | |||||
"\n", | |||||
"# 由于之后需要使用之前定义的BiLSTMMaxPoolCls模型,所以需要将words这个field修改为chars(因为该模型的forward接受chars参数)\n", | |||||
"data_bundle.rename_field('words', 'chars')" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### (3) 选择预训练词向量" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"这里我们选择腾讯的预训练中文词向量,可以在 [腾讯词向量](https://ai.tencent.com/ailab/nlp/en/embedding.html) 处下载并解压。这里我们不能直接使用BERT,因为BERT是基于中文字进行预训练的。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from fastNLP.embeddings import StaticEmbedding\n", | |||||
"\n", | |||||
"word2vec_embed = StaticEmbedding(data_bundle.get_vocab('words'), \n", | |||||
" model_dir_or_name='/path/to/Tencent_AILab_ChineseEmbedding.txt')" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from fastNLP import Trainer\n", | |||||
"from fastNLP import CrossEntropyLoss\n", | |||||
"from torch.optim import Adam\n", | |||||
"from fastNLP import AccuracyMetric\n", | |||||
"\n", | |||||
"# 初始化模型\n", | |||||
"model = BiLSTMMaxPoolCls(word2vec_embed, len(data_bundle.get_vocab('target')))\n", | |||||
"\n", | |||||
"# 开始训练\n", | |||||
"loss = CrossEntropyLoss()\n", | |||||
"optimizer = Adam(model.parameters(), lr=0.001)\n", | |||||
"metric = AccuracyMetric()\n", | |||||
"device = 0 if torch.cuda.is_available() else 'cpu' # 如果有gpu的话在gpu上运行,训练速度会更快\n", | |||||
"\n", | |||||
"trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, loss=loss, \n", | |||||
" optimizer=optimizer, batch_size=32, dev_data=data_bundle.get_dataset('dev'),\n", | |||||
" metrics=metric, device=device)\n", | |||||
"trainer.train() # 开始训练,训练完成之后默认会加载在dev上表现最好的模型\n", | |||||
"\n", | |||||
"# 在测试集上测试一下模型的性能\n", | |||||
"from fastNLP import Tester\n", | |||||
"print(\"Performance on test is:\")\n", | |||||
"tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)\n", | |||||
"tester.test()" | |||||
] | |||||
}, | |||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": null, | "execution_count": null, | ||||
@@ -826,7 +556,7 @@ | |||||
"name": "python", | "name": "python", | ||||
"nbconvert_exporter": "python", | "nbconvert_exporter": "python", | ||||
"pygments_lexer": "ipython3", | "pygments_lexer": "ipython3", | ||||
"version": "3.6.7" | |||||
"version": "3.6.8" | |||||
} | } | ||||
}, | }, | ||||
"nbformat": 4, | "nbformat": 4, | ||||
@@ -86,7 +86,7 @@ fastNLP中的Vocabulary | |||||
# 将验证集或者测试集在建立词表是放入no_create_entry_dataset这个参数中。 | # 将验证集或者测试集在建立词表是放入no_create_entry_dataset这个参数中。 | ||||
vocab.from_dataset(tr_data, field_name='chars', no_create_entry_dataset=[dev_data]) | vocab.from_dataset(tr_data, field_name='chars', no_create_entry_dataset=[dev_data]) | ||||
:class:`~fastNLP.Vocabulary` 中的 `no_create_entry` , 建议在添加来自于测试集和验证集的词的时候将该参数置为True, 或将验证集和测试集 | |||||
:class:`~fastNLP.Vocabulary` 中的 `no_create_entry` ,如果您并不关心具体的原理,您可以直接采取以下的建议:在添加来自于非训练集的词的时候将该参数置为True, 或将非训练集数据 | |||||
传入 `no_create_entry_dataset` 参数。它们的意义是在接下来的模型会使用pretrain的embedding(包括glove, word2vec, elmo与bert)且会finetune的 | 传入 `no_create_entry_dataset` 参数。它们的意义是在接下来的模型会使用pretrain的embedding(包括glove, word2vec, elmo与bert)且会finetune的 | ||||
情况下,如果仅使用来自于train的数据建立vocabulary,会导致只出现在test与dev中的词语无法充分利用到来自于预训练embedding的信息(因为他们 | 情况下,如果仅使用来自于train的数据建立vocabulary,会导致只出现在test与dev中的词语无法充分利用到来自于预训练embedding的信息(因为他们 | ||||
会被认为是unk),所以在建立词表的时候将test与dev考虑进来会使得最终的结果更好。 | 会被认为是unk),所以在建立词表的时候将test与dev考虑进来会使得最终的结果更好。 | ||||
@@ -11,7 +11,7 @@ | |||||
1, 商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错! | 1, 商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错! | ||||
其中开头的1是只这条评论的标签,表示是正面的情绪。我们将使用到的数据可以通过 `此链接 <http://212.129.155.247/dataset/chn_senti_corp.zip>`_ | |||||
其中开头的1是只这条评论的标签,表示是正面的情绪。我们将使用到的数据可以通过 `此链接 <http://download.fastnlp.top/dataset/chn_senti_corp.zip>`_ | |||||
下载并解压,当然也可以通过fastNLP自动下载该数据。 | 下载并解压,当然也可以通过fastNLP自动下载该数据。 | ||||
数据中的内容如下图所示。接下来,我们将用fastNLP在这个数据上训练一个分类网络。 | 数据中的内容如下图所示。接下来,我们将用fastNLP在这个数据上训练一个分类网络。 | ||||
@@ -163,8 +163,7 @@ Vocabulary是一个记录着词语与index之间映射关系的类,比如 | |||||
(3) 选择预训练词向量 | (3) 选择预训练词向量 | ||||
~~~~~~~~~~~~~~~~~~~~ | ~~~~~~~~~~~~~~~~~~~~ | ||||
由于Word2vec, Glove, Elmo, | |||||
Bert等预训练模型可以增强模型的性能,所以在训练具体任务前,选择合适的预训练词向量非常重要。 | |||||
由于Word2vec, Glove, Elmo, Bert等预训练模型可以增强模型的性能,所以在训练具体任务前,选择合适的预训练词向量非常重要。 | |||||
在fastNLP中我们提供了多种Embedding使得加载这些预训练模型的过程变得更加便捷。 | 在fastNLP中我们提供了多种Embedding使得加载这些预训练模型的过程变得更加便捷。 | ||||
这里我们先给出一个使用word2vec的中文汉字预训练的示例,之后再给出一个使用Bert的文本分类。 | 这里我们先给出一个使用word2vec的中文汉字预训练的示例,之后再给出一个使用Bert的文本分类。 | ||||
这里使用的预训练词向量为'cn-fastnlp-100d',fastNLP将自动下载该embedding至本地缓存, | 这里使用的预训练词向量为'cn-fastnlp-100d',fastNLP将自动下载该embedding至本地缓存, | ||||
@@ -291,7 +290,7 @@ fastNLP提供了Trainer对象来组织训练过程,包括完成loss计算(所 | |||||
使用Bert进行文本分类 | |||||
PS: 使用Bert进行文本分类 | |||||
~~~~~~~~~~~~~~~~~~~~ | ~~~~~~~~~~~~~~~~~~~~ | ||||
.. code-block:: python | .. code-block:: python | ||||
@@ -368,6 +367,170 @@ fastNLP提供了Trainer对象来组织训练过程,包括完成loss计算(所 | |||||
{'AccuracyMetric': {'acc': 0.919167}} | {'AccuracyMetric': {'acc': 0.919167}} | ||||
PS: 基于词进行文本分类 | |||||
~~~~~~~~~~~~~~~~~~~~ | |||||
由于汉字中没有显示的字与字的边界,一般需要通过分词器先将句子进行分词操作。 | |||||
下面的例子演示了如何不基于fastNLP已有的数据读取、预处理代码进行文本分类。 | |||||
(1) 读取数据 | |||||
~~~~~~~~~~~~~~~~~~~~ | |||||
这里我们继续以之前的数据为例,但这次我们不使用fastNLP自带的数据读取代码 | |||||
.. code-block:: python | |||||
from fastNLP.io import ChnSentiCorpLoader | |||||
loader = ChnSentiCorpLoader() # 初始化一个中文情感分类的loader | |||||
data_dir = loader.download() # 这一行代码将自动下载数据到默认的缓存地址, 并将该地址返回 | |||||
获取到的data_dir下应该有类似以下的文件 | |||||
.. code-block:: text | |||||
- chn_senti_corp | |||||
- train.tsv | |||||
- dev.tsv | |||||
- test.tsv | |||||
如果打开任何一个文件查看,会发现里面的格式均为 | |||||
.. code-block:: text | |||||
target raw_chars | |||||
1 这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般 | |||||
0 怀着十分激动的心情放映... | |||||
下面我们先定义一个read_file_to_dataset的函数, 即给定一个文件路径,读取其中的内容,并返回一个DataSet。然后我们将所有的DataSet放入到DataBundle对象中来方便接下来的预处理 | |||||
.. code-block:: python | |||||
import os | |||||
from fastNLP import DataSet, Instance | |||||
from fastNLP.io import DataBundle | |||||
def read_file_to_dataset(fp): | |||||
ds = DataSet() | |||||
with open(fp, 'r') as f: | |||||
f.readline() # 第一行是title名称,忽略掉 | |||||
for line in f: | |||||
line = line.strip() | |||||
target, chars = line.split('\t') | |||||
ins = Instance(target=target, raw_chars=chars) | |||||
ds.append(ins) | |||||
return ds | |||||
data_bundle = DataBundle() | |||||
for name in ['train.tsv', 'dev.tsv', 'test.tsv']: | |||||
fp = os.path.join(data_dir, name) | |||||
ds = read_file_to_dataset(fp) | |||||
data_bundle.set_dataset(name=name.split('.')[0], dataset=ds) | |||||
print(data_bundle) # 查看以下数据集的情况 | |||||
# In total 3 datasets: | |||||
# train has 9600 instances. | |||||
# dev has 1200 instances. | |||||
# test has 1200 instances. | |||||
(2) 数据预处理 | |||||
~~~~~~~~~~~~~~~~~~~~ | |||||
在这里,我们首先把句子通过 fastHan_ 进行分词操作,然后创建词表,并将词语转换为序号。 | |||||
.. _fastHan: https://gitee.com/fastnlp/fastHan | |||||
.. code-block:: python | |||||
from fastHan import FastHan | |||||
from fastNLP import Vocabulary | |||||
model=FastHan() | |||||
# model.set_device('cuda') # 可以注视掉这一行增加速度 | |||||
# 定义分词处理操作 | |||||
def word_seg(ins): | |||||
raw_chars = ins['raw_chars'] | |||||
# 由于有些句子比较长,我们只截取前128个汉字 | |||||
raw_words = model(raw_chars[:128], target='CWS')[0] | |||||
return raw_words | |||||
for name, ds in data_bundle.iter_datasets(): | |||||
# apply函数将对内部的instance依次执行word_seg操作,并把其返回值放入到raw_words这个field | |||||
ds.apply(word_seg, new_field_name='raw_words') | |||||
# 除了apply函数,fastNLP还支持apply_field, apply_more(可同时创建多个field)等操作 | |||||
# 同时我们增加一个seq_len的field | |||||
ds.add_seq_len('raw_words') | |||||
vocab = Vocabulary() | |||||
# 对raw_words列创建词表, 建议把非训练集的dataset放在no_create_entry_dataset参数中 | |||||
# 也可以通过add_word(), add_word_lst()等建立词表,请参考http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_2_vocabulary.html | |||||
vocab.from_dataset(data_bundle.get_dataset('train'), field_name='raw_words', | |||||
no_create_entry_dataset=[data_bundle.get_dataset('dev'), | |||||
data_bundle.get_dataset('test')]) | |||||
# 将建立好词表的Vocabulary用于对raw_words列建立词表,并把转为序号的列存入到words列 | |||||
vocab.index_dataset(data_bundle.get_dataset('train'), data_bundle.get_dataset('dev'), | |||||
data_bundle.get_dataset('test'), field_name='raw_words', new_field_name='words') | |||||
# 建立target的词表,target的词表一般不需要padding和unknown | |||||
target_vocab = Vocabulary(padding=None, unknown=None) | |||||
# 一般情况下我们可以只用训练集建立target的词表 | |||||
target_vocab.from_dataset(data_bundle.get_dataset('train'), field_name='target') | |||||
# 如果没有传递new_field_name, 则默认覆盖原词表 | |||||
target_vocab.index_dataset(data_bundle.get_dataset('train'), data_bundle.get_dataset('dev'), | |||||
data_bundle.get_dataset('test'), field_name='target') | |||||
# 我们可以把词表保存到data_bundle中,方便之后使用 | |||||
data_bundle.set_vocab(field_name='words', vocab=vocab) | |||||
data_bundle.set_vocab(field_name='target', vocab=target_vocab) | |||||
# 我们把words和target分别设置为input和target,这样它们才会在训练循环中被取出并自动padding, 有关这部分更多的内容参考 | |||||
# http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_6_datasetiter.html | |||||
data_bundle.set_target('target') | |||||
data_bundle.set_input('words') # DataSet也有这两个接口 | |||||
# 如果某些field,您希望它被设置为target或者input,但是不希望fastNLP自动padding或需要使用特定的padding方式,请参考 | |||||
# http://www.fastnlp.top/docs/fastNLP/fastNLP.core.dataset.html | |||||
print(data_bundle.get_dataset('train')[:2]) # 我们可以看一下当前dataset的内容 | |||||
# +--------+-----------------------+-----------------------+----------------------+ | |||||
# | target | raw_chars | raw_words | words | | |||||
# +--------+-----------------------+-----------------------+----------------------+ | |||||
# | 0 | 选择珠江花园的原因... | ['选择', '珠江', ... | [2, 3, 4, 5, 6, 7... | | |||||
# | 0 | 15.4寸笔记本的键盘... | ['15.4', '寸', '笔... | [71, 72, 73, 74, ... | | |||||
# +--------+-----------------------+-----------------------+----------------------+ | |||||
# 由于之后需要使用之前定义的BiLSTMMaxPoolCls模型,所以需要将words这个field修改为chars | |||||
data_bundle.rename_field('words', 'chars') | |||||
我们可以打印一下vocab看一下当前的词表内容 | |||||
.. code-block:: python | |||||
print(data_bundle.get_vocab('chars')) | |||||
# Vocabulary([选择, 珠江, 花园, 的, 原因]...) | |||||
(3) 选择预训练词向量 | |||||
~~~~~~~~~~~~~~~~~~~~ | |||||
这里我们选择腾讯的预训练中文词向量,可以在 腾讯词向量_ 处下载并解压。这里我们不能直接使用BERT,因为BERT是基于中文字进行预训练的。 | |||||
.. _腾讯词向量: https://ai.tencent.com/ailab/nlp/en/embedding.html | |||||
下面我们使用 :mod:`fastNLP.embeddings` 加载该词向量,fastNLP会抽取vocabulary中包含的词的向量,并随机初始化不包含在文件中的词语的词向量。 | |||||
.. code-block:: python | |||||
from fastNLP.embeddings import StaticEmbedding | |||||
word2vec_embed = StaticEmbedding(data_bundle.get_vocab('chars'), model_dir_or_name='/path/to/Tencent_AILab_ChineseEmbedding.txt') | |||||
再之后的模型定义与训练过程与上面是一致的,这里就不再赘述了。 | |||||
---------------------------------- | ---------------------------------- | ||||
代码下载 | 代码下载 | ||||
@@ -376,3 +539,4 @@ fastNLP提供了Trainer对象来组织训练过程,包括完成loss计算(所 | |||||
.. raw:: html | .. raw:: html | ||||
<a href="../_static/notebooks/%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB.ipynb" download="文本分类.ipynb">点击下载 IPython Notebook 文件 </a><hr> | <a href="../_static/notebooks/%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB.ipynb" download="文本分类.ipynb">点击下载 IPython Notebook 文件 </a><hr> | ||||
@@ -62,6 +62,7 @@ __all__ = [ | |||||
"CrossEntropyLoss", | "CrossEntropyLoss", | ||||
"L1Loss", | "L1Loss", | ||||
"BCELoss", | "BCELoss", | ||||
"BCEWithLogits", | |||||
"NLLLoss", | "NLLLoss", | ||||
"LossInForward", | "LossInForward", | ||||
"CMRC2018Loss", | "CMRC2018Loss", | ||||
@@ -98,7 +99,7 @@ from .dataset import DataSet | |||||
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | ||||
from .instance import Instance | from .instance import Instance | ||||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, \ | from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, \ | ||||
LossInForward, CMRC2018Loss, LossBase, MSELoss | |||||
LossInForward, CMRC2018Loss, LossBase, MSELoss, BCEWithLogits | |||||
from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\ | from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\ | ||||
ConfusionMatrixMetric | ConfusionMatrixMetric | ||||
from .optimizer import Optimizer, SGD, Adam, AdamW | from .optimizer import Optimizer, SGD, Adam, AdamW | ||||
@@ -86,7 +86,6 @@ except: | |||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .tester import Tester | from .tester import Tester | ||||
from ._logger import logger | from ._logger import logger | ||||
from .utils import _check_fp16 | |||||
from ._parallel_utils import _model_contains_inner_module | from ._parallel_utils import _model_contains_inner_module | ||||
try: | try: | ||||
@@ -94,11 +93,6 @@ try: | |||||
except: | except: | ||||
pass | pass | ||||
try: | |||||
from apex import amp | |||||
except: | |||||
amp = None | |||||
class Callback(object): | class Callback(object): | ||||
r""" | r""" | ||||
@@ -123,6 +117,20 @@ class Callback(object): | |||||
该属性可以通过self.trainer获取到,一般情况下不需要使用这个属性。 | 该属性可以通过self.trainer获取到,一般情况下不需要使用这个属性。 | ||||
""" | """ | ||||
return self._trainer | return self._trainer | ||||
@property | |||||
def grad_scaler(self): | |||||
r""" | |||||
float16的gradient scaler | |||||
""" | |||||
return self._trainer.grad_scaler | |||||
@property | |||||
def auto_cast(self): | |||||
r""" | |||||
float16用的auto cast环境 | |||||
""" | |||||
return self._trainer.auto_cast | |||||
@property | @property | ||||
def step(self): | def step(self): | ||||
@@ -206,7 +214,7 @@ class Callback(object): | |||||
def on_batch_begin(self, batch_x, batch_y, indices): | def on_batch_begin(self, batch_x, batch_y, indices): | ||||
r""" | r""" | ||||
每次采集到一个batch的数据则调用一次。这里对batch_x或batch_y删除添加内容是可以影响到Trainer中内容的。所以在这一步 | 每次采集到一个batch的数据则调用一次。这里对batch_x或batch_y删除添加内容是可以影响到Trainer中内容的。所以在这一步 | ||||
可以进行一些负采样之类的操作 | |||||
可以进行一些负采样之类的操作。batch_x和batch_y中的tensor已经被放置到了模型所在的设备上。 | |||||
:param dict batch_x: DataSet中被设置为input的field的batch。 | :param dict batch_x: DataSet中被设置为input的field的batch。 | ||||
:param dict batch_y: DataSet中被设置为target的field的batch。 | :param dict batch_y: DataSet中被设置为target的field的batch。 | ||||
@@ -472,14 +480,12 @@ class GradientClipCallback(Callback): | |||||
def on_backward_end(self): | def on_backward_end(self): | ||||
if self.step%self.update_every==0: | if self.step%self.update_every==0: | ||||
if self.parameters is None: | |||||
if getattr(self.trainer, 'fp16', ''): | |||||
_check_fp16() | |||||
self.clip_fun(amp.master_params(self.optimizer), self.clip_value) | |||||
else: | |||||
self.clip_fun(self.model.parameters(), self.clip_value) | |||||
else: | |||||
if self.trainer.fp16: | |||||
self.grad_scaler.unscale_(self.optimizer) | |||||
if self.parameters is not None: | |||||
self.clip_fun(self.parameters, self.clip_value) | self.clip_fun(self.parameters, self.clip_value) | ||||
else: | |||||
self.clip_fun(self.model.parameters(), self.clip_value) | |||||
class EarlyStopCallback(Callback): | class EarlyStopCallback(Callback): | ||||
@@ -569,10 +575,10 @@ class FitlogCallback(Callback): | |||||
if len(self.datasets) > 0: | if len(self.datasets) > 0: | ||||
for key, data in self.datasets.items(): | for key, data in self.datasets.items(): | ||||
tester = Tester(data=data, model=self.model, | tester = Tester(data=data, model=self.model, | ||||
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), | |||||
batch_size=self.trainer.kwargs.get('dev_batch_size', self.trainer.batch_size), | |||||
metrics=self.trainer.metrics, | metrics=self.trainer.metrics, | ||||
verbose=0, | verbose=0, | ||||
use_tqdm=self.trainer.test_use_tqdm, | |||||
use_tqdm=self.trainer.kwargs.get('test_use_tqdm', self.trainer.use_tqdm), | |||||
sampler=self.trainer.kwargs.get('test_sampler', None)) | sampler=self.trainer.kwargs.get('test_sampler', None)) | ||||
self.testers[key] = tester | self.testers[key] = tester | ||||
fitlog.add_progress(total_steps=self.n_steps) | fitlog.add_progress(total_steps=self.n_steps) | ||||
@@ -948,6 +954,8 @@ class CheckPointCallback(Callback): | |||||
model = model.module | model = model.module | ||||
model.load_state_dict(states['model']) | model.load_state_dict(states['model']) | ||||
self.optimizer.load_state_dict(states['optimizer']) | self.optimizer.load_state_dict(states['optimizer']) | ||||
if 'grad_scaler' in states: | |||||
self.grad_scaler.load_state_dict(states['grad_scaler']) | |||||
self.trainer.epoch = states['epoch'] + 1 # 因为是结束储存的,所以需要从下一个epoch开始 | self.trainer.epoch = states['epoch'] + 1 # 因为是结束储存的,所以需要从下一个epoch开始 | ||||
self.trainer.step = states['step'] | self.trainer.step = states['step'] | ||||
if 'best_dev_epoch' in states: | if 'best_dev_epoch' in states: | ||||
@@ -970,6 +978,7 @@ class CheckPointCallback(Callback): | |||||
model = model.module | model = model.module | ||||
states['model'] = {name:param.cpu() for name, param in model.state_dict().items()} | states['model'] = {name:param.cpu() for name, param in model.state_dict().items()} | ||||
states['optimizer'] = self.optimizer.state_dict() | states['optimizer'] = self.optimizer.state_dict() | ||||
states['grad_scaler'] = self.grad_scaler.state_dict() | |||||
states['epoch'] = self.epoch | states['epoch'] = self.epoch | ||||
states['step'] = self.step | states['step'] = self.step | ||||
if self.trainer.best_dev_epoch is not None: | if self.trainer.best_dev_epoch is not None: | ||||
@@ -1169,11 +1178,12 @@ class EchoCallback(Callback): | |||||
class _TesterCallback(Callback): | class _TesterCallback(Callback): | ||||
def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None): | |||||
def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None, sampler=None, | |||||
use_tqdm=True): | |||||
super(_TesterCallback, self).__init__() | super(_TesterCallback, self).__init__() | ||||
self.tester = Tester(data, model, | self.tester = Tester(data, model, | ||||
metrics=metrics, batch_size=batch_size, | metrics=metrics, batch_size=batch_size, | ||||
num_workers=num_workers, verbose=0) | |||||
num_workers=num_workers, verbose=0, sampler=sampler, use_tqdm=use_tqdm) | |||||
if metric_key is not None: | if metric_key is not None: | ||||
self.metric_key, self.increase_better = self._parse_metric_key(metric_key) | self.metric_key, self.increase_better = self._parse_metric_key(metric_key) | ||||
else: | else: | ||||
@@ -371,6 +371,10 @@ from .field import SetInputOrTargetException | |||||
from .instance import Instance | from .instance import Instance | ||||
from .utils import pretty_table_printer | from .utils import pretty_table_printer | ||||
from .collate_fn import Collater | from .collate_fn import Collater | ||||
try: | |||||
from tqdm.auto import tqdm | |||||
except: | |||||
from .utils import _pseudo_tqdm as tqdm | |||||
class ApplyResultException(Exception): | class ApplyResultException(Exception): | ||||
@@ -531,11 +535,11 @@ class DataSet(object): | |||||
| pad_value | 0 | | | | pad_value | 0 | | | ||||
+-------------+-------+-------+ | +-------------+-------+-------+ | ||||
:param field_names: DataSet中field的名称 | |||||
:param is_input: field是否为input | |||||
:param is_target: field是否为target | |||||
:param ignore_type: 是否忽略该field的type, 一般仅在该field至少为input或target时才有意义 | |||||
:param pad_value: 该field的pad的值,仅在该field为input或target时有意义 | |||||
str field_names: DataSet中field的名称 | |||||
bool is_input: field是否为input | |||||
bool is_target: field是否为target | |||||
bool ignore_type: 是否忽略该field的type, 一般仅在该field至少为input或target时才有意义 | |||||
int pad_value: 该field的pad的值,仅在该field为input或target时有意义 | |||||
:return: | :return: | ||||
""" | """ | ||||
if len(self.field_arrays)>0: | if len(self.field_arrays)>0: | ||||
@@ -860,6 +864,11 @@ class DataSet(object): | |||||
2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target | 2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target | ||||
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | 3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | ||||
4. use_tqdm: bool, 是否使用tqdm显示预处理进度 | |||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 | |||||
:return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度 | :return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度 | ||||
""" | """ | ||||
assert len(self) != 0, "Null DataSet cannot use apply_field()." | assert len(self) != 0, "Null DataSet cannot use apply_field()." | ||||
@@ -887,6 +896,10 @@ class DataSet(object): | |||||
3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 | 3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 | ||||
4. use_tqdm: bool, 是否使用tqdm显示预处理进度 | |||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 | |||||
:return Dict[str:Field]: 返回一个字典 | :return Dict[str:Field]: 返回一个字典 | ||||
""" | """ | ||||
assert len(self) != 0, "Null DataSet cannot use apply_field()." | assert len(self) != 0, "Null DataSet cannot use apply_field()." | ||||
@@ -920,7 +933,8 @@ class DataSet(object): | |||||
if 'ignore_type' not in extra_param: | if 'ignore_type' not in extra_param: | ||||
extra_param['ignore_type'] = old_field.ignore_type | extra_param['ignore_type'] = old_field.ignore_type | ||||
self.add_field(field_name=new_field_name, fields=results, is_input=extra_param["is_input"], | self.add_field(field_name=new_field_name, fields=results, is_input=extra_param["is_input"], | ||||
is_target=extra_param["is_target"], ignore_type=extra_param['ignore_type']) | |||||
is_target=extra_param["is_target"], ignore_type=extra_param['ignore_type'], | |||||
padder=self.get_field(new_field_name).padder) | |||||
else: | else: | ||||
self.add_field(field_name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), | self.add_field(field_name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), | ||||
is_target=extra_param.get("is_target", None), | is_target=extra_param.get("is_target", None), | ||||
@@ -949,6 +963,10 @@ class DataSet(object): | |||||
3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 | 3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 | ||||
4. use_tqdm: bool, 是否使用tqdm显示预处理进度 | |||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 | |||||
:return Dict[str:Field]: 返回一个字典 | :return Dict[str:Field]: 返回一个字典 | ||||
""" | """ | ||||
# 返回 dict , 检查是否一直相同 | # 返回 dict , 检查是否一直相同 | ||||
@@ -957,7 +975,9 @@ class DataSet(object): | |||||
idx = -1 | idx = -1 | ||||
try: | try: | ||||
results = {} | results = {} | ||||
for idx, ins in enumerate(self._inner_iter()): | |||||
for idx, ins in tqdm(enumerate(self._inner_iter()), total=len(self), dynamic_ncols=True, | |||||
desc=kwargs.get('tqdm_desc', ''), | |||||
leave=False, disable=not kwargs.get('use_tqdm', False)): | |||||
if "_apply_field" in kwargs: | if "_apply_field" in kwargs: | ||||
res = func(ins[kwargs["_apply_field"]]) | res = func(ins[kwargs["_apply_field"]]) | ||||
else: | else: | ||||
@@ -1001,6 +1021,10 @@ class DataSet(object): | |||||
2. is_target: bool, 如果为True则将 `new_field_name` 的field设置为target | 2. is_target: bool, 如果为True则将 `new_field_name` 的field设置为target | ||||
3. ignore_type: bool, 如果为True则将 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | 3. ignore_type: bool, 如果为True则将 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | ||||
4. use_tqdm: bool, 是否使用tqdm显示预处理进度 | |||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 | |||||
:return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度 | :return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度 | ||||
""" | """ | ||||
@@ -1009,7 +1033,9 @@ class DataSet(object): | |||||
idx = -1 | idx = -1 | ||||
try: | try: | ||||
results = [] | results = [] | ||||
for idx, ins in enumerate(self._inner_iter()): | |||||
for idx, ins in tqdm(enumerate(self._inner_iter()), total=len(self), dynamic_ncols=True, leave=False, | |||||
desc=kwargs.get('tqdm_desc', ''), | |||||
disable=not kwargs.get('use_tqdm', False)): | |||||
if "_apply_field" in kwargs: | if "_apply_field" in kwargs: | ||||
results.append(func(ins[kwargs["_apply_field"]])) | results.append(func(ins[kwargs["_apply_field"]])) | ||||
else: | else: | ||||
@@ -1146,3 +1172,40 @@ class DataSet(object): | |||||
def _collate_batch(self, ins_list): | def _collate_batch(self, ins_list): | ||||
return self.collater.collate_batch(ins_list) | return self.collater.collate_batch(ins_list) | ||||
def concat(self, dataset, inplace=True, field_mapping=None): | |||||
""" | |||||
将当前dataset与输入的dataset结合成一个更大的dataset,需要保证两个dataset都包含了相同的field。结合后的dataset的input,target | |||||
以及collate_fn以当前dataset为准。当dataset中包含的field多于当前的dataset,则多余的field会被忽略;若dataset中未包含所有 | |||||
当前dataset含有field,则会报错。 | |||||
:param DataSet, dataset: 需要和当前dataset concat的dataset | |||||
:param bool, inplace: 是否直接将dataset组合到当前dataset中 | |||||
:param dict, field_mapping: 当dataset中的field名称和当前dataset不一致时,需要通过field_mapping把输入的dataset中的field | |||||
名称映射到当前field. field_mapping为dict类型,key为dataset中的field名称,value是需要映射成的名称 | |||||
:return: DataSet | |||||
""" | |||||
assert isinstance(dataset, DataSet), "Can only concat two datasets." | |||||
fns_in_this_dataset = set(self.get_field_names()) | |||||
fns_in_other_dataset = dataset.get_field_names() | |||||
reverse_field_mapping = {} | |||||
if field_mapping is not None: | |||||
fns_in_other_dataset = [field_mapping.get(fn, fn) for fn in fns_in_other_dataset] | |||||
reverse_field_mapping = {v:k for k, v in field_mapping.items()} | |||||
fns_in_other_dataset = set(fns_in_other_dataset) | |||||
fn_not_seen = list(fns_in_this_dataset - fns_in_other_dataset) | |||||
if fn_not_seen: | |||||
raise RuntimeError(f"The following fields are not provided in the dataset:{fn_not_seen}") | |||||
if inplace: | |||||
ds = self | |||||
else: | |||||
ds = deepcopy(self) | |||||
for fn in fns_in_this_dataset: | |||||
ds.get_field(fn).content.extend(deepcopy(dataset.get_field(reverse_field_mapping.get(fn, fn)).content)) | |||||
return ds |
@@ -29,15 +29,10 @@ from .dataset import DataSet | |||||
from .losses import _prepare_losser | from .losses import _prepare_losser | ||||
from .optimizer import Optimizer | from .optimizer import Optimizer | ||||
from .utils import _build_args | from .utils import _build_args | ||||
from .utils import _check_fp16 | |||||
from .utils import _build_fp16_env | |||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
from .utils import _move_dict_value_to_device | from .utils import _move_dict_value_to_device | ||||
try: | |||||
from apex import amp | |||||
except: | |||||
amp = None | |||||
__all__ = [ | __all__ = [ | ||||
'get_local_rank', | 'get_local_rank', | ||||
'DistTrainer', | 'DistTrainer', | ||||
@@ -73,7 +68,7 @@ class DistTrainer(): | |||||
dev_data=None, metrics=None, metric_key=None, | dev_data=None, metrics=None, metric_key=None, | ||||
update_every=1, print_every=10, validate_every=-1, | update_every=1, print_every=10, validate_every=-1, | ||||
save_path=None, device='auto', | save_path=None, device='auto', | ||||
fp16='', use_tqdm=True): | |||||
fp16=False, use_tqdm=True, **kwargs): | |||||
r""" | r""" | ||||
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 | :param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 | ||||
@@ -104,8 +99,15 @@ class DistTrainer(): | |||||
:param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存 | :param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存 | ||||
最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 | 最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 | ||||
:param str device: 指定 device,可以是 gpu,cpu 或 auto | :param str device: 指定 device,可以是 gpu,cpu 或 auto | ||||
:param str fp16: 指定半精度训练的优化等级,可为 O1,O2 或 O3,若为空字符串则不使用半精度。 | |||||
:param bool fp16: 指定是否使用半精度训练。 | |||||
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 | :param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 | ||||
:param kwargs: 支持配置可选参数 | |||||
bool test_use_tqdm: 在dev上验证的时候是否开启tqdm | |||||
Sampler test_sampler: 在evaluate的时候使用的sampler | |||||
int dev_batch_size: 在evaluate时,使用的evaluate的batch大小 | |||||
bool test_use_fp16: test时使用fp16 | |||||
bool set_grad_to_none: zero_grad时将grad设为None而不是0 | |||||
GradScaler gradscaler: 自定义的梯度 scaler | |||||
""" | """ | ||||
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" | assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" | ||||
if device == 'auto': | if device == 'auto': | ||||
@@ -144,14 +146,19 @@ class DistTrainer(): | |||||
self.use_tqdm = use_tqdm | self.use_tqdm = use_tqdm | ||||
model.to(self.device) | model.to(self.device) | ||||
optimizer = self._get_optimizer(optimizer) | |||||
# init fp16, must before DataParallel init | # init fp16, must before DataParallel init | ||||
if len(self.fp16): | |||||
assert isinstance(self.fp16, str), "Please set Apex AMP optimization level selected in ['O0', 'O1', 'O2', 'O3']" | |||||
_check_fp16() | |||||
assert device == 'cuda', "Amp requires cuda device" | |||||
model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16) | |||||
autocast, GradScaler = _build_fp16_env(dummy=not self.fp16) | |||||
self.auto_cast = autocast | |||||
user_grad_scaler = getattr(kwargs, 'gradscaler', None) | |||||
if user_grad_scaler is not None: | |||||
assert self.fp16, "must set fp16=True to enable gradscaler" | |||||
grad_scaler = user_grad_scaler | |||||
else: | |||||
grad_scaler = GradScaler() | |||||
self.grad_scaler = grad_scaler | |||||
self.set_grad_to_none = getattr(kwargs, 'set_grad_to_none', True) | |||||
# init DataParallel | # init DataParallel | ||||
if parse_version(torch.__version__)>=parse_version('1.1'): | if parse_version(torch.__version__)>=parse_version('1.1'): | ||||
@@ -162,17 +169,27 @@ class DistTrainer(): | |||||
output_device=self.local_rank) | output_device=self.local_rank) | ||||
self.model = self.ddp_model.module | self.model = self.ddp_model.module | ||||
optimizer = self._get_optimizer(optimizer) | |||||
self.optimizer = optimizer | self.optimizer = optimizer | ||||
self.sampler = DistributedSampler(self.train_data) | |||||
if isinstance(self.train_data, DataSet): | |||||
self.sampler = DistributedSampler(self.train_data) | |||||
self.data_iterator = self._get_data_iter(self.train_data) | self.data_iterator = self._get_data_iter(self.train_data) | ||||
self.batch_size = self.world_size * self.batch_size_per_gpu | self.batch_size = self.world_size * self.batch_size_per_gpu | ||||
self.n_steps = self._get_n_steps() | self.n_steps = self._get_n_steps() | ||||
self.dev_data = dev_data | |||||
self.metrics = metrics | |||||
self.test_use_tqdm = True | |||||
self.kwargs = kwargs | |||||
self.test_use_tqdm = kwargs.get('test_use_tqdm', self.use_tqdm) | |||||
dev_batch_size = kwargs.get('dev_batch_size', batch_size_per_gpu) | |||||
# for evaluation, only run eval on master proc | # for evaluation, only run eval on master proc | ||||
if dev_data and metrics: | if dev_data and metrics: | ||||
cb = _TesterCallback( | cb = _TesterCallback( | ||||
dev_data, model, metrics, | dev_data, model, metrics, | ||||
batch_size=batch_size_per_gpu, num_workers=num_workers) | |||||
batch_size=dev_batch_size, num_workers=num_workers, sampler=kwargs.get('test_sampler', None), | |||||
use_tqdm=self.test_use_tqdm) | |||||
self.test_manager.add_callback([cb], master=True) | self.test_manager.add_callback([cb], master=True) | ||||
# Setup logging | # Setup logging | ||||
@@ -190,11 +207,9 @@ class DistTrainer(): | |||||
self.logger = logger | self.logger = logger | ||||
self.logger.info("Setup Distributed Trainer") | self.logger.info("Setup Distributed Trainer") | ||||
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( | self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( | ||||
os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False)) | |||||
os.getpid(), self.rank, self.local_rank, self.device, self.fp16)) | |||||
self.logger.info("Num of processes: {}".format(self.world_size)) | self.logger.info("Num of processes: {}".format(self.world_size)) | ||||
self.logger.info("Use device: {}".format(device)) | self.logger.info("Use device: {}".format(device)) | ||||
self.logger.info("Training with fp16: {}, optimization level: {}".format( | |||||
len(self.fp16) > 0, self.fp16 if self.fp16 else None)) | |||||
def _maybe_no_sync(self): | def _maybe_no_sync(self): | ||||
""" | """ | ||||
@@ -232,8 +247,10 @@ class DistTrainer(): | |||||
elif optimizer is None: | elif optimizer is None: | ||||
return torch.optim.Adam(self.ddp_model.parameters(), lr=4e-3) | return torch.optim.Adam(self.ddp_model.parameters(), lr=4e-3) | ||||
else: | else: | ||||
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) | |||||
if not (hasattr(optimizer, 'step') and callable(optimizer.step)): | |||||
raise TypeError("optimizer must have a callable step() function.") | |||||
else: | |||||
self.optimizer = optimizer | |||||
@property | @property | ||||
def is_master(self): | def is_master(self): | ||||
r"""是否是主进程""" | r"""是否是主进程""" | ||||
@@ -334,28 +351,20 @@ class DistTrainer(): | |||||
indices = data_iterator.get_batch_indices() | indices = data_iterator.get_batch_indices() | ||||
# negative sampling; replace unknown; re-weight batch_y | # negative sampling; replace unknown; re-weight batch_y | ||||
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | ||||
prediction = self._data_forward(self.ddp_model, batch_x) | |||||
with self.auto_cast(): | |||||
prediction = self._data_forward(self.ddp_model, batch_x) | |||||
# edit prediction | |||||
self.callback_manager.on_loss_begin(batch_y, prediction) | |||||
loss = self._compute_loss(prediction, batch_y) | |||||
# edit prediction | |||||
self.callback_manager.on_loss_begin(batch_y, prediction) | |||||
loss = self._compute_loss(prediction, batch_y) | |||||
if self.update_every > 1: | |||||
loss = loss / self.update_every | |||||
avg_loss += loss.item() | |||||
avg_loss += loss.detach() | |||||
# Is loss NaN or inf? requires_grad = False | # Is loss NaN or inf? requires_grad = False | ||||
self.callback_manager.on_backward_begin(loss) | self.callback_manager.on_backward_begin(loss) | ||||
# with self._maybe_no_sync(): | |||||
if self.fp16: | |||||
with amp.scale_loss(loss, self.optimizer) as scale_loss: | |||||
scale_loss.backward() | |||||
else: | |||||
loss.backward() | |||||
self.grad_scaler.scale(loss).backward() | |||||
self.callback_manager.on_backward_end() | self.callback_manager.on_backward_end() | ||||
self._update() | |||||
if self.step % self.update_every == 0: | |||||
self._update() | |||||
self.callback_manager.on_step_end() | self.callback_manager.on_step_end() | ||||
if self.step % self.print_every == 0: | if self.step % self.print_every == 0: | ||||
@@ -367,11 +376,11 @@ class DistTrainer(): | |||||
self.callback_manager.on_batch_end() | self.callback_manager.on_batch_end() | ||||
if (self.validate_every > 0 and self.step % self.validate_every == 0): | |||||
if (self.validate_every > 0 and self.step % self.validate_every == 0) and len(self.test_manager.callbacks): | |||||
self._do_validation() | self._do_validation() | ||||
# ================= mini-batch end ==================== # | # ================= mini-batch end ==================== # | ||||
if self.validate_every < 0: | |||||
if self.validate_every < 0 and len(self.test_manager.callbacks): | |||||
self._do_validation() | self._do_validation() | ||||
# lr decay; early stopping | # lr decay; early stopping | ||||
@@ -381,13 +390,22 @@ class DistTrainer(): | |||||
self.pbar = None | self.pbar = None | ||||
# ============ tqdm end ============== # | # ============ tqdm end ============== # | ||||
def _clear_grad_opt(self, optimizer): | |||||
if self.set_grad_to_none: | |||||
for group in optimizer.param_groups: | |||||
for p in group['params']: | |||||
if p.grad is not None: | |||||
p.grad = None | |||||
else: | |||||
optimizer.zero_grad() | |||||
def _update(self): | def _update(self): | ||||
r"""Perform weight update on a model. | r"""Perform weight update on a model. | ||||
""" | """ | ||||
if self.step % self.update_every == 0: | |||||
self.optimizer.step() | |||||
self.ddp_model.zero_grad() | |||||
self.grad_scaler.step(self.optimizer) | |||||
self.grad_scaler.update() | |||||
self._clear_grad_opt(self.optimizer) | |||||
def _data_forward(self, network, x): | def _data_forward(self, network, x): | ||||
x = _build_args(self._forward_func, **x) | x = _build_args(self._forward_func, **x) | ||||
@@ -38,8 +38,286 @@ class AppendToTargetOrInputException(Exception): | |||||
self.field_name = field_name # 标示当前field的名称 | self.field_name = field_name # 标示当前field的名称 | ||||
def _get_ele_type_and_dim(cell: Any, dim=0): | |||||
r""" | |||||
识别cell的类别与dimension的数量 | |||||
numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html | |||||
:param cell: | |||||
:param dim: | |||||
:return: | |||||
""" | |||||
if isinstance(cell, (str, Number, np.bool_)): | |||||
if hasattr(cell, 'dtype'): | |||||
return cell.dtype.type, dim | |||||
return type(cell), dim | |||||
elif isinstance(cell, list): | |||||
dim += 1 | |||||
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | |||||
types = set([i for i, j in res]) | |||||
dims = set([j for i, j in res]) | |||||
if len(types) > 1: | |||||
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) | |||||
elif len(types) == 0: | |||||
raise SetInputOrTargetException("Empty value encountered.") | |||||
if len(dims) > 1: | |||||
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) | |||||
return types.pop(), dims.pop() | |||||
elif isinstance(cell, torch.Tensor): | |||||
return cell.dtype, cell.dim() + dim # 如果是torch.mean的结果是0 | |||||
elif isinstance(cell, np.ndarray): | |||||
if cell.dtype != np.dtype('O'): # 如果不是object的话说明是well-formatted的了 | |||||
return cell.dtype.type, cell.ndim + dim # dtype.type返回的会是np.int32, np.float等 | |||||
# 否则需要继续往下iterate | |||||
dim += 1 | |||||
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | |||||
types = set([i for i, j in res]) | |||||
dims = set([j for i, j in res]) | |||||
if len(types) > 1: | |||||
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) | |||||
elif len(types) == 0: | |||||
raise SetInputOrTargetException("Empty value encountered.") | |||||
if len(dims) > 1: | |||||
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) | |||||
return types.pop(), dims.pop() | |||||
else: # 包含tuple, set, dict以及其它的类型 | |||||
raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") | |||||
class Padder: | |||||
r""" | |||||
所有padder都需要继承这个类,并覆盖__call__方法。 | |||||
用于对batch进行padding操作。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前deepcopy一份。 | |||||
.. py:function:: __call__(self, contents, field_name, field_ele_dtype): | |||||
""" | |||||
def __init__(self, pad_val=0, **kwargs): | |||||
r""" | |||||
:param List[Any] contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | |||||
deepcopy一份。 | |||||
:param str, field_name: field的名称。 | |||||
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。 | |||||
:return: np.array([padded_element]) | |||||
""" | |||||
self.pad_val = pad_val | |||||
def set_pad_val(self, pad_val): | |||||
self.pad_val = pad_val | |||||
def get_pad_val(self): | |||||
return self.pad_val | |||||
@abstractmethod | |||||
def __call__(self, contents, field_name, field_ele_dtype, dim: int): | |||||
r""" | |||||
传入的是List内容。假设有以下的DataSet。 | |||||
:param List[Any] contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | |||||
deepcopy一份。 | |||||
:param str, field_name: field的名称。 | |||||
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True, | |||||
该这个值为None。 | |||||
:param dim: 这个field的维度。当ignore_type为True时,该值为None | |||||
:return: np.array([padded_element]) | |||||
Example:: | |||||
from fastNLP import DataSet | |||||
from fastNLP import Instance | |||||
dataset = DataSet() | |||||
dataset.append(Instance(sent='this is a demo', length=4, | |||||
chars=[['t', 'h', 'i', 's'], ['i', 's'], ['a'], ['d', 'e', 'm', 'o']])) | |||||
dataset.append(Instance(sent='another one', length=2, | |||||
chars=[['a', 'n', 'o', 't', 'h', 'e', 'r'], ['o', 'n', 'e']])) | |||||
如果调用 | |||||
batch = dataset.get([0,1], pad=True) | |||||
sent这个field的padder的__call__会接收到的内容会是 | |||||
[ | |||||
'this is a demo', | |||||
'another one' | |||||
] | |||||
length这个field的padder的__call__会接收到的内容会是 | |||||
[4, 2] | |||||
chars这个field的padder的__call__会接收到的内容会是 | |||||
[ | |||||
[['t', 'h', 'i', 's'], ['i', 's'], ['a'], ['d', 'e', 'm', 'o']], | |||||
[['a', 'n', 'o', 't', 'h', 'e', 'r'], ['o', 'n', 'e']] | |||||
] | |||||
即把每个instance中某个field的内容合成一个List传入 | |||||
""" | |||||
raise NotImplementedError | |||||
class AutoPadder(Padder): | |||||
r""" | |||||
根据contents的数据自动判定是否需要做padding。 | |||||
1 如果元素类型(元素类型是指field中最里层元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 | |||||
型为str, [[1,2], ...]的元素类型为int)的数据不为数值类型则不会进行pad | |||||
2 如果元素类型为数值类型,比如np.int64, np.float64, int, float, torch.int64等 | |||||
2.1 如果该field的内容为数值类型(包括int, float等),比如为seq_len, 则不进行padding | |||||
2.2 如果该field的内容等价于一维list, 那么会将Batch中的List pad为一样长。 | |||||
2.3 如果该field的内容等价于二维list,那么会按照英语character padding的方式进行padding。如果是character padding建议使用 | |||||
:class: fastNLP.EngChar2DPadder. | |||||
2.4 如果该field的内容等价于三维list,则如果每个instance在每个维度上相等,会组成一个batch的tensor返回,这种情况应该是为图片 | |||||
的情况。 | |||||
3 其它情况不进行处理,返回一个np.array类型。 | |||||
""" | |||||
def __init__(self, pad_val=0): | |||||
super().__init__(pad_val=pad_val) | |||||
def __call__(self, contents, field_name, field_ele_dtype, dim): | |||||
if field_ele_dtype: | |||||
if dim > 3: | |||||
return np.array(contents) | |||||
if isinstance(field_ele_dtype, type) and \ | |||||
(issubclass(field_ele_dtype, np.number) or issubclass(field_ele_dtype, Number)): | |||||
if dim == 0: | |||||
array = np.array(contents, dtype=field_ele_dtype) | |||||
elif dim == 1: | |||||
max_len = max(map(len, contents)) | |||||
array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) | |||||
for i, content_i in enumerate(contents): | |||||
array[i, :len(content_i)] = content_i | |||||
elif dim == 2: | |||||
max_len = max(map(len, contents)) | |||||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||||
content_i in contents]) | |||||
array = np.full((len(contents), max_len, max_word_len), self.pad_val, dtype=field_ele_dtype) | |||||
for i, content_i in enumerate(contents): | |||||
for j, content_ii in enumerate(content_i): | |||||
array[i, j, :len(content_ii)] = content_ii | |||||
else: | |||||
shape = np.shape(contents) | |||||
if len(shape) == 4: # 说明各dimension是相同的大小 | |||||
array = np.array(contents, dtype=field_ele_dtype) | |||||
else: | |||||
raise RuntimeError( | |||||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||||
return array | |||||
elif str(field_ele_dtype).startswith('torch'): | |||||
if dim == 0: | |||||
tensor = torch.tensor(contents).to(field_ele_dtype) | |||||
elif dim == 1: | |||||
max_len = max(map(len, contents)) | |||||
tensor = torch.full((len(contents), max_len), fill_value=self.pad_val, dtype=field_ele_dtype) | |||||
for i, content_i in enumerate(contents): | |||||
tensor[i, :len(content_i)] = content_i.clone().detach() | |||||
elif dim == 2: | |||||
max_len = max(map(len, contents)) | |||||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||||
content_i in contents]) | |||||
tensor = torch.full((len(contents), max_len, max_word_len), fill_value=self.pad_val, | |||||
dtype=field_ele_dtype) | |||||
for i, content_i in enumerate(contents): | |||||
for j, content_ii in enumerate(content_i): | |||||
tensor[i, j, :len(content_ii)] = content_ii.clone().detach() | |||||
else: | |||||
shapes = set([np.shape(content_i) for content_i in contents]) | |||||
if len(shapes) > 1: | |||||
raise RuntimeError( | |||||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||||
shape = shapes.pop() | |||||
if len(shape) == 3: | |||||
tensor = torch.full([len(contents)] + list(shape), fill_value=self.pad_val, | |||||
dtype=field_ele_dtype) | |||||
for i, content_i in enumerate(contents): | |||||
tensor[i] = content_i.clone().detach().to(field_ele_dtype) | |||||
else: | |||||
raise RuntimeError( | |||||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||||
return tensor | |||||
else: | |||||
return np.array(contents) # 不进行任何操作 | |||||
else: | |||||
return np.array(contents) | |||||
class EngChar2DPadder(Padder): | |||||
r""" | |||||
用于为英语执行character级别的2D padding操作。对应的field内容应该类似[['T', 'h', 'i', 's'], ['a'], ['d', 'e', 'm', 'o']], | |||||
但这个Padder只能处理index为int的情况。 | |||||
padded过后的batch内容,形状为(batch_size, max_sentence_length, max_word_length). max_sentence_length为这个batch中最大句 | |||||
子长度;max_word_length为这个batch中最长的word的长度:: | |||||
from fastNLP import DataSet | |||||
from fastNLP import EngChar2DPadder | |||||
from fastNLP import Vocabulary | |||||
dataset = DataSet({'sent': ['This is the first demo', 'This is the second demo']}) | |||||
dataset.apply(lambda ins:[list(word) for word in ins['sent'].split()], new_field_name='chars') | |||||
vocab = Vocabulary() | |||||
vocab.from_dataset(dataset, field_name='chars') | |||||
vocab.index_dataset(dataset, field_name='chars') | |||||
dataset.set_input('chars') | |||||
padder = EngChar2DPadder() | |||||
dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder | |||||
""" | |||||
def __init__(self, pad_val=0, pad_length=0): | |||||
r""" | |||||
:param pad_val: int, pad的位置使用该index | |||||
:param pad_length: int, 如果为0则取一个batch中最大的单词长度作为padding长度。如果为大于0的数,则将所有单词的长度 | |||||
都pad或截取到该长度. | |||||
""" | |||||
super().__init__(pad_val=pad_val) | |||||
self.pad_length = pad_length | |||||
def __call__(self, contents, field_name, field_ele_dtype, dim): | |||||
r""" | |||||
期望输入类似于 | |||||
[ | |||||
[[0, 2], [2, 3, 4], ..], | |||||
[[9, 8, 2, 4], [1, 2,], ...], | |||||
.... | |||||
] | |||||
:param contents: | |||||
:param field_name: | |||||
:param field_ele_dtype | |||||
:return: | |||||
""" | |||||
if field_ele_dtype not in (np.int64, np.float64, int, float): | |||||
raise TypeError('dtype of Field:{} should be np.int64 or np.float64 to do 2D padding, get {}.'.format( | |||||
field_name, field_ele_dtype | |||||
)) | |||||
assert dim == 2, f"Field:{field_name} has {dim}, EngChar2DPadder only supports input with 2 dimensions." | |||||
if self.pad_length < 1: | |||||
max_char_length = max([max(len(char_lst) for char_lst in word_lst) for word_lst in contents]) | |||||
else: | |||||
max_char_length = self.pad_length | |||||
max_sent_length = max(len(word_lst) for word_lst in contents) | |||||
batch_size = len(contents) | |||||
dtype = type(contents[0][0][0]) | |||||
padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val, | |||||
dtype=dtype) | |||||
for b_idx, word_lst in enumerate(contents): | |||||
for c_idx, char_lst in enumerate(word_lst): | |||||
chars = char_lst[:max_char_length] | |||||
padded_array[b_idx, c_idx, :len(chars)] = chars | |||||
return padded_array | |||||
class FieldArray: | class FieldArray: | ||||
def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False, | |||||
def __init__(self, name, content, is_target=False, is_input=False, padder=AutoPadder(), ignore_type=False, | |||||
use_1st_ins_infer_dim_type=True): | use_1st_ins_infer_dim_type=True): | ||||
if len(content) == 0: | if len(content) == 0: | ||||
raise RuntimeError("Empty fieldarray is not allowed.") | raise RuntimeError("Empty fieldarray is not allowed.") | ||||
@@ -58,34 +336,29 @@ class FieldArray: | |||||
self._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) | self._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) | ||||
self._is_input = False | self._is_input = False | ||||
self._is_target = False | self._is_target = False | ||||
if is_input: | if is_input: | ||||
self.is_input = is_input | self.is_input = is_input | ||||
if is_target: | if is_target: | ||||
self.is_target = is_target | self.is_target = is_target | ||||
if padder is None: | |||||
padder = AutoPadder(pad_val=0) | |||||
else: | |||||
assert isinstance(padder, Padder), "padder must be of type fastNLP.Padder." | |||||
padder = deepcopy(padder) | |||||
self.set_padder(padder) | self.set_padder(padder) | ||||
@property | @property | ||||
def ignore_type(self): | def ignore_type(self): | ||||
return self._ignore_type | return self._ignore_type | ||||
@ignore_type.setter | @ignore_type.setter | ||||
def ignore_type(self, value): | def ignore_type(self, value): | ||||
if value: | if value: | ||||
self._cell_ndim = None | self._cell_ndim = None | ||||
self.dtype = None | self.dtype = None | ||||
self._ignore_type = value | self._ignore_type = value | ||||
@property | @property | ||||
def is_input(self): | def is_input(self): | ||||
return self._is_input | return self._is_input | ||||
@is_input.setter | @is_input.setter | ||||
def is_input(self, value): | def is_input(self, value): | ||||
r""" | r""" | ||||
@@ -100,11 +373,11 @@ class FieldArray: | |||||
self.dtype = None | self.dtype = None | ||||
self._cell_ndim = None | self._cell_ndim = None | ||||
self._is_input = value | self._is_input = value | ||||
@property | @property | ||||
def is_target(self): | def is_target(self): | ||||
return self._is_target | return self._is_target | ||||
@is_target.setter | @is_target.setter | ||||
def is_target(self, value): | def is_target(self, value): | ||||
r""" | r""" | ||||
@@ -118,7 +391,7 @@ class FieldArray: | |||||
self.dtype = None | self.dtype = None | ||||
self._cell_ndim = None | self._cell_ndim = None | ||||
self._is_target = value | self._is_target = value | ||||
def _check_dtype_and_ndim(self, only_check_1st_ins_dim_type=True): | def _check_dtype_and_ndim(self, only_check_1st_ins_dim_type=True): | ||||
r""" | r""" | ||||
检查当前content所有的element是否是同一个类型,且是否每个元素具有相同的维度。通过的话,设置_cell_ndim与_ele_type属性;没有 | 检查当前content所有的element是否是同一个类型,且是否每个元素具有相同的维度。通过的话,设置_cell_ndim与_ele_type属性;没有 | ||||
@@ -148,7 +421,7 @@ class FieldArray: | |||||
except SetInputOrTargetException as e: | except SetInputOrTargetException as e: | ||||
e.index = index | e.index = index | ||||
raise e | raise e | ||||
def append(self, val: Any): | def append(self, val: Any): | ||||
r""" | r""" | ||||
:param val: 把该val append到fieldarray。 | :param val: 把该val append到fieldarray。 | ||||
@@ -165,7 +438,7 @@ class FieldArray: | |||||
self.content.append(val) | self.content.append(val) | ||||
else: | else: | ||||
self.content.append(val) | self.content.append(val) | ||||
def pop(self, index): | def pop(self, index): | ||||
r""" | r""" | ||||
删除该field中index处的元素 | 删除该field中index处的元素 | ||||
@@ -173,10 +446,10 @@ class FieldArray: | |||||
:return: | :return: | ||||
""" | """ | ||||
self.content.pop(index) | self.content.pop(index) | ||||
def __getitem__(self, indices): | def __getitem__(self, indices): | ||||
return self.get(indices, pad=False) | return self.get(indices, pad=False) | ||||
def __setitem__(self, idx, val): | def __setitem__(self, idx, val): | ||||
assert isinstance(idx, int) | assert isinstance(idx, int) | ||||
if (self._is_target or self._is_input) and self.ignore_type is False: # 需要检测类型 | if (self._is_target or self._is_input) and self.ignore_type is False: # 需要检测类型 | ||||
@@ -188,7 +461,7 @@ class FieldArray: | |||||
raise RuntimeError(f"Value(dim:{dim_}) are of different dimensions with " | raise RuntimeError(f"Value(dim:{dim_}) are of different dimensions with " | ||||
f"previous values(dim:{self._cell_ndim}).") | f"previous values(dim:{self._cell_ndim}).") | ||||
self.content[idx] = val | self.content[idx] = val | ||||
def get(self, indices, pad=True): | def get(self, indices, pad=True): | ||||
r""" | r""" | ||||
根据给定的indices返回内容。 | 根据给定的indices返回内容。 | ||||
@@ -208,7 +481,7 @@ class FieldArray: | |||||
return self.pad(contents) | return self.pad(contents) | ||||
else: | else: | ||||
return np.array(contents) | return np.array(contents) | ||||
def pad(self, contents): | def pad(self, contents): | ||||
r""" | r""" | ||||
传入list的contents,将contents使用padder进行padding,contents必须为从本FieldArray中取出的。 | 传入list的contents,将contents使用padder进行padding,contents必须为从本FieldArray中取出的。 | ||||
@@ -217,7 +490,7 @@ class FieldArray: | |||||
:return: | :return: | ||||
""" | """ | ||||
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) | return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) | ||||
def set_padder(self, padder): | def set_padder(self, padder): | ||||
r""" | r""" | ||||
设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 | 设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 | ||||
@@ -225,11 +498,11 @@ class FieldArray: | |||||
:param padder: :class:`~fastNLP.Padder` 类型,设置为None即删除padder。 | :param padder: :class:`~fastNLP.Padder` 类型,设置为None即删除padder。 | ||||
""" | """ | ||||
if padder is not None: | if padder is not None: | ||||
assert isinstance(padder, Padder), "padder must be of type Padder." | |||||
assert isinstance(padder, Padder), "padder must be of type `fastNLP.core.Padder`." | |||||
self.padder = deepcopy(padder) | self.padder = deepcopy(padder) | ||||
else: | else: | ||||
self.padder = None | self.padder = None | ||||
def set_pad_val(self, pad_val): | def set_pad_val(self, pad_val): | ||||
r""" | r""" | ||||
修改padder的pad_val. | 修改padder的pad_val. | ||||
@@ -239,7 +512,7 @@ class FieldArray: | |||||
if self.padder is not None: | if self.padder is not None: | ||||
self.padder.set_pad_val(pad_val) | self.padder.set_pad_val(pad_val) | ||||
return self | return self | ||||
def __len__(self): | def __len__(self): | ||||
r""" | r""" | ||||
Returns the size of FieldArray. | Returns the size of FieldArray. | ||||
@@ -247,7 +520,7 @@ class FieldArray: | |||||
:return int length: | :return int length: | ||||
""" | """ | ||||
return len(self.content) | return len(self.content) | ||||
def to(self, other): | def to(self, other): | ||||
r""" | r""" | ||||
将other的属性复制给本FieldArray(other必须为FieldArray类型). | 将other的属性复制给本FieldArray(other必须为FieldArray类型). | ||||
@@ -257,14 +530,14 @@ class FieldArray: | |||||
:return: :class:`~fastNLP.FieldArray` | :return: :class:`~fastNLP.FieldArray` | ||||
""" | """ | ||||
assert isinstance(other, FieldArray), "Only supports fastNLP.FieldArray type, not {}.".format(type(other)) | assert isinstance(other, FieldArray), "Only supports fastNLP.FieldArray type, not {}.".format(type(other)) | ||||
self.ignore_type = other.ignore_type | self.ignore_type = other.ignore_type | ||||
self.is_input = other.is_input | self.is_input = other.is_input | ||||
self.is_target = other.is_target | self.is_target = other.is_target | ||||
self.padder = other.padder | self.padder = other.padder | ||||
return self | return self | ||||
def split(self, sep: str = None, inplace: bool = True): | def split(self, sep: str = None, inplace: bool = True): | ||||
r""" | r""" | ||||
依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。将返回值 | 依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。将返回值 | ||||
@@ -281,421 +554,143 @@ class FieldArray: | |||||
logger.error(f"Exception happens when process value in index {index}.") | logger.error(f"Exception happens when process value in index {index}.") | ||||
raise e | raise e | ||||
return self._after_process(new_contents, inplace=inplace) | return self._after_process(new_contents, inplace=inplace) | ||||
def int(self, inplace: bool = True): | def int(self, inplace: bool = True): | ||||
r""" | r""" | ||||
将本field中的值调用int(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | 将本field中的值调用int(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | ||||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | ||||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | :param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | ||||
:return: List[int], List[List[int]], self | |||||
""" | |||||
new_contents = [] | |||||
for index, cell in enumerate(self.content): | |||||
try: | |||||
if isinstance(cell, list): | |||||
new_contents.append([int(value) for value in cell]) | |||||
else: | |||||
new_contents.append(int(cell)) | |||||
except Exception as e: | |||||
logger.error(f"Exception happens when process value in index {index}.") | |||||
raise e | |||||
return self._after_process(new_contents, inplace=inplace) | |||||
def float(self, inplace=True): | |||||
r""" | |||||
将本field中的值调用float(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||||
:return: | |||||
""" | |||||
new_contents = [] | |||||
for index, cell in enumerate(self.content): | |||||
try: | |||||
if isinstance(cell, list): | |||||
new_contents.append([float(value) for value in cell]) | |||||
else: | |||||
new_contents.append(float(cell)) | |||||
except Exception as e: | |||||
logger.error(f"Exception happens when process value in index {index}.") | |||||
raise e | |||||
return self._after_process(new_contents, inplace=inplace) | |||||
def bool(self, inplace=True): | |||||
r""" | |||||
将本field中的值调用bool(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||||
:return: | |||||
""" | |||||
new_contents = [] | |||||
for index, cell in enumerate(self.content): | |||||
try: | |||||
if isinstance(cell, list): | |||||
new_contents.append([bool(value) for value in cell]) | |||||
else: | |||||
new_contents.append(bool(cell)) | |||||
except Exception as e: | |||||
logger.error(f"Exception happens when process value in index {index}.") | |||||
raise e | |||||
return self._after_process(new_contents, inplace=inplace) | |||||
def lower(self, inplace=True): | |||||
r""" | |||||
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||||
:return: List[int], List[List[int]], self | |||||
""" | |||||
new_contents = [] | |||||
for index, cell in enumerate(self.content): | |||||
try: | |||||
if isinstance(cell, list): | |||||
new_contents.append([value.lower() for value in cell]) | |||||
else: | |||||
new_contents.append(cell.lower()) | |||||
except Exception as e: | |||||
logger.error(f"Exception happens when process value in index {index}.") | |||||
raise e | |||||
return self._after_process(new_contents, inplace=inplace) | |||||
def upper(self, inplace=True): | |||||
r""" | |||||
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||||
:return: List[int], List[List[int]], self | |||||
""" | |||||
new_contents = [] | |||||
for index, cell in enumerate(self.content): | |||||
try: | |||||
if isinstance(cell, list): | |||||
new_contents.append([value.upper() for value in cell]) | |||||
else: | |||||
new_contents.append(cell.upper()) | |||||
except Exception as e: | |||||
logger.error(f"Exception happens when process value in index {index}.") | |||||
raise e | |||||
return self._after_process(new_contents, inplace=inplace) | |||||
def value_count(self): | |||||
r""" | |||||
返回该field下不同value的数量。多用于统计label数量 | |||||
:return: Counter, key是label,value是出现次数 | |||||
""" | |||||
count = Counter() | |||||
def cum(cell): | |||||
if _is_iterable(cell) and not isinstance(cell, str): | |||||
for cell_ in cell: | |||||
cum(cell_) | |||||
else: | |||||
count[cell] += 1 | |||||
for cell in self.content: | |||||
cum(cell) | |||||
return count | |||||
def _after_process(self, new_contents, inplace): | |||||
r""" | |||||
当调用处理函数之后,决定是否要替换field。 | |||||
:param new_contents: | |||||
:param inplace: | |||||
:return: self或者生成的content | |||||
""" | |||||
if inplace: | |||||
self.content = new_contents | |||||
try: | |||||
self.is_input = self.is_input | |||||
self.is_target = self.is_input | |||||
except SetInputOrTargetException as e: | |||||
logger.error("The newly generated field cannot be set as input or target.") | |||||
raise e | |||||
return self | |||||
else: | |||||
return new_contents | |||||
def _get_ele_type_and_dim(cell: Any, dim=0): | |||||
r""" | |||||
识别cell的类别与dimension的数量 | |||||
numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html | |||||
:param cell: | |||||
:param dim: | |||||
:return: | |||||
""" | |||||
if isinstance(cell, (str, Number, np.bool_)): | |||||
if hasattr(cell, 'dtype'): | |||||
return cell.dtype.type, dim | |||||
return type(cell), dim | |||||
elif isinstance(cell, list): | |||||
dim += 1 | |||||
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | |||||
types = set([i for i, j in res]) | |||||
dims = set([j for i, j in res]) | |||||
if len(types) > 1: | |||||
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) | |||||
elif len(types) == 0: | |||||
raise SetInputOrTargetException("Empty value encountered.") | |||||
if len(dims) > 1: | |||||
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) | |||||
return types.pop(), dims.pop() | |||||
elif isinstance(cell, torch.Tensor): | |||||
return cell.dtype, cell.dim() + dim # 如果是torch.mean的结果是0 | |||||
elif isinstance(cell, np.ndarray): | |||||
if cell.dtype != np.dtype('O'): # 如果不是object的话说明是well-formatted的了 | |||||
return cell.dtype.type, cell.ndim + dim # dtype.type返回的会是np.int32, np.float等 | |||||
# 否则需要继续往下iterate | |||||
dim += 1 | |||||
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | |||||
types = set([i for i, j in res]) | |||||
dims = set([j for i, j in res]) | |||||
if len(types) > 1: | |||||
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) | |||||
elif len(types) == 0: | |||||
raise SetInputOrTargetException("Empty value encountered.") | |||||
if len(dims) > 1: | |||||
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) | |||||
return types.pop(), dims.pop() | |||||
else: # 包含tuple, set, dict以及其它的类型 | |||||
raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") | |||||
class Padder: | |||||
r""" | |||||
所有padder都需要继承这个类,并覆盖__call__方法。 | |||||
用于对batch进行padding操作。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前deepcopy一份。 | |||||
.. py:function:: __call__(self, contents, field_name, field_ele_dtype): | |||||
""" | |||||
def __init__(self, pad_val=0, **kwargs): | |||||
r""" | |||||
:param List[Any] contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | |||||
deepcopy一份。 | |||||
:param str, field_name: field的名称。 | |||||
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。 | |||||
:return: np.array([padded_element]) | |||||
""" | |||||
self.pad_val = pad_val | |||||
def set_pad_val(self, pad_val): | |||||
self.pad_val = pad_val | |||||
def get_pad_val(self): | |||||
return self.pad_val | |||||
@abstractmethod | |||||
def __call__(self, contents, field_name, field_ele_dtype, dim: int): | |||||
r""" | |||||
传入的是List内容。假设有以下的DataSet。 | |||||
:param List[Any] contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | |||||
deepcopy一份。 | |||||
:param str, field_name: field的名称。 | |||||
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True, | |||||
该这个值为None。 | |||||
:param dim: 这个field的维度。当ignore_type为True时,该值为None | |||||
:return: np.array([padded_element]) | |||||
Example:: | |||||
from fastNLP import DataSet | |||||
from fastNLP import Instance | |||||
dataset = DataSet() | |||||
dataset.append(Instance(sent='this is a demo', length=4, | |||||
chars=[['t', 'h', 'i', 's'], ['i', 's'], ['a'], ['d', 'e', 'm', 'o']])) | |||||
dataset.append(Instance(sent='another one', length=2, | |||||
chars=[['a', 'n', 'o', 't', 'h', 'e', 'r'], ['o', 'n', 'e']])) | |||||
如果调用 | |||||
batch = dataset.get([0,1], pad=True) | |||||
sent这个field的padder的__call__会接收到的内容会是 | |||||
[ | |||||
'this is a demo', | |||||
'another one' | |||||
] | |||||
length这个field的padder的__call__会接收到的内容会是 | |||||
[4, 2] | |||||
chars这个field的padder的__call__会接收到的内容会是 | |||||
[ | |||||
[['t', 'h', 'i', 's'], ['i', 's'], ['a'], ['d', 'e', 'm', 'o']], | |||||
[['a', 'n', 'o', 't', 'h', 'e', 'r'], ['o', 'n', 'e']] | |||||
] | |||||
即把每个instance中某个field的内容合成一个List传入 | |||||
:return: List[int], List[List[int]], self | |||||
""" | """ | ||||
raise NotImplementedError | |||||
new_contents = [] | |||||
for index, cell in enumerate(self.content): | |||||
try: | |||||
if isinstance(cell, list): | |||||
new_contents.append([int(value) for value in cell]) | |||||
else: | |||||
new_contents.append(int(cell)) | |||||
except Exception as e: | |||||
logger.error(f"Exception happens when process value in index {index}.") | |||||
raise e | |||||
return self._after_process(new_contents, inplace=inplace) | |||||
def float(self, inplace=True): | |||||
r""" | |||||
将本field中的值调用float(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||||
class AutoPadder(Padder): | |||||
r""" | |||||
根据contents的数据自动判定是否需要做padding。 | |||||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||||
:return: | |||||
""" | |||||
new_contents = [] | |||||
for index, cell in enumerate(self.content): | |||||
try: | |||||
if isinstance(cell, list): | |||||
new_contents.append([float(value) for value in cell]) | |||||
else: | |||||
new_contents.append(float(cell)) | |||||
except Exception as e: | |||||
logger.error(f"Exception happens when process value in index {index}.") | |||||
raise e | |||||
return self._after_process(new_contents, inplace=inplace) | |||||
1 如果元素类型(元素类型是指field中最里层元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 | |||||
型为str, [[1,2], ...]的元素类型为int)的数据不为数值类型则不会进行pad | |||||
def bool(self, inplace=True): | |||||
r""" | |||||
将本field中的值调用bool(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||||
2 如果元素类型为数值类型,比如np.int64, np.float64, int, float, torch.int64等 | |||||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||||
:return: | |||||
""" | |||||
new_contents = [] | |||||
for index, cell in enumerate(self.content): | |||||
try: | |||||
if isinstance(cell, list): | |||||
new_contents.append([bool(value) for value in cell]) | |||||
else: | |||||
new_contents.append(bool(cell)) | |||||
except Exception as e: | |||||
logger.error(f"Exception happens when process value in index {index}.") | |||||
raise e | |||||
2.1 如果该field的内容为数值类型(包括int, float等),比如为seq_len, 则不进行padding | |||||
return self._after_process(new_contents, inplace=inplace) | |||||
2.2 如果该field的内容等价于一维list, 那么会将Batch中的List pad为一样长。 | |||||
def lower(self, inplace=True): | |||||
r""" | |||||
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||||
2.3 如果该field的内容等价于二维list,那么会按照英语character padding的方式进行padding。如果是character padding建议使用 | |||||
:class: fastNLP.EngChar2DPadder. | |||||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||||
:return: List[int], List[List[int]], self | |||||
""" | |||||
new_contents = [] | |||||
for index, cell in enumerate(self.content): | |||||
try: | |||||
if isinstance(cell, list): | |||||
new_contents.append([value.lower() for value in cell]) | |||||
else: | |||||
new_contents.append(cell.lower()) | |||||
except Exception as e: | |||||
logger.error(f"Exception happens when process value in index {index}.") | |||||
raise e | |||||
return self._after_process(new_contents, inplace=inplace) | |||||
2.4 如果该field的内容等价于三维list,则如果每个instance在每个维度上相等,会组成一个batch的tensor返回,这种情况应该是为图片 | |||||
的情况。 | |||||
def upper(self, inplace=True): | |||||
r""" | |||||
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||||
3 其它情况不进行处理,返回一个np.array类型。 | |||||
""" | |||||
def __init__(self, pad_val=0): | |||||
super().__init__(pad_val=pad_val) | |||||
def __call__(self, contents, field_name, field_ele_dtype, dim): | |||||
if field_ele_dtype: | |||||
if dim > 3: | |||||
return np.array(contents) | |||||
if isinstance(field_ele_dtype, type) and \ | |||||
(issubclass(field_ele_dtype, np.number) or issubclass(field_ele_dtype, Number)): | |||||
if dim == 0: | |||||
array = np.array(contents, dtype=field_ele_dtype) | |||||
elif dim == 1: | |||||
max_len = max(map(len, contents)) | |||||
array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) | |||||
for i, content_i in enumerate(contents): | |||||
array[i, :len(content_i)] = content_i | |||||
elif dim == 2: | |||||
max_len = max(map(len, contents)) | |||||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||||
content_i in contents]) | |||||
array = np.full((len(contents), max_len, max_word_len), self.pad_val, dtype=field_ele_dtype) | |||||
for i, content_i in enumerate(contents): | |||||
for j, content_ii in enumerate(content_i): | |||||
array[i, j, :len(content_ii)] = content_ii | |||||
else: | |||||
shape = np.shape(contents) | |||||
if len(shape) == 4: # 说明各dimension是相同的大小 | |||||
array = np.array(contents, dtype=field_ele_dtype) | |||||
else: | |||||
raise RuntimeError( | |||||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||||
return array | |||||
elif str(field_ele_dtype).startswith('torch'): | |||||
if dim == 0: | |||||
tensor = torch.tensor(contents).to(field_ele_dtype) | |||||
elif dim == 1: | |||||
max_len = max(map(len, contents)) | |||||
tensor = torch.full((len(contents), max_len), fill_value=self.pad_val, dtype=field_ele_dtype) | |||||
for i, content_i in enumerate(contents): | |||||
tensor[i, :len(content_i)] = content_i.clone().detach() | |||||
elif dim == 2: | |||||
max_len = max(map(len, contents)) | |||||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||||
content_i in contents]) | |||||
tensor = torch.full((len(contents), max_len, max_word_len), fill_value=self.pad_val, | |||||
dtype=field_ele_dtype) | |||||
for i, content_i in enumerate(contents): | |||||
for j, content_ii in enumerate(content_i): | |||||
tensor[i, j, :len(content_ii)] = content_ii.clone().detach() | |||||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||||
:return: List[int], List[List[int]], self | |||||
""" | |||||
new_contents = [] | |||||
for index, cell in enumerate(self.content): | |||||
try: | |||||
if isinstance(cell, list): | |||||
new_contents.append([value.upper() for value in cell]) | |||||
else: | else: | ||||
shapes = set([np.shape(content_i) for content_i in contents]) | |||||
if len(shapes) > 1: | |||||
raise RuntimeError( | |||||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||||
shape = shapes.pop() | |||||
if len(shape) == 3: | |||||
tensor = torch.full([len(contents)] + list(shape), fill_value=self.pad_val, | |||||
dtype=field_ele_dtype) | |||||
for i, content_i in enumerate(contents): | |||||
tensor[i] = content_i.clone().detach().to(field_ele_dtype) | |||||
else: | |||||
raise RuntimeError( | |||||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||||
return tensor | |||||
else: | |||||
return np.array(contents) # 不进行任何操作 | |||||
else: | |||||
return np.array(contents) | |||||
new_contents.append(cell.upper()) | |||||
except Exception as e: | |||||
logger.error(f"Exception happens when process value in index {index}.") | |||||
raise e | |||||
return self._after_process(new_contents, inplace=inplace) | |||||
def value_count(self): | |||||
r""" | |||||
返回该field下不同value的数量。多用于统计label数量 | |||||
class EngChar2DPadder(Padder): | |||||
r""" | |||||
用于为英语执行character级别的2D padding操作。对应的field内容应该类似[['T', 'h', 'i', 's'], ['a'], ['d', 'e', 'm', 'o']], | |||||
但这个Padder只能处理index为int的情况。 | |||||
:return: Counter, key是label,value是出现次数 | |||||
""" | |||||
count = Counter() | |||||
padded过后的batch内容,形状为(batch_size, max_sentence_length, max_word_length). max_sentence_length为这个batch中最大句 | |||||
子长度;max_word_length为这个batch中最长的word的长度:: | |||||
def cum(cell): | |||||
if _is_iterable(cell) and not isinstance(cell, str): | |||||
for cell_ in cell: | |||||
cum(cell_) | |||||
else: | |||||
count[cell] += 1 | |||||
from fastNLP import DataSet | |||||
from fastNLP import EngChar2DPadder | |||||
from fastNLP import Vocabulary | |||||
dataset = DataSet({'sent': ['This is the first demo', 'This is the second demo']}) | |||||
dataset.apply(lambda ins:[list(word) for word in ins['sent'].split()], new_field_name='chars') | |||||
vocab = Vocabulary() | |||||
vocab.from_dataset(dataset, field_name='chars') | |||||
vocab.index_dataset(dataset, field_name='chars') | |||||
dataset.set_input('chars') | |||||
padder = EngChar2DPadder() | |||||
dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder | |||||
for cell in self.content: | |||||
cum(cell) | |||||
return count | |||||
""" | |||||
def __init__(self, pad_val=0, pad_length=0): | |||||
r""" | |||||
:param pad_val: int, pad的位置使用该index | |||||
:param pad_length: int, 如果为0则取一个batch中最大的单词长度作为padding长度。如果为大于0的数,则将所有单词的长度 | |||||
都pad或截取到该长度. | |||||
""" | |||||
super().__init__(pad_val=pad_val) | |||||
self.pad_length = pad_length | |||||
def __call__(self, contents, field_name, field_ele_dtype, dim): | |||||
def _after_process(self, new_contents, inplace): | |||||
r""" | r""" | ||||
期望输入类似于 | |||||
[ | |||||
[[0, 2], [2, 3, 4], ..], | |||||
[[9, 8, 2, 4], [1, 2,], ...], | |||||
.... | |||||
] | |||||
当调用处理函数之后,决定是否要替换field。 | |||||
:param contents: | |||||
:param field_name: | |||||
:param field_ele_dtype | |||||
:return: | |||||
:param new_contents: | |||||
:param inplace: | |||||
:return: self或者生成的content | |||||
""" | """ | ||||
if field_ele_dtype not in (np.int64, np.float64, int, float): | |||||
raise TypeError('dtype of Field:{} should be np.int64 or np.float64 to do 2D padding, get {}.'.format( | |||||
field_name, field_ele_dtype | |||||
)) | |||||
assert dim == 2, f"Field:{field_name} has {dim}, EngChar2DPadder only supports input with 2 dimensions." | |||||
if self.pad_length < 1: | |||||
max_char_length = max([max(len(char_lst) for char_lst in word_lst) for word_lst in contents]) | |||||
if inplace: | |||||
self.content = new_contents | |||||
try: | |||||
self.is_input = self.is_input | |||||
self.is_target = self.is_input | |||||
except SetInputOrTargetException as e: | |||||
logger.error("The newly generated field cannot be set as input or target.") | |||||
raise e | |||||
return self | |||||
else: | else: | ||||
max_char_length = self.pad_length | |||||
max_sent_length = max(len(word_lst) for word_lst in contents) | |||||
batch_size = len(contents) | |||||
dtype = type(contents[0][0][0]) | |||||
padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val, | |||||
dtype=dtype) | |||||
for b_idx, word_lst in enumerate(contents): | |||||
for c_idx, char_lst in enumerate(word_lst): | |||||
chars = char_lst[:max_char_length] | |||||
padded_array[b_idx, c_idx, :len(chars)] = chars | |||||
return padded_array | |||||
return new_contents |
@@ -10,6 +10,7 @@ __all__ = [ | |||||
"CrossEntropyLoss", | "CrossEntropyLoss", | ||||
"BCELoss", | "BCELoss", | ||||
"BCEWithLogits", | |||||
"L1Loss", | "L1Loss", | ||||
"NLLLoss", | "NLLLoss", | ||||
"MSELoss", | "MSELoss", | ||||
@@ -216,7 +217,7 @@ class CrossEntropyLoss(LossBase): | |||||
或(batch_size, num_classes, max_len), CrossEntropyLoss需要知道哪一维是class的维度以计算loss。如果为-1,就根据pred的第 | 或(batch_size, num_classes, max_len), CrossEntropyLoss需要知道哪一维是class的维度以计算loss。如果为-1,就根据pred的第 | ||||
二维是否等于target的第二维来判断是否需要交换pred的第二维和第三维,因为target的第二维是length的维度,如果这一维度上和pred相等, | 二维是否等于target的第二维来判断是否需要交换pred的第二维和第三维,因为target的第二维是length的维度,如果这一维度上和pred相等, | ||||
那么pred可能第二维也是长度维(存在误判的可能,如果有误判的情况,请显示设置该值)。其它大于0的值则认为该维度是class的维度。 | 那么pred可能第二维也是长度维(存在误判的可能,如果有误判的情况,请显示设置该值)。其它大于0的值则认为该维度是class的维度。 | ||||
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替 | |||||
:param ignore_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替 | |||||
传入seq_len. | 传入seq_len. | ||||
:param str reduction: 支持 `mean` ,`sum` 和 `none` . | :param str reduction: 支持 `mean` ,`sum` 和 `none` . | ||||
@@ -226,10 +227,11 @@ class CrossEntropyLoss(LossBase): | |||||
""" | """ | ||||
def __init__(self, pred=None, target=None, seq_len=None, class_in_dim=-1, padding_idx=-100, reduction='mean'): | |||||
def __init__(self, pred=None, target=None, seq_len=None, class_in_dim=-1, ignore_idx=-100, reduction='mean', **kwargs): | |||||
super(CrossEntropyLoss, self).__init__() | super(CrossEntropyLoss, self).__init__() | ||||
self._init_param_map(pred=pred, target=target, seq_len=seq_len) | self._init_param_map(pred=pred, target=target, seq_len=seq_len) | ||||
self.padding_idx = padding_idx | |||||
ignore_idx = kwargs.pop('padding_idx', ignore_idx) | |||||
self.ignore_idx = ignore_idx | |||||
assert reduction in ('mean', 'sum', 'none') | assert reduction in ('mean', 'sum', 'none') | ||||
self.reduction = reduction | self.reduction = reduction | ||||
self.class_in_dim = class_in_dim | self.class_in_dim = class_in_dim | ||||
@@ -237,7 +239,7 @@ class CrossEntropyLoss(LossBase): | |||||
def get_loss(self, pred, target, seq_len=None): | def get_loss(self, pred, target, seq_len=None): | ||||
if seq_len is not None and target.dim()>1: | if seq_len is not None and target.dim()>1: | ||||
mask = seq_len_to_mask(seq_len, max_len=target.size(1)).eq(False) | mask = seq_len_to_mask(seq_len, max_len=target.size(1)).eq(False) | ||||
target = target.masked_fill(mask, self.padding_idx) | |||||
target = target.masked_fill(mask, self.ignore_idx) | |||||
if pred.dim() > 2: | if pred.dim() > 2: | ||||
if self.class_in_dim == -1: | if self.class_in_dim == -1: | ||||
@@ -249,7 +251,7 @@ class CrossEntropyLoss(LossBase): | |||||
target = target.reshape(-1) | target = target.reshape(-1) | ||||
return F.cross_entropy(input=pred, target=target, | return F.cross_entropy(input=pred, target=target, | ||||
ignore_index=self.padding_idx, reduction=self.reduction) | |||||
ignore_index=self.ignore_idx, reduction=self.reduction) | |||||
class L1Loss(LossBase): | class L1Loss(LossBase): | ||||
@@ -311,27 +313,79 @@ class BCELoss(LossBase): | |||||
return F.binary_cross_entropy(input=pred, target=target, reduction=self.reduction) | return F.binary_cross_entropy(input=pred, target=target, reduction=self.reduction) | ||||
class BCEWithLogits(LossBase): | |||||
r""" | |||||
二分类交叉熵损失函数, 传入数据之前不需要做sigmoid操作 | |||||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||||
:param int class_in_dim: 在序列标注的场景中,pred可能的shape为(batch_size, max_len, num_classes) | |||||
或(batch_size, num_classes, max_len), CrossEntropyLoss需要知道哪一维是class的维度以计算loss。如果为-1,就根据pred的第 | |||||
二维是否等于target的第二维来判断是否需要交换pred的第二维和第三维,因为target的第二维是length的维度,如果这一维度上和pred相等, | |||||
那么pred可能第二维也是长度维(存在误判的可能,如果有误判的情况,请显示设置该值)。其它大于0的值则认为该维度是class的维度。 | |||||
:param str reduction: 支持 `mean` ,`sum` 和 `none` . | |||||
""" | |||||
def __init__(self, pred=None, target=None, class_in_dim=-1, reduction='mean'): | |||||
super(BCEWithLogits, self).__init__() | |||||
self._init_param_map(pred=pred, target=target) | |||||
assert reduction in ('mean', 'sum', 'none') | |||||
self.reduction = reduction | |||||
self.class_in_dim = class_in_dim | |||||
def get_loss(self, pred, target): | |||||
if pred.dim() > 2: | |||||
if self.class_in_dim == -1: | |||||
if pred.size(1) != target.size(1): # 有可能顺序替换了 | |||||
pred = pred.transpose(1, 2) | |||||
else: | |||||
pred = pred.transpose(-1, self.class_in_dim) | |||||
pred = pred.reshape(-1, pred.size(-1)) | |||||
target = target.reshape(-1) | |||||
return F.binary_cross_entropy_with_logits(input=pred, target=target, reduction=self.reduction) | |||||
class NLLLoss(LossBase): | class NLLLoss(LossBase): | ||||
r""" | r""" | ||||
负对数似然损失函数 | 负对数似然损失函数 | ||||
""" | """ | ||||
def __init__(self, pred=None, target=None, ignore_idx=-100, reduction='mean'): | |||||
def __init__(self, pred=None, target=None, seq_len=None, class_in_dim=-1, ignore_idx=-100, reduction='mean'): | |||||
r""" | r""" | ||||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | ||||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | ||||
:param seq_len: 句子的长度, 长度之外的token不会计算loss。仅在输出为3d时需要 | |||||
:param int class_in_dim: 在序列标注的场景中,pred可能的shape为(batch_size, max_len, num_classes) | |||||
或(batch_size, num_classes, max_len), CrossEntropyLoss需要知道哪一维是class的维度以计算loss。如果为-1,就根据pred的第 | |||||
二维是否等于target的第二维来判断是否需要交换pred的第二维和第三维,因为target的第二维是length的维度,如果这一维度上和pred相等, | |||||
那么pred可能第二维也是长度维(存在误判的可能,如果有误判的情况,请显示设置该值)。其它大于0的值则认为该维度是class的维度。 | |||||
:param ignore_idx: ignore的index,在计算loss时将忽略target中标号为ignore_idx的内容, 可以通过该值代替 | :param ignore_idx: ignore的index,在计算loss时将忽略target中标号为ignore_idx的内容, 可以通过该值代替 | ||||
传入seq_len. | 传入seq_len. | ||||
:param str reduction: 支持 `mean` ,`sum` 和 `none` . | :param str reduction: 支持 `mean` ,`sum` 和 `none` . | ||||
""" | """ | ||||
super(NLLLoss, self).__init__() | super(NLLLoss, self).__init__() | ||||
self._init_param_map(pred=pred, target=target) | |||||
self._init_param_map(pred=pred, target=target, seq_len=seq_len) | |||||
assert reduction in ('mean', 'sum', 'none') | assert reduction in ('mean', 'sum', 'none') | ||||
self.reduction = reduction | self.reduction = reduction | ||||
self.ignore_idx = ignore_idx | self.ignore_idx = ignore_idx | ||||
self.class_in_dim = class_in_dim | |||||
def get_loss(self, pred, target): | |||||
def get_loss(self, pred, target, seq_len=None): | |||||
if seq_len is not None and target.dim()>1: | |||||
mask = seq_len_to_mask(seq_len, max_len=target.size(1)).eq(False) | |||||
target = target.masked_fill(mask, self.ignore_idx) | |||||
if pred.dim() > 2: | |||||
if self.class_in_dim == -1: | |||||
if pred.size(1) != target.size(1): # 有可能顺序替换了 | |||||
pred = pred.transpose(1, 2) | |||||
else: | |||||
pred = pred.transpose(-1, self.class_in_dim) | |||||
pred = pred.reshape(-1, pred.size(-1)) | |||||
target = target.reshape(-1) | |||||
return F.nll_loss(input=pred, target=target, ignore_index=self.ignore_idx, reduction=self.reduction) | return F.nll_loss(input=pred, target=target, ignore_index=self.ignore_idx, reduction=self.reduction) | ||||
@@ -112,6 +112,108 @@ class BucketSampler(Sampler): | |||||
return list(chain(*batchs)) | return list(chain(*batchs)) | ||||
class ConstTokenNumSampler(Sampler): | |||||
""" | |||||
尽量保证每个batch的输入token数量是接近的。 | |||||
使用示例 | |||||
>>> # 假设已经有了tr_data并有一个field叫做seq_len保存了每个instance的token数量 | |||||
>>> from fastNLP import DataSetIter, Trainer | |||||
>>> sampler = ConstTokenNumSampler('src_seq_len', max_token=4096) | |||||
>>> | |||||
>>> # 直接将sampler传入Trainer中,此时batch_size参数的值会被忽略 | |||||
>>> trainer = Trainer(tr_data, model, optimizer=optimizer, loss=TranslationLoss(), | |||||
>>> batch_size=1, sampler=sampler, drop_last=False, update_every=1) | |||||
""" | |||||
def __init__(self, seq_len_field_name, max_token=4096, max_sentence=-1, need_be_multiple_of=1, num_bucket=-1): | |||||
""" | |||||
:param List[int] seq_len_field_name: 哪个field指示的sample的长度 | |||||
:param int max_token: 每个batch的最大的token数量 | |||||
:param int max_sentence: 每个batch最多多少个instance, -1表示根据max_token决定 | |||||
:param int need_be_multiple_of: 生成的batch的instance的数量需要是几的倍数,在DataParallel场景下会用到 | |||||
:param int num_bucket: 将数据按长度拆分为num_bucket个bucket,batch中的sample尽量在bucket之中进行组合,这样可以减少padding。 | |||||
""" | |||||
assert (max_sentence!=-1 and max_sentence>=need_be_multiple_of) or max_sentence<1 | |||||
self.seq_len_field_name = seq_len_field_name | |||||
self.num_bucket = num_bucket | |||||
self.max_token = max_token | |||||
self._max_sentence = max_sentence | |||||
self.need_be_multiple_of = need_be_multiple_of | |||||
def __call__(self, data_set): | |||||
assert len(data_set)>self.num_bucket, "The number of samples should be larger than buckets." | |||||
seq_len = data_set.get_field(self.seq_len_field_name) | |||||
self.seq_len = seq_len | |||||
seq_len_indice = [(length, i) for i, length in enumerate(seq_len)] | |||||
seq_len_indice.sort(key=lambda x: x[0]) | |||||
indice_in_buckets = [] | |||||
if self.num_bucket>0: | |||||
sample_per_bucket = len(seq_len_indice)//self.num_bucket | |||||
i = 0 | |||||
while len(indice_in_buckets)<len(seq_len_indice): | |||||
indice_in_buckets.append(seq_len_indice[i*sample_per_bucket:(i+1)*sample_per_bucket]) | |||||
i += 1 | |||||
else: | |||||
indice_in_buckets = [seq_len_indice] | |||||
self.indice_in_buckets = indice_in_buckets | |||||
self.get_new_order() | |||||
@property | |||||
def max_sentence(self): | |||||
if self._max_sentence<1: | |||||
return 100000000 | |||||
return self._max_sentence | |||||
@max_sentence.setter | |||||
def max_sentence(self, max_sentence): | |||||
self._max_sentence = max_sentence | |||||
def get_new_order(self): | |||||
np.random.shuffle(self.indice_in_buckets) | |||||
for bucket in self.indice_in_buckets: | |||||
np.random.shuffle(bucket) | |||||
indices = list(chain(*self.indice_in_buckets)) | |||||
batches = [] | |||||
cur_max_len = 0 | |||||
batch = [] | |||||
for length, i in indices: | |||||
max_len = max(length, cur_max_len) | |||||
if max_len*(len(batch)+1)>self.max_token or len(batch)>=self.max_sentence: | |||||
left_sample = len(batch) % self.need_be_multiple_of | |||||
add_samples = batch.copy() | |||||
cur_max_len =length | |||||
if left_sample!=0: | |||||
add_samples = add_samples[:-left_sample] | |||||
batch = batch[-left_sample:] | |||||
cur_max_len = max(cur_max_len, max(batch)) | |||||
else: | |||||
batch = [] | |||||
if len(add_samples)==0: | |||||
raise RuntimeError(f"The sample `{i}` is too long to make a batch with {self.need_be_multiple_of} samples.") | |||||
batches.append(add_samples) | |||||
else: | |||||
cur_max_len = max_len | |||||
batch.append(i) | |||||
if batch: | |||||
left_sample = len(batch) % self.need_be_multiple_of | |||||
add_samples = batch.copy() | |||||
if left_sample != 0: | |||||
add_samples = add_samples[:-left_sample].copy() | |||||
if add_samples: | |||||
batches.append(add_samples) | |||||
np.random.shuffle(batches) | |||||
self.batches = batches | |||||
def __iter__(self): | |||||
for batch in self.batches: | |||||
yield batch | |||||
self.get_new_order() | |||||
def __len__(self): | |||||
return len(self.batches) | |||||
class ConstantTokenNumSampler: | class ConstantTokenNumSampler: | ||||
""" | """ | ||||
尽量保证每个batch的输入token数量是接近的。 | 尽量保证每个batch的输入token数量是接近的。 | ||||
@@ -119,7 +221,7 @@ class ConstantTokenNumSampler: | |||||
使用示例 | 使用示例 | ||||
>>> # 假设已经有了tr_data并有一个field叫做seq_len保存了每个instance的token数量 | >>> # 假设已经有了tr_data并有一个field叫做seq_len保存了每个instance的token数量 | ||||
>>> from fastNLP import DataSetIter, Trainer | >>> from fastNLP import DataSetIter, Trainer | ||||
>>> sampler = BatchSampler(tr_data.get_field('seq_len').content, max_token=4096) | |||||
>>> sampler = ConstantTokenNumSampler(tr_data.get_field('seq_len').content, max_token=4096) | |||||
>>> tr_iter = DataSetIter(tr_data, | >>> tr_iter = DataSetIter(tr_data, | ||||
>>> batch_size=1, sampler=None, as_numpy=False, num_workers=0, pin_memory=False, | >>> batch_size=1, sampler=None, as_numpy=False, num_workers=0, pin_memory=False, | ||||
>>> drop_last=False, timeout=0, worker_init_fn=None, | >>> drop_last=False, timeout=0, worker_init_fn=None, | ||||
@@ -128,7 +230,6 @@ class ConstantTokenNumSampler: | |||||
>>> # 直接将tr_iter传入Trainer中,此时batch_size参数的值会被忽略 | >>> # 直接将tr_iter传入Trainer中,此时batch_size参数的值会被忽略 | ||||
>>> trainer = Trainer(tr_iter, model, optimizer=optimizer, loss=TranslationLoss(), | >>> trainer = Trainer(tr_iter, model, optimizer=optimizer, loss=TranslationLoss(), | ||||
>>> batch_size=1, sampler=None, drop_last=False, update_every=1) | >>> batch_size=1, sampler=None, drop_last=False, update_every=1) | ||||
""" | """ | ||||
def __init__(self, seq_len, max_token=4096, max_sentence=-1, need_be_multiple_of=1, num_bucket=-1): | def __init__(self, seq_len, max_token=4096, max_sentence=-1, need_be_multiple_of=1, num_bucket=-1): | ||||
""" | """ | ||||
@@ -221,7 +322,8 @@ class SortedSampler(Sampler): | |||||
def __init__(self, seq_len_field_name='seq_len', descending=True): | def __init__(self, seq_len_field_name='seq_len', descending=True): | ||||
""" | """ | ||||
:param str seq_len_field_name: 对应序列长度的 `field` 的名字 | |||||
:param str seq_len_field_name: 按哪个field进行排序。如果传入的field是数字,则直接按照该数字大小排序;如果传入的field不是 | |||||
数字,则使用该field的长度进行排序 | |||||
:param bool descending: 是否降序排列 | :param bool descending: 是否降序排列 | ||||
""" | """ | ||||
self.seq_len_field_name = seq_len_field_name | self.seq_len_field_name = seq_len_field_name | ||||
@@ -229,6 +331,11 @@ class SortedSampler(Sampler): | |||||
def __call__(self, data_set): | def __call__(self, data_set): | ||||
seq_lens = data_set.get_field(self.seq_len_field_name).content | seq_lens = data_set.get_field(self.seq_len_field_name).content | ||||
try: | |||||
seq_lens = list(map(len, seq_lens)) | |||||
except: | |||||
pass | |||||
orders = np.argsort(seq_lens).tolist() # 从小到大的顺序 | orders = np.argsort(seq_lens).tolist() # 从小到大的顺序 | ||||
if self.descending: | if self.descending: | ||||
orders = orders[::-1] | orders = orders[::-1] | ||||
@@ -53,6 +53,8 @@ from .utils import _move_dict_value_to_device | |||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
from .utils import _get_model_device | from .utils import _get_model_device | ||||
from .utils import _move_model_to_device | from .utils import _move_model_to_device | ||||
from .utils import _build_fp16_env | |||||
from .utils import _can_use_fp16 | |||||
from ._parallel_utils import _data_parallel_wrapper | from ._parallel_utils import _data_parallel_wrapper | ||||
from ._parallel_utils import _model_contains_inner_module | from ._parallel_utils import _model_contains_inner_module | ||||
from functools import partial | from functools import partial | ||||
@@ -70,7 +72,7 @@ class Tester(object): | |||||
""" | """ | ||||
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True, | def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True, | ||||
**kwargs): | |||||
fp16=False, **kwargs): | |||||
r""" | r""" | ||||
:param ~fastNLP.DataSet,~fastNLP.BatchIter data: 需要测试的数据集 | :param ~fastNLP.DataSet,~fastNLP.BatchIter data: 需要测试的数据集 | ||||
@@ -93,7 +95,9 @@ class Tester(object): | |||||
如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 | 如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 | ||||
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | :param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | ||||
:param bool use_tqdm: 是否使用tqdm来显示测试进度; 如果为False,则不会显示任何内容。 | :param bool use_tqdm: 是否使用tqdm来显示测试进度; 如果为False,则不会显示任何内容。 | ||||
:param kwargs: 支持传入sampler控制测试顺序 | |||||
:param bool fp16: 是否使用float16进行验证 | |||||
:param kwargs: | |||||
Sampler sampler: 支持传入sampler控制测试顺序 | |||||
""" | """ | ||||
super(Tester, self).__init__() | super(Tester, self).__init__() | ||||
@@ -147,7 +151,11 @@ class Tester(object): | |||||
else: | else: | ||||
self._predict_func = self._model.forward | self._predict_func = self._model.forward | ||||
self._predict_func_wrapper = self._model.forward | self._predict_func_wrapper = self._model.forward | ||||
if fp16: | |||||
_can_use_fp16(model=model, device=device, func=self._predict_func) | |||||
self.auto_cast, _grad_scaler = _build_fp16_env(not fp16) | |||||
def test(self): | def test(self): | ||||
r"""开始进行验证,并返回验证结果。 | r"""开始进行验证,并返回验证结果。 | ||||
@@ -172,12 +180,13 @@ class Tester(object): | |||||
for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | ||||
pred_dict = self._data_forward(self._predict_func, batch_x) | |||||
if not isinstance(pred_dict, dict): | |||||
raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} " | |||||
f"must be `dict`, got {type(pred_dict)}.") | |||||
for metric in self.metrics: | |||||
metric(pred_dict, batch_y) | |||||
with self.auto_cast(): | |||||
pred_dict = self._data_forward(self._predict_func, batch_x) | |||||
if not isinstance(pred_dict, dict): | |||||
raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} " | |||||
f"must be `dict`, got {type(pred_dict)}.") | |||||
for metric in self.metrics: | |||||
metric(pred_dict, batch_y) | |||||
if self.use_tqdm: | if self.use_tqdm: | ||||
pbar.update() | pbar.update() | ||||
@@ -342,7 +342,7 @@ from .losses import _prepare_losser | |||||
from .metrics import _prepare_metrics | from .metrics import _prepare_metrics | ||||
from .optimizer import Optimizer | from .optimizer import Optimizer | ||||
from .sampler import Sampler | from .sampler import Sampler | ||||
from .sampler import RandomSampler | |||||
from .sampler import RandomSampler, ConstTokenNumSampler | |||||
from .tester import Tester | from .tester import Tester | ||||
from .utils import _CheckError | from .utils import _CheckError | ||||
from .utils import _build_args | from .utils import _build_args | ||||
@@ -352,6 +352,8 @@ from .utils import _move_dict_value_to_device | |||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
from .utils import _get_model_device | from .utils import _get_model_device | ||||
from .utils import _move_model_to_device | from .utils import _move_model_to_device | ||||
from .utils import _build_fp16_env | |||||
from .utils import _can_use_fp16 | |||||
from ._parallel_utils import _model_contains_inner_module | from ._parallel_utils import _model_contains_inner_module | ||||
from ._logger import logger | from ._logger import logger | ||||
@@ -373,7 +375,7 @@ class Trainer(object): | |||||
num_workers=0, n_epochs=10, print_every=5, | num_workers=0, n_epochs=10, print_every=5, | ||||
dev_data=None, metrics=None, metric_key=None, | dev_data=None, metrics=None, metric_key=None, | ||||
validate_every=-1, save_path=None, use_tqdm=True, device=None, | validate_every=-1, save_path=None, use_tqdm=True, device=None, | ||||
callbacks=None, check_code_level=0, **kwargs): | |||||
callbacks=None, check_code_level=0, fp16=False, **kwargs): | |||||
r""" | r""" | ||||
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型或 :class:`~fastNLP.BatchIter` 的子类 | :param train_data: 训练集, :class:`~fastNLP.DataSet` 类型或 :class:`~fastNLP.BatchIter` 的子类 | ||||
:param nn.modules model: 待训练的模型 | :param nn.modules model: 待训练的模型 | ||||
@@ -422,9 +424,14 @@ class Trainer(object): | |||||
报告警告信息; 2: 有任何field没有被使用都报错. 检查的原理是通过使用很小的batch(默认2个sample)来运行代码,但是 | 报告警告信息; 2: 有任何field没有被使用都报错. 检查的原理是通过使用很小的batch(默认2个sample)来运行代码,但是 | ||||
这个过程理论上不会修改任何参数,只是会检查能否运行。但如果(1)模型中存在将batch_size写为某个固定值的情况; | 这个过程理论上不会修改任何参数,只是会检查能否运行。但如果(1)模型中存在将batch_size写为某个固定值的情况; | ||||
(2)模型中存在累加前向计算次数的,可能会多计算1次。以上情况建议将check_code_level设置为-1。 | (2)模型中存在累加前向计算次数的,可能会多计算1次。以上情况建议将check_code_level设置为-1。 | ||||
:param bool fp16: 是否使用fp16进行训练。 | |||||
:param kwargs: 支持配置可选参数 | :param kwargs: 支持配置可选参数 | ||||
bool test_use_tqdm: 在dev上验证的时候是否开启tqdm | bool test_use_tqdm: 在dev上验证的时候是否开启tqdm | ||||
Sampler test_sampler: 在evaluate的时候使用的sampler | Sampler test_sampler: 在evaluate的时候使用的sampler | ||||
bool test_use_fp16: evalute的时候是否使用fp16测试,默认与fp16相同的取值。 | |||||
bool set_grad_to_none: 在zero_grad的时候是否将gradient设置为None,而不是设置为zero | |||||
GradScaler grad_scaler: 仅在fp16为True时有效,如果不使用torch.cuda.amp.GradScaler的初始化参数,可传入一个已经初始化后的 | |||||
grad_scaler。 | |||||
""" | """ | ||||
super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
if not isinstance(model, nn.Module): | if not isinstance(model, nn.Module): | ||||
@@ -488,6 +495,15 @@ class Trainer(object): | |||||
sampler = RandomSampler() | sampler = RandomSampler() | ||||
elif hasattr(sampler, 'set_batch_size'): | elif hasattr(sampler, 'set_batch_size'): | ||||
sampler.set_batch_size(batch_size) | sampler.set_batch_size(batch_size) | ||||
if isinstance(sampler, ConstTokenNumSampler): # 直接使用固定token数量的Sampler | |||||
assert isinstance(train_data, | |||||
DataSet), f"When sampler is `ConstTokenNumSampler`, the train_data must" \ | |||||
f" be `DataSet`." | |||||
sampler(train_data) | |||||
train_data = DataSetIter(train_data, | |||||
batch_size=1, sampler=None, as_numpy=False, num_workers=num_workers, | |||||
pin_memory=False, drop_last=drop_last, timeout=0, worker_init_fn=None, | |||||
batch_sampler=sampler) | |||||
if isinstance(train_data, DataSet): | if isinstance(train_data, DataSet): | ||||
self.data_iterator = DataSetIter(dataset=train_data, batch_size=batch_size, sampler=sampler, | self.data_iterator = DataSetIter(dataset=train_data, batch_size=batch_size, sampler=sampler, | ||||
@@ -505,6 +521,23 @@ class Trainer(object): | |||||
self._forward_func = self.model.module.forward | self._forward_func = self.model.module.forward | ||||
else: | else: | ||||
self._forward_func = self.model.forward | self._forward_func = self.model.forward | ||||
self.fp16 = fp16 | |||||
self.verbose = kwargs.get('verbose', 0) | |||||
# check fp16相关的设置 | |||||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | |||||
self.grad_scaler = _grad_scaler() | |||||
if self.fp16: | |||||
_can_use_fp16(device=device, model=model, func=self._forward_func) | |||||
grad_scaler = kwargs.get('grad_scaler', None) | |||||
if grad_scaler is not None: | |||||
self.grad_scaler = grad_scaler | |||||
else: | |||||
self.grad_scaler = _grad_scaler() | |||||
self.test_use_fp16 = kwargs.get('test_use_fp16', fp16) | |||||
self.set_grad_to_none = kwargs.get('set_grad_to_none', True) | |||||
if check_code_level > -1: | if check_code_level > -1: | ||||
# _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的field名与模型的输入 | # _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的field名与模型的输入 | ||||
# 名是否匹配 | # 名是否匹配 | ||||
@@ -545,15 +578,15 @@ class Trainer(object): | |||||
elif optimizer is None: | elif optimizer is None: | ||||
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3) | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3) | ||||
else: | else: | ||||
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) | |||||
if not (hasattr(optimizer, 'step') and callable(optimizer.step)): | |||||
raise TypeError("optimizer must have a callable step() function.") | |||||
else: | |||||
self.optimizer = optimizer | |||||
self.logger = logger | self.logger = logger | ||||
self.use_tqdm = use_tqdm | self.use_tqdm = use_tqdm | ||||
if 'test_use_tqdm' in kwargs: | |||||
self.test_use_tqdm = kwargs.get('test_use_tqdm') | |||||
else: | |||||
self.test_use_tqdm = self.use_tqdm | |||||
self.test_use_tqdm = kwargs.get('test_use_tqdm', self.use_tqdm) | |||||
self.pbar = None | self.pbar = None | ||||
self.print_every = abs(self.print_every) | self.print_every = abs(self.print_every) | ||||
self.kwargs = kwargs | self.kwargs = kwargs | ||||
@@ -565,7 +598,8 @@ class Trainer(object): | |||||
device=None, # 由上面的部分处理device | device=None, # 由上面的部分处理device | ||||
verbose=0, | verbose=0, | ||||
use_tqdm=self.test_use_tqdm, | use_tqdm=self.test_use_tqdm, | ||||
sampler=kwargs.get('test_sampler', None)) | |||||
sampler=kwargs.get('test_sampler', None), | |||||
fp16=self.test_use_fp16) | |||||
self.start_time = None # start timestamp | self.start_time = None # start timestamp | ||||
@@ -575,7 +609,7 @@ class Trainer(object): | |||||
self.callback_manager = CallbackManager(env={"trainer": self}, | self.callback_manager = CallbackManager(env={"trainer": self}, | ||||
callbacks=callbacks) | callbacks=callbacks) | ||||
def train(self, load_best_model=True, on_exception='auto'): | |||||
def train(self, load_best_model=True, on_exception='auto', **kwargs): | |||||
r""" | r""" | ||||
使用该函数使Trainer开始训练。 | 使用该函数使Trainer开始训练。 | ||||
@@ -584,6 +618,8 @@ class Trainer(object): | |||||
:param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。 | :param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。 | ||||
支持'ignore','raise', 'auto': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出; | 支持'ignore','raise', 'auto': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出; | ||||
'auto'将ignore以下两种Exception: CallbackException与KeyboardInterrupt, raise其它exception. | 'auto'将ignore以下两种Exception: CallbackException与KeyboardInterrupt, raise其它exception. | ||||
:param kwargs: | |||||
int verbose: 为1时在发生异常时会打印异常发生时batch中的数据在dataset中的index | |||||
:return dict: 返回一个字典类型的数据, | :return dict: 返回一个字典类型的数据, | ||||
内含以下内容:: | 内含以下内容:: | ||||
@@ -596,6 +632,7 @@ class Trainer(object): | |||||
""" | """ | ||||
results = {} | results = {} | ||||
verbose = kwargs.get('verbose', 0) | |||||
if self.n_epochs <= 0: | if self.n_epochs <= 0: | ||||
self.logger.info(f"training epoch is {self.n_epochs}, nothing was done.") | self.logger.info(f"training epoch is {self.n_epochs}, nothing was done.") | ||||
results['seconds'] = 0. | results['seconds'] = 0. | ||||
@@ -617,6 +654,8 @@ class Trainer(object): | |||||
except BaseException as e: | except BaseException as e: | ||||
self.callback_manager.on_exception(e) | self.callback_manager.on_exception(e) | ||||
if verbose>0: | |||||
self.logger.info(f"The data indices for current batch are: {self.data_iterator.cur_batch_indices}.") | |||||
if on_exception == 'auto': | if on_exception == 'auto': | ||||
if not isinstance(e, (CallbackException, KeyboardInterrupt)): | if not isinstance(e, (CallbackException, KeyboardInterrupt)): | ||||
raise e | raise e | ||||
@@ -674,7 +713,8 @@ class Trainer(object): | |||||
# edit prediction | # edit prediction | ||||
self.callback_manager.on_loss_begin(batch_y, prediction) | self.callback_manager.on_loss_begin(batch_y, prediction) | ||||
loss = self._compute_loss(prediction, batch_y).mean() | |||||
with self.auto_cast(): | |||||
loss = self._compute_loss(prediction, batch_y).mean() | |||||
loss = loss / self.update_every | loss = loss / self.update_every | ||||
avg_loss += loss.item() | avg_loss += loss.item() | ||||
@@ -759,11 +799,13 @@ class Trainer(object): | |||||
""" | """ | ||||
if self.step % self.update_every == 0: | if self.step % self.update_every == 0: | ||||
self.optimizer.step() | |||||
self.grad_scaler.step(self.optimizer) | |||||
self.grad_scaler.update() | |||||
def _data_forward(self, network, x): | def _data_forward(self, network, x): | ||||
x = _build_args(self._forward_func, **x) | x = _build_args(self._forward_func, **x) | ||||
y = network(**x) | |||||
with self.auto_cast(): | |||||
y = network(**x) | |||||
if not isinstance(y, dict): | if not isinstance(y, dict): | ||||
raise TypeError( | raise TypeError( | ||||
f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.") | f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.") | ||||
@@ -777,8 +819,22 @@ class Trainer(object): | |||||
For PyTorch, just do "loss.backward()" | For PyTorch, just do "loss.backward()" | ||||
""" | """ | ||||
if (self.step-1) % self.update_every == 0: | if (self.step-1) % self.update_every == 0: | ||||
self.model.zero_grad() | |||||
loss.backward() | |||||
self._clear_grad(self.optimizer, self.set_grad_to_none) | |||||
self.grad_scaler.scale(loss).backward() | |||||
def _clear_grad(self, optimizer, set_to_none=True): | |||||
param_groups = optimizer.param_groups | |||||
for group in param_groups: | |||||
for p in group['params']: | |||||
if p.grad is not None: | |||||
if set_to_none: | |||||
p.grad = None | |||||
else: | |||||
if p.grad.grad_fn is not None: | |||||
p.grad.detach_() | |||||
else: | |||||
p.grad.requires_grad_(False) | |||||
p.grad.zero_() | |||||
def _compute_loss(self, predict, truth): | def _compute_loss(self, predict, truth): | ||||
r"""Compute loss given prediction and ground truth. | r"""Compute loss given prediction and ground truth. | ||||
@@ -12,23 +12,20 @@ import inspect | |||||
import os | import os | ||||
import warnings | import warnings | ||||
from collections import Counter, namedtuple | from collections import Counter, namedtuple | ||||
from copy import deepcopy | |||||
from typing import List | from typing import List | ||||
import _pickle | import _pickle | ||||
import numpy as np | import numpy as np | ||||
import torch | |||||
import torch.nn as nn | import torch.nn as nn | ||||
from prettytable import PrettyTable | from prettytable import PrettyTable | ||||
from ._logger import logger | from ._logger import logger | ||||
from ._parallel_utils import _model_contains_inner_module | from ._parallel_utils import _model_contains_inner_module | ||||
# from .vocabulary import Vocabulary | # from .vocabulary import Vocabulary | ||||
import torch | |||||
import contextlib | |||||
from pkg_resources import parse_version | |||||
try: | |||||
from apex import amp | |||||
except: | |||||
amp = None | |||||
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | ||||
'varargs']) | 'varargs']) | ||||
@@ -271,7 +268,7 @@ def _prepare_cache_filepath(filepath): | |||||
raise RuntimeError("The cache_file_path must be a file, not a directory.") | raise RuntimeError("The cache_file_path must be a file, not a directory.") | ||||
cache_dir = os.path.dirname(_cache_filepath) | cache_dir = os.path.dirname(_cache_filepath) | ||||
if not os.path.exists(cache_dir): | if not os.path.exists(cache_dir): | ||||
os.makedirs(cache_dir) | |||||
os.makedirs(cache_dir, exist_ok=True) | |||||
def cache_results(_cache_fp, _refresh=False, _verbose=1): | def cache_results(_cache_fp, _refresh=False, _verbose=1): | ||||
@@ -1032,8 +1029,92 @@ def sub_column(string: str, c: int, c_size: int, title: str) -> str: | |||||
return res | return res | ||||
def _check_fp16(): | |||||
if amp is None: | |||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") | |||||
if not torch.backends.cudnn.enabled: | |||||
raise RuntimeError("Amp requires cudnn backend to be enabled.") | |||||
def _is_function_contains_autocast(func): | |||||
""" | |||||
检查func是否包含autocast,(1)是否使用了autocast的修饰器或, (2)使用使用with autocast()环境 | |||||
:param func: 待检查的函数 | |||||
""" | |||||
import re | |||||
source = inspect.getsource(func) | |||||
lines = source.split('\n') | |||||
for line in lines: | |||||
line = line.strip() | |||||
if re.search(r'@[\w\.]*autocast\(\w*\)', line): | |||||
raise RuntimeError("Please do not use `autocast()` decorator, use `with autocast():` instead. Please refer to" | |||||
" https://pytorch.org/docs/stable/notes/amp_examples.html#dataparallel-in-a-single-process ") | |||||
if re.search(r'with [\w\.]*autocast\(\w*\):', line): | |||||
return True | |||||
return False | |||||
class DummyGradScaler: | |||||
""" | |||||
用于Dummy pytorch的GradScaler对象,防止重复写大量的if判断 | |||||
""" | |||||
def __init__(self, *args, **kwargs): | |||||
pass | |||||
def get_scale(self): | |||||
return 1.0 | |||||
def is_enabled(self): | |||||
return False | |||||
def scale(self, outputs): | |||||
return outputs | |||||
def step(self, optimizer, *args, **kwargs): | |||||
optimizer.step(*args, **kwargs) | |||||
def update(self, new_scale=None): | |||||
pass | |||||
def unscale_(self, optimizer): | |||||
pass | |||||
def load_state_dict(self, state_dict): | |||||
pass | |||||
def state_dict(self): | |||||
return {} | |||||
def _build_fp16_env(dummy=False): | |||||
if dummy: | |||||
autocast = contextlib.ExitStack | |||||
GradScaler = DummyGradScaler | |||||
else: | |||||
if not torch.cuda.is_available(): | |||||
raise RuntimeError("No cuda") | |||||
if torch.cuda.get_device_capability(0)[0] < 7: | |||||
warnings.warn( | |||||
"NOTE: your device does NOT support faster training with fp16, " | |||||
"please switch to FP32 which is likely to be faster" | |||||
) | |||||
try: | |||||
from torch.cuda.amp import autocast, GradScaler | |||||
except ImportError: | |||||
raise RuntimeError("torch version too low (less than 1.6)") | |||||
return autocast, GradScaler | |||||
def _can_use_fp16(device, model, func): | |||||
if parse_version(torch.__version__) < parse_version('1.6'): | |||||
raise RuntimeError("Pytorch supports float16 after version 1.6, please upgrade your pytorch version.") | |||||
model_device = _get_model_device(model) | |||||
if device is None and model_device is not None and model_device.type != 'cuda': | |||||
raise RuntimeError("You have to run in cuda device to use fp16.") | |||||
if isinstance(device, str): | |||||
if device=='cpu': | |||||
raise RuntimeError("You have to run in cuda device to use fp16.") | |||||
if isinstance(device, torch.device) and device.type=='cpu': | |||||
raise RuntimeError("You have to run in cuda device to use fp16.") | |||||
if (_model_contains_inner_module(model) or (isinstance(device, list) and len(device) > 1)): | |||||
# 需要提醒用户 | |||||
if not _is_function_contains_autocast(func): | |||||
raise RuntimeError("When use fp16 in Parallel Training, you have to set autocast() in your forward " | |||||
"function as described in " | |||||
"https://pytorch.org/docs/stable/notes/amp_examples.html#dataparallel-in-a-single-process") |
@@ -125,7 +125,7 @@ class Vocabulary(object): | |||||
r"""依次增加序列中词在词典中的出现频率 | r"""依次增加序列中词在词典中的出现频率 | ||||
:param list word_lst: a list of strings | :param list word_lst: a list of strings | ||||
:param bool no_create_entry: 在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 | |||||
:param bool no_create_entry: 如果词语来自于非训练集建议设置为True。在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 | |||||
如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 | 如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 | ||||
的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 | 的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 | ||||
加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 | 加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 | ||||
@@ -142,7 +142,7 @@ class Vocabulary(object): | |||||
增加一个新词在词典中的出现频率 | 增加一个新词在词典中的出现频率 | ||||
:param str word: 新词 | :param str word: 新词 | ||||
:param bool no_create_entry: 在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 | |||||
:param bool no_create_entry: 如果词语来自于非训练集建议设置为True。在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 | |||||
如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 | 如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 | ||||
的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 | 的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 | ||||
加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 | 加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 | ||||
@@ -175,7 +175,7 @@ class Vocabulary(object): | |||||
增加一个新词在词典中的出现频率 | 增加一个新词在词典中的出现频率 | ||||
:param str word: 新词 | :param str word: 新词 | ||||
:param bool no_create_entry: 在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 | |||||
:param bool no_create_entry: 如果词语来自于非训练集建议设置为True。在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 | |||||
如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 | 如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 | ||||
的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 | 的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 | ||||
加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 | 加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 | ||||
@@ -190,7 +190,7 @@ class Vocabulary(object): | |||||
依次增加序列中词在词典中的出现频率 | 依次增加序列中词在词典中的出现频率 | ||||
:param list[str] word_lst: 词的序列 | :param list[str] word_lst: 词的序列 | ||||
:param bool no_create_entry: 在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 | |||||
:param bool no_create_entry: 如果词语来自于非训练集建议设置为True。在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 | |||||
如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 | 如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 | ||||
的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 | 的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 | ||||
加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 | 加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 | ||||
@@ -344,7 +344,7 @@ class Vocabulary(object): | |||||
:param str,List[str] field_name: 可为 ``str`` 或 ``List[str]`` . | :param str,List[str] field_name: 可为 ``str`` 或 ``List[str]`` . | ||||
构建词典所使用的 field(s), 支持一个或多个field,若有多个 DataSet, 每个DataSet都必须有这些field. 目前支持的field结构 | 构建词典所使用的 field(s), 支持一个或多个field,若有多个 DataSet, 每个DataSet都必须有这些field. 目前支持的field结构 | ||||
: ``str`` , ``List[str]`` | : ``str`` , ``List[str]`` | ||||
:param no_create_entry_dataset: 可以传入DataSet, List[DataSet]或者None(默认),该选项用在接下来的模型会使用pretrain | |||||
:param no_create_entry_dataset: 可以传入DataSet, List[DataSet]或者None(默认), 建议直接将非训练数据都传入到这个参数。该选项用在接下来的模型会使用pretrain | |||||
的embedding(包括glove, word2vec, elmo与bert)且会finetune的情况。如果仅使用来自于train的数据建立vocabulary,会导致test与dev | 的embedding(包括glove, word2vec, elmo与bert)且会finetune的情况。如果仅使用来自于train的数据建立vocabulary,会导致test与dev | ||||
中的数据无法充分利用到来自于预训练embedding的信息,所以在建立词表的时候将test与dev考虑进来会使得最终的结果更好。 | 中的数据无法充分利用到来自于预训练embedding的信息,所以在建立词表的时候将test与dev考虑进来会使得最终的结果更好。 | ||||
如果一个词出现在了train中,但是没在预训练模型中,embedding会为它用unk初始化,但它是单独的一个vector,如果 | 如果一个词出现在了train中,但是没在预训练模型中,embedding会为它用unk初始化,但它是单独的一个vector,如果 | ||||
@@ -108,13 +108,14 @@ class BertEmbedding(ContextualEmbedding): | |||||
self._word_sep_index = vocab['[SEP]'] | self._word_sep_index = vocab['[SEP]'] | ||||
self._word_cls_index = -100 | self._word_cls_index = -100 | ||||
if '[CLS]' in vocab: | if '[CLS]' in vocab: | ||||
self._word_cls_index = vocab['CLS'] | |||||
self._word_cls_index = vocab['[CLS]'] | |||||
min_freq = kwargs.get('min_freq', 1) | |||||
min_freq = kwargs.pop('min_freq', 1) | |||||
self._min_freq = min_freq | self._min_freq = min_freq | ||||
self.model = _BertWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | self.model = _BertWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | ||||
pool_method=pool_method, include_cls_sep=include_cls_sep, | pool_method=pool_method, include_cls_sep=include_cls_sep, | ||||
pooled_cls=pooled_cls, min_freq=min_freq, auto_truncate=auto_truncate) | |||||
pooled_cls=pooled_cls, min_freq=min_freq, auto_truncate=auto_truncate, | |||||
**kwargs) | |||||
self.requires_grad = requires_grad | self.requires_grad = requires_grad | ||||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | ||||
@@ -367,32 +368,44 @@ class BertWordPieceEncoder(nn.Module): | |||||
class _BertWordModel(nn.Module): | class _BertWordModel(nn.Module): | ||||
def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | ||||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): | |||||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2, | |||||
**kwargs): | |||||
super().__init__() | super().__init__() | ||||
if isinstance(layers, list): | if isinstance(layers, list): | ||||
self.layers = [int(l) for l in layers] | self.layers = [int(l) for l in layers] | ||||
elif isinstance(layers, str): | elif isinstance(layers, str): | ||||
self.layers = list(map(int, layers.split(','))) | |||||
if layers.lower() == 'all': | |||||
self.layers = None | |||||
else: | |||||
self.layers = list(map(int, layers.split(','))) | |||||
else: | else: | ||||
raise TypeError("`layers` only supports str or list[int]") | raise TypeError("`layers` only supports str or list[int]") | ||||
assert len(self.layers) > 0, "There is no layer selected!" | |||||
neg_num_output_layer = -16384 | neg_num_output_layer = -16384 | ||||
pos_num_output_layer = 0 | pos_num_output_layer = 0 | ||||
for layer in self.layers: | |||||
if layer < 0: | |||||
neg_num_output_layer = max(layer, neg_num_output_layer) | |||||
else: | |||||
pos_num_output_layer = max(layer, pos_num_output_layer) | |||||
if self.layers is None: | |||||
neg_num_output_layer = -1 | |||||
else: | |||||
for layer in self.layers: | |||||
if layer < 0: | |||||
neg_num_output_layer = max(layer, neg_num_output_layer) | |||||
else: | |||||
pos_num_output_layer = max(layer, pos_num_output_layer) | |||||
self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) | |||||
self.tokenizer = BertTokenizer.from_pretrained(model_dir_or_name) | |||||
self.encoder = BertModel.from_pretrained(model_dir_or_name, | self.encoder = BertModel.from_pretrained(model_dir_or_name, | ||||
neg_num_output_layer=neg_num_output_layer, | neg_num_output_layer=neg_num_output_layer, | ||||
pos_num_output_layer=pos_num_output_layer) | |||||
pos_num_output_layer=pos_num_output_layer, | |||||
**kwargs) | |||||
self._max_position_embeddings = self.encoder.config.max_position_embeddings | self._max_position_embeddings = self.encoder.config.max_position_embeddings | ||||
# 检查encoder_layer_number是否合理 | # 检查encoder_layer_number是否合理 | ||||
encoder_layer_number = len(self.encoder.encoder.layer) | encoder_layer_number = len(self.encoder.encoder.layer) | ||||
if self.layers is None: | |||||
self.layers = [idx for idx in range(encoder_layer_number + 1)] | |||||
logger.info(f'Bert Model will return {len(self.layers)} layers (layer-0 ' | |||||
f'is embedding result): {self.layers}') | |||||
assert len(self.layers) > 0, "There is no layer selected!" | |||||
for layer in self.layers: | for layer in self.layers: | ||||
if layer < 0: | if layer < 0: | ||||
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | ||||
@@ -417,17 +430,17 @@ class _BertWordModel(nn.Module): | |||||
word = '[PAD]' | word = '[PAD]' | ||||
elif index == vocab.unknown_idx: | elif index == vocab.unknown_idx: | ||||
word = '[UNK]' | word = '[UNK]' | ||||
elif vocab.word_count[word]<min_freq: | |||||
elif vocab.word_count[word] < min_freq: | |||||
word = '[UNK]' | word = '[UNK]' | ||||
word_pieces = self.tokenzier.wordpiece_tokenizer.tokenize(word) | |||||
word_pieces = self.tokenzier.convert_tokens_to_ids(word_pieces) | |||||
word_pieces = self.tokenizer.wordpiece_tokenizer.tokenize(word) | |||||
word_pieces = self.tokenizer.convert_tokens_to_ids(word_pieces) | |||||
word_to_wordpieces.append(word_pieces) | word_to_wordpieces.append(word_pieces) | ||||
word_pieces_lengths.append(len(word_pieces)) | word_pieces_lengths.append(len(word_pieces)) | ||||
self._cls_index = self.tokenzier.vocab['[CLS]'] | |||||
self._sep_index = self.tokenzier.vocab['[SEP]'] | |||||
self._cls_index = self.tokenizer.vocab['[CLS]'] | |||||
self._sep_index = self.tokenizer.vocab['[SEP]'] | |||||
self._word_pad_index = vocab.padding_idx | self._word_pad_index = vocab.padding_idx | ||||
self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece | |||||
self.word_to_wordpieces = np.array(word_to_wordpieces) | |||||
self._wordpiece_pad_index = self.tokenizer.vocab['[PAD]'] # 需要用于生成word_piece | |||||
self.word_to_wordpieces = np.array(word_to_wordpieces, dtype=object) | |||||
self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths)) | self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths)) | ||||
logger.debug("Successfully generate word pieces.") | logger.debug("Successfully generate word pieces.") | ||||
@@ -481,14 +494,15 @@ class _BertWordModel(nn.Module): | |||||
token_type_ids = torch.zeros_like(word_pieces) | token_type_ids = torch.zeros_like(word_pieces) | ||||
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 | # 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 | ||||
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] | # all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] | ||||
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks, | |||||
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, | |||||
attention_mask=attn_masks, | |||||
output_all_encoded_layers=True) | output_all_encoded_layers=True) | ||||
# output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size | # output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size | ||||
if self.include_cls_sep: | if self.include_cls_sep: | ||||
s_shift = 1 | s_shift = 1 | ||||
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2, | outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2, | ||||
bert_outputs[-1].size(-1)) | |||||
bert_outputs[-1].size(-1)) | |||||
else: | else: | ||||
s_shift = 0 | s_shift = 0 | ||||
@@ -552,7 +566,7 @@ class _BertWordModel(nn.Module): | |||||
:param str folder: | :param str folder: | ||||
:return: | :return: | ||||
""" | """ | ||||
self.tokenzier.save_pretrained(folder) | |||||
self.tokenizer.save_pretrained(folder) | |||||
self.encoder.save_pretrained(folder) | self.encoder.save_pretrained(folder) | ||||
@@ -565,7 +579,7 @@ class _BertWordPieceModel(nn.Module): | |||||
def __init__(self, model_dir_or_name: str, layers: str = '-1', pooled_cls: bool=False): | def __init__(self, model_dir_or_name: str, layers: str = '-1', pooled_cls: bool=False): | ||||
super().__init__() | super().__init__() | ||||
self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) | |||||
self.tokenizer = BertTokenizer.from_pretrained(model_dir_or_name) | |||||
self.encoder = BertModel.from_pretrained(model_dir_or_name) | self.encoder = BertModel.from_pretrained(model_dir_or_name) | ||||
# 检查encoder_layer_number是否合理 | # 检查encoder_layer_number是否合理 | ||||
encoder_layer_number = len(self.encoder.encoder.layer) | encoder_layer_number = len(self.encoder.encoder.layer) | ||||
@@ -585,10 +599,10 @@ class _BertWordPieceModel(nn.Module): | |||||
assert layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | assert layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | ||||
f"a bert model with {encoder_layer_number} layers." | f"a bert model with {encoder_layer_number} layers." | ||||
self._cls_index = self.tokenzier.cls_index | |||||
self._sep_index = self.tokenzier.sep_index | |||||
self._wordpiece_unknown_index = self.tokenzier.unk_index | |||||
self._wordpiece_pad_index = self.tokenzier.pad_index # 需要用于生成word_piece | |||||
self._cls_index = self.tokenizer.cls_index | |||||
self._sep_index = self.tokenizer.sep_index | |||||
self._wordpiece_unknown_index = self.tokenizer.unk_index | |||||
self._wordpiece_pad_index = self.tokenizer.pad_index # 需要用于生成word_piece | |||||
self.pooled_cls = pooled_cls | self.pooled_cls = pooled_cls | ||||
def index_datasets(self, *datasets, field_name, add_cls_sep=True): | def index_datasets(self, *datasets, field_name, add_cls_sep=True): | ||||
@@ -601,7 +615,7 @@ class _BertWordPieceModel(nn.Module): | |||||
:return: | :return: | ||||
""" | """ | ||||
encode_func = partial(self.tokenzier.encode, add_special_tokens=add_cls_sep) | |||||
encode_func = partial(self.tokenizer.encode, add_special_tokens=add_cls_sep) | |||||
for index, dataset in enumerate(datasets): | for index, dataset in enumerate(datasets): | ||||
try: | try: | ||||
@@ -640,5 +654,5 @@ class _BertWordPieceModel(nn.Module): | |||||
:param folder: | :param folder: | ||||
:return: | :return: | ||||
""" | """ | ||||
self.tokenzier.save_pretrained(folder) | |||||
self.tokenizer.save_pretrained(folder) | |||||
self.encoder.save_pretrained(folder) | self.encoder.save_pretrained(folder) |
@@ -423,7 +423,7 @@ class _GPT2Model(nn.Module): | |||||
self._word_pad_index = vocab.padding_idx | self._word_pad_index = vocab.padding_idx | ||||
self._endoftext_index = self.tokenzier.encoder.get('<|endoftext|>') | self._endoftext_index = self.tokenzier.encoder.get('<|endoftext|>') | ||||
self._wordpiece_pad_index = self.tokenzier.encoder.get('<|endoftext|>') # 需要用于生成word_piece | self._wordpiece_pad_index = self.tokenzier.encoder.get('<|endoftext|>') # 需要用于生成word_piece | ||||
self.word_to_wordpieces = np.array(word_to_wordpieces) | |||||
self.word_to_wordpieces = np.array(word_to_wordpieces, dtype=object) | |||||
self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths)) | self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths)) | ||||
logger.debug("Successfully generate word pieces.") | logger.debug("Successfully generate word pieces.") | ||||
@@ -93,12 +93,13 @@ class RobertaEmbedding(ContextualEmbedding): | |||||
if '<s>' in vocab: | if '<s>' in vocab: | ||||
self._word_cls_index = vocab['<s>'] | self._word_cls_index = vocab['<s>'] | ||||
min_freq = kwargs.get('min_freq', 1) | |||||
min_freq = kwargs.pop('min_freq', 1) | |||||
self._min_freq = min_freq | self._min_freq = min_freq | ||||
self.model = _RobertaWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | self.model = _RobertaWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | ||||
pool_method=pool_method, include_cls_sep=include_cls_sep, | pool_method=pool_method, include_cls_sep=include_cls_sep, | ||||
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq) | |||||
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq, | |||||
**kwargs) | |||||
self.requires_grad = requires_grad | self.requires_grad = requires_grad | ||||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | ||||
@@ -193,33 +194,45 @@ class RobertaEmbedding(ContextualEmbedding): | |||||
class _RobertaWordModel(nn.Module): | class _RobertaWordModel(nn.Module): | ||||
def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | ||||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): | |||||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2, | |||||
**kwargs): | |||||
super().__init__() | super().__init__() | ||||
if isinstance(layers, list): | if isinstance(layers, list): | ||||
self.layers = [int(l) for l in layers] | self.layers = [int(l) for l in layers] | ||||
elif isinstance(layers, str): | elif isinstance(layers, str): | ||||
self.layers = list(map(int, layers.split(','))) | |||||
if layers.lower() == 'all': | |||||
self.layers = None | |||||
else: | |||||
self.layers = list(map(int, layers.split(','))) | |||||
else: | else: | ||||
raise TypeError("`layers` only supports str or list[int]") | raise TypeError("`layers` only supports str or list[int]") | ||||
assert len(self.layers) > 0, "There is no layer selected!" | |||||
neg_num_output_layer = -16384 | neg_num_output_layer = -16384 | ||||
pos_num_output_layer = 0 | pos_num_output_layer = 0 | ||||
for layer in self.layers: | |||||
if layer < 0: | |||||
neg_num_output_layer = max(layer, neg_num_output_layer) | |||||
else: | |||||
pos_num_output_layer = max(layer, pos_num_output_layer) | |||||
if self.layers is None: | |||||
neg_num_output_layer = -1 | |||||
else: | |||||
for layer in self.layers: | |||||
if layer < 0: | |||||
neg_num_output_layer = max(layer, neg_num_output_layer) | |||||
else: | |||||
pos_num_output_layer = max(layer, pos_num_output_layer) | |||||
self.tokenizer = RobertaTokenizer.from_pretrained(model_dir_or_name) | self.tokenizer = RobertaTokenizer.from_pretrained(model_dir_or_name) | ||||
self.encoder = RobertaModel.from_pretrained(model_dir_or_name, | self.encoder = RobertaModel.from_pretrained(model_dir_or_name, | ||||
neg_num_output_layer=neg_num_output_layer, | neg_num_output_layer=neg_num_output_layer, | ||||
pos_num_output_layer=pos_num_output_layer) | |||||
pos_num_output_layer=pos_num_output_layer, | |||||
**kwargs) | |||||
# 由于RobertaEmbedding中设置了padding_idx为1, 且使用了非常神奇的position计算方式,所以-2 | # 由于RobertaEmbedding中设置了padding_idx为1, 且使用了非常神奇的position计算方式,所以-2 | ||||
self._max_position_embeddings = self.encoder.config.max_position_embeddings - 2 | self._max_position_embeddings = self.encoder.config.max_position_embeddings - 2 | ||||
# 检查encoder_layer_number是否合理 | # 检查encoder_layer_number是否合理 | ||||
encoder_layer_number = len(self.encoder.encoder.layer) | encoder_layer_number = len(self.encoder.encoder.layer) | ||||
if self.layers is None: | |||||
self.layers = [idx for idx in range(encoder_layer_number + 1)] | |||||
logger.info(f'RoBERTa Model will return {len(self.layers)} layers (layer-0 ' | |||||
f'is embedding result): {self.layers}') | |||||
assert len(self.layers) > 0, "There is no layer selected!" | |||||
for layer in self.layers: | for layer in self.layers: | ||||
if layer < 0: | if layer < 0: | ||||
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | ||||
@@ -241,7 +254,7 @@ class _RobertaWordModel(nn.Module): | |||||
word = '<pad>' | word = '<pad>' | ||||
elif index == vocab.unknown_idx: | elif index == vocab.unknown_idx: | ||||
word = '<unk>' | word = '<unk>' | ||||
elif vocab.word_count[word]<min_freq: | |||||
elif vocab.word_count[word] < min_freq: | |||||
word = '<unk>' | word = '<unk>' | ||||
word_pieces = self.tokenizer.tokenize(word) | word_pieces = self.tokenizer.tokenize(word) | ||||
word_pieces = self.tokenizer.convert_tokens_to_ids(word_pieces) | word_pieces = self.tokenizer.convert_tokens_to_ids(word_pieces) | ||||
@@ -251,7 +264,7 @@ class _RobertaWordModel(nn.Module): | |||||
self._sep_index = self.tokenizer.encoder['</s>'] | self._sep_index = self.tokenizer.encoder['</s>'] | ||||
self._word_pad_index = vocab.padding_idx | self._word_pad_index = vocab.padding_idx | ||||
self._wordpiece_pad_index = self.tokenizer.encoder['<pad>'] # 需要用于生成word_piece | self._wordpiece_pad_index = self.tokenizer.encoder['<pad>'] # 需要用于生成word_piece | ||||
self.word_to_wordpieces = np.array(word_to_wordpieces) | |||||
self.word_to_wordpieces = np.array(word_to_wordpieces, dtype=object) | |||||
self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths)) | self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths)) | ||||
logger.debug("Successfully generate word pieces.") | logger.debug("Successfully generate word pieces.") | ||||
@@ -265,13 +278,15 @@ class _RobertaWordModel(nn.Module): | |||||
batch_size, max_word_len = words.size() | batch_size, max_word_len = words.size() | ||||
word_mask = words.ne(self._word_pad_index) # 为1的地方有word | word_mask = words.ne(self._word_pad_index) # 为1的地方有word | ||||
seq_len = word_mask.sum(dim=-1) | seq_len = word_mask.sum(dim=-1) | ||||
batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False), 0) # batch_size x max_len | |||||
batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False), | |||||
0) # batch_size x max_len | |||||
word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size | word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size | ||||
max_word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding) | max_word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding) | ||||
if max_word_piece_length + 2 > self._max_position_embeddings: | if max_word_piece_length + 2 > self._max_position_embeddings: | ||||
if self.auto_truncate: | if self.auto_truncate: | ||||
word_pieces_lengths = word_pieces_lengths.masked_fill( | word_pieces_lengths = word_pieces_lengths.masked_fill( | ||||
word_pieces_lengths + 2 > self._max_position_embeddings, self._max_position_embeddings - 2) | |||||
word_pieces_lengths + 2 > self._max_position_embeddings, | |||||
self._max_position_embeddings - 2) | |||||
else: | else: | ||||
raise RuntimeError( | raise RuntimeError( | ||||
"After split words into word pieces, the lengths of word pieces are longer than the " | "After split words into word pieces, the lengths of word pieces are longer than the " | ||||
@@ -290,6 +305,7 @@ class _RobertaWordModel(nn.Module): | |||||
word_pieces_i = word_pieces_i[:self._max_position_embeddings - 2] | word_pieces_i = word_pieces_i[:self._max_position_embeddings - 2] | ||||
word_pieces[i, 1:word_pieces_lengths[i] + 1] = torch.LongTensor(word_pieces_i) | word_pieces[i, 1:word_pieces_lengths[i] + 1] = torch.LongTensor(word_pieces_i) | ||||
attn_masks[i, :word_pieces_lengths[i] + 2].fill_(1) | attn_masks[i, :word_pieces_lengths[i] + 2].fill_(1) | ||||
# 添加<s>和</s> | |||||
word_pieces[:, 0].fill_(self._cls_index) | word_pieces[:, 0].fill_(self._cls_index) | ||||
batch_indexes = torch.arange(batch_size).to(words) | batch_indexes = torch.arange(batch_size).to(words) | ||||
word_pieces[batch_indexes, word_pieces_lengths + 1] = self._sep_index | word_pieces[batch_indexes, word_pieces_lengths + 1] = self._sep_index | ||||
@@ -362,6 +378,12 @@ class _RobertaWordModel(nn.Module): | |||||
return outputs | return outputs | ||||
def save(self, folder): | def save(self, folder): | ||||
""" | |||||
给定一个folder保存pytorch_model.bin, config.json, vocab.txt | |||||
:param str folder: | |||||
:return: | |||||
""" | |||||
self.tokenizer.save_pretrained(folder) | self.tokenizer.save_pretrained(folder) | ||||
self.encoder.save_pretrained(folder) | self.encoder.save_pretrained(folder) | ||||
@@ -13,6 +13,7 @@ import torch | |||||
from torch import nn as nn | from torch import nn as nn | ||||
from .embedding import TokenEmbedding | from .embedding import TokenEmbedding | ||||
from .utils import _check_vocab_has_same_index | |||||
class StackEmbedding(TokenEmbedding): | class StackEmbedding(TokenEmbedding): | ||||
@@ -44,8 +45,9 @@ class StackEmbedding(TokenEmbedding): | |||||
vocabs.append(embed.get_word_vocab()) | vocabs.append(embed.get_word_vocab()) | ||||
_vocab = vocabs[0] | _vocab = vocabs[0] | ||||
for vocab in vocabs[1:]: | for vocab in vocabs[1:]: | ||||
assert vocab == _vocab, "All embeddings in StackEmbedding should use the same word vocabulary." | |||||
if _vocab!=vocab: | |||||
_check_vocab_has_same_index(_vocab, vocab) | |||||
super(StackEmbedding, self).__init__(_vocab, word_dropout=word_dropout, dropout=dropout) | super(StackEmbedding, self).__init__(_vocab, word_dropout=word_dropout, dropout=dropout) | ||||
assert isinstance(embeds, list) | assert isinstance(embeds, list) | ||||
for embed in embeds: | for embed in embeds: | ||||
@@ -60,6 +62,7 @@ class StackEmbedding(TokenEmbedding): | |||||
:return: | :return: | ||||
""" | """ | ||||
assert isinstance(embed, TokenEmbedding) | assert isinstance(embed, TokenEmbedding) | ||||
_check_vocab_has_same_index(self.get_word_vocab(), embed.get_word_vocab()) | |||||
self._embed_size += embed.embed_size | self._embed_size += embed.embed_size | ||||
self.embeds.append(embed) | self.embeds.append(embed) | ||||
return self | return self | ||||
@@ -81,7 +81,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs): | init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs): | ||||
r""" | r""" | ||||
:param vocab: Vocabulary. 若该项为None则会读取所有的embedding。 | |||||
:param Vocabulary vocab: 词表. StaticEmbedding只会加载包含在词表中的词的词向量,在预训练向量中没找到的使用随机初始化 | |||||
:param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding文件夹(文件夹下应该只有一个 | :param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding文件夹(文件夹下应该只有一个 | ||||
以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。 | 以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。 | ||||
如果输入为None则使用embedding_dim的维度随机初始化一个embedding。 | 如果输入为None则使用embedding_dim的维度随机初始化一个embedding。 | ||||
@@ -281,7 +281,9 @@ class StaticEmbedding(TokenEmbedding): | |||||
if word in vocab: | if word in vocab: | ||||
index = vocab.to_index(word) | index = vocab.to_index(word) | ||||
if index in matrix: | if index in matrix: | ||||
warnings.warn(f"Word:{word} occurs again in line:{idx}(starts from 0)") | |||||
warnings.warn(f"Word has more than one vector in embedding file. Set logger level to " | |||||
f"DEBUG for detail.") | |||||
logger.debug(f"Word:{word} occurs again in line:{idx}(starts from 0)") | |||||
matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) | matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) | ||||
if self.only_norm_found_vector: | if self.only_norm_found_vector: | ||||
matrix[index] = matrix[index] / np.linalg.norm(matrix[index]) | matrix[index] = matrix[index] / np.linalg.norm(matrix[index]) | ||||
@@ -65,6 +65,8 @@ class TransformersEmbedding(ContextualEmbedding): | |||||
来进行分类的任务将auto_truncate置为True。 | 来进行分类的任务将auto_truncate置为True。 | ||||
:param kwargs: | :param kwargs: | ||||
int min_freq: 小于该次数的词会被unk代替, 默认为1 | int min_freq: 小于该次数的词会被unk代替, 默认为1 | ||||
dict tokenizer_kwargs: 传递给tokenizer在调用tokenize()方法时所额外使用的参数,例如RoBERTaTokenizer需要传入 | |||||
{'add_prefix_space':True} | |||||
""" | """ | ||||
super().__init__(vocab, word_dropout=word_dropout, dropout=dropout) | super().__init__(vocab, word_dropout=word_dropout, dropout=dropout) | ||||
@@ -82,9 +84,10 @@ class TransformersEmbedding(ContextualEmbedding): | |||||
min_freq = kwargs.get('min_freq', 1) | min_freq = kwargs.get('min_freq', 1) | ||||
self._min_freq = min_freq | self._min_freq = min_freq | ||||
tokenizer_kwargs = kwargs.get('tokenizer_kwargs', {}) | |||||
self.model = _TransformersWordModel(tokenizer=tokenizer, model=model, vocab=vocab, layers=layers, | self.model = _TransformersWordModel(tokenizer=tokenizer, model=model, vocab=vocab, layers=layers, | ||||
pool_method=pool_method, include_cls_sep=include_cls_sep, | |||||
auto_truncate=auto_truncate, min_freq=min_freq) | |||||
pool_method=pool_method, include_cls_sep=include_cls_sep, | |||||
auto_truncate=auto_truncate, min_freq=min_freq, tokenizer_kwargs=tokenizer_kwargs) | |||||
self.requires_grad = requires_grad | self.requires_grad = requires_grad | ||||
self._embed_size = len(self.model.layers) * model.config.hidden_size | self._embed_size = len(self.model.layers) * model.config.hidden_size | ||||
@@ -237,7 +240,7 @@ class TransformersWordPieceEncoder(nn.Module): | |||||
class _TransformersWordModel(nn.Module): | class _TransformersWordModel(nn.Module): | ||||
def __init__(self, tokenizer, model, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | def __init__(self, tokenizer, model, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | ||||
include_cls_sep: bool = False, auto_truncate: bool = False, min_freq=2): | |||||
include_cls_sep: bool = False, auto_truncate: bool = False, min_freq=2, tokenizer_kwargs={}): | |||||
super().__init__() | super().__init__() | ||||
self.tokenizer = tokenizer | self.tokenizer = tokenizer | ||||
@@ -283,7 +286,7 @@ class _TransformersWordModel(nn.Module): | |||||
word = tokenizer.unk_token | word = tokenizer.unk_token | ||||
elif vocab.word_count[word]<min_freq: | elif vocab.word_count[word]<min_freq: | ||||
word = tokenizer.unk_token | word = tokenizer.unk_token | ||||
word_pieces = self.tokenizer.tokenize(word, add_prefix_space=True) | |||||
word_pieces = self.tokenizer.tokenize(word, **tokenizer_kwargs) | |||||
word_pieces = self.tokenizer.convert_tokens_to_ids(word_pieces) | word_pieces = self.tokenizer.convert_tokens_to_ids(word_pieces) | ||||
word_to_wordpieces.append(word_pieces) | word_to_wordpieces.append(word_pieces) | ||||
word_pieces_lengths.append(len(word_pieces)) | word_pieces_lengths.append(len(word_pieces)) | ||||
@@ -291,7 +294,7 @@ class _TransformersWordModel(nn.Module): | |||||
self._sep_index = self.tokenizer.sep_token_id | self._sep_index = self.tokenizer.sep_token_id | ||||
self._word_pad_index = vocab.padding_idx | self._word_pad_index = vocab.padding_idx | ||||
self._wordpiece_pad_index = self.tokenizer.pad_token_id # 需要用于生成word_piece | self._wordpiece_pad_index = self.tokenizer.pad_token_id # 需要用于生成word_piece | ||||
self.word_to_wordpieces = np.array(word_to_wordpieces) | |||||
self.word_to_wordpieces = np.array(word_to_wordpieces, dtype=object) | |||||
self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths)) | self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths)) | ||||
logger.debug("Successfully generate word pieces.") | logger.debug("Successfully generate word pieces.") | ||||
@@ -89,3 +89,16 @@ def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): | |||||
return torch.FloatTensor(sinusoid_table) | return torch.FloatTensor(sinusoid_table) | ||||
def _check_vocab_has_same_index(vocab, other_vocab): | |||||
""" | |||||
检查两个vocabulary是否含有相同的word idx | |||||
:param Vocabulary vocab: | |||||
:param Vocabulary other_vocab: | |||||
:return: | |||||
""" | |||||
if other_vocab != vocab: | |||||
for word, word_ix in vocab: | |||||
other_word_idx = other_vocab.to_index(word) | |||||
assert other_word_idx == word_ix, f"Word {word} has different index in vocabs, {word_ix} Vs. {other_word_idx}." |
@@ -321,8 +321,15 @@ class DataBundle: | |||||
2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target | 2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target | ||||
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | 3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | ||||
4. use_tqdm: bool, 是否显示tqdm进度条 | |||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 | |||||
""" | """ | ||||
tqdm_desc = kwargs.get('tqdm_desc', '') | |||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
if tqdm_desc != '': | |||||
kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`' | |||||
if dataset.has_field(field_name=field_name): | if dataset.has_field(field_name=field_name): | ||||
dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, **kwargs) | dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, **kwargs) | ||||
elif not ignore_miss_dataset: | elif not ignore_miss_dataset: | ||||
@@ -350,10 +357,17 @@ class DataBundle: | |||||
3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 | 3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 | ||||
4. use_tqdm: bool, 是否显示tqdm进度条 | |||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 | |||||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | ||||
""" | """ | ||||
res = {} | res = {} | ||||
tqdm_desc = kwargs.get('tqdm_desc', '') | |||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
if tqdm_desc != '': | |||||
kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`' | |||||
if dataset.has_field(field_name=field_name): | if dataset.has_field(field_name=field_name): | ||||
res[name] = dataset.apply_field_more(func=func, field_name=field_name, modify_fields=modify_fields, **kwargs) | res[name] = dataset.apply_field_more(func=func, field_name=field_name, modify_fields=modify_fields, **kwargs) | ||||
elif not ignore_miss_dataset: | elif not ignore_miss_dataset: | ||||
@@ -376,8 +390,16 @@ class DataBundle: | |||||
2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target | 2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target | ||||
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | 3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | ||||
4. use_tqdm: bool, 是否显示tqdm进度条 | |||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 | |||||
""" | """ | ||||
tqdm_desc = kwargs.get('tqdm_desc', '') | |||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
if tqdm_desc != '': | |||||
kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`' | |||||
dataset.apply(func, new_field_name=new_field_name, **kwargs) | dataset.apply(func, new_field_name=new_field_name, **kwargs) | ||||
return self | return self | ||||
@@ -399,10 +421,17 @@ class DataBundle: | |||||
3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 | 3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 | ||||
4. use_tqdm: bool, 是否显示tqdm进度条 | |||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 | |||||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | ||||
""" | """ | ||||
res = {} | res = {} | ||||
tqdm_desc = kwargs.get('tqdm_desc', '') | |||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
if tqdm_desc!='': | |||||
kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`' | |||||
res[name] = dataset.apply_more(func, modify_fields=modify_fields, **kwargs) | res[name] = dataset.apply_more(func, modify_fields=modify_fields, **kwargs) | ||||
return res | return res | ||||
@@ -259,8 +259,8 @@ def _get_base_url(name): | |||||
return url + '/' | return url + '/' | ||||
else: | else: | ||||
URLS = { | URLS = { | ||||
'embedding': "http://212.129.155.247/embedding/", | |||||
"dataset": "http://212.129.155.247/dataset/" | |||||
'embedding': "http://download.fastnlp.top/embedding/", | |||||
"dataset": "http://download.fastnlp.top/dataset/" | |||||
} | } | ||||
if name.lower() not in URLS: | if name.lower() not in URLS: | ||||
raise KeyError(f"{name} is not recognized.") | raise KeyError(f"{name} is not recognized.") | ||||
@@ -312,7 +312,8 @@ def _read_extend_url_file(filename, name)->str: | |||||
return parts[1] | return parts[1] | ||||
return None | return None | ||||
def _get_dataset_url(name): | |||||
def _get_dataset_url(name, dataset_dir: dict = None): | |||||
r""" | r""" | ||||
给定dataset的名称,返回下载url | 给定dataset的名称,返回下载url | ||||
@@ -323,8 +324,9 @@ def _get_dataset_url(name): | |||||
url = _read_extend_url_file(FASTNLP_EXTEND_DATASET_URL, name) | url = _read_extend_url_file(FASTNLP_EXTEND_DATASET_URL, name) | ||||
if url: | if url: | ||||
return url | return url | ||||
filename = DATASET_DIR.get(name, None) | |||||
dataset_dir = DATASET_DIR if dataset_dir is None else dataset_dir | |||||
filename = dataset_dir.get(name, None) | |||||
if filename: | if filename: | ||||
url = _get_base_url('dataset') + filename | url = _get_base_url('dataset') + filename | ||||
return url | return url | ||||
@@ -12,6 +12,7 @@ import warnings | |||||
from ...core.const import Const | from ...core.const import Const | ||||
from ...core.vocabulary import Vocabulary | from ...core.vocabulary import Vocabulary | ||||
from ...core._logger import logger | from ...core._logger import logger | ||||
from pkg_resources import parse_version | |||||
def iob2(tags: List[str]) -> List[str]: | def iob2(tags: List[str]) -> List[str]: | ||||
@@ -82,7 +83,10 @@ def get_tokenizer(tokenize_method: str, lang='en'): | |||||
spacy.prefer_gpu() | spacy.prefer_gpu() | ||||
if lang != 'en': | if lang != 'en': | ||||
raise RuntimeError("Spacy only supports en right right.") | raise RuntimeError("Spacy only supports en right right.") | ||||
en = spacy.load(lang) | |||||
if parse_version(spacy.__version__) >= parse_version('3.0'): | |||||
en = spacy.load('en_core_web_sm') | |||||
else: | |||||
en = spacy.load(lang) | |||||
tokenizer = lambda x: [w.text for w in en.tokenizer(x)] | tokenizer = lambda x: [w.text for w in en.tokenizer(x)] | ||||
elif tokenize_method in tokenizer_dict: | elif tokenize_method in tokenizer_dict: | ||||
tokenizer = tokenizer_dict[tokenize_method] | tokenizer = tokenizer_dict[tokenize_method] | ||||
@@ -11,7 +11,8 @@ __all__ = ['SequenceGeneratorModel'] | |||||
class SequenceGeneratorModel(nn.Module): | class SequenceGeneratorModel(nn.Module): | ||||
""" | """ | ||||
用于封装Seq2SeqModel使其可以做生成任务 | |||||
通过使用本模型封装seq2seq_model使得其既可以用于训练也可以用于生成。训练的时候,本模型的forward函数会被调用,生成的时候本模型的predict | |||||
函数会被调用。 | |||||
""" | """ | ||||
@@ -46,7 +47,7 @@ class SequenceGeneratorModel(nn.Module): | |||||
def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): | def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): | ||||
""" | """ | ||||
透传调用seq2seq_model的forward | |||||
透传调用seq2seq_model的forward。 | |||||
:param torch.LongTensor src_tokens: bsz x max_len | :param torch.LongTensor src_tokens: bsz x max_len | ||||
:param torch.LongTensor tgt_tokens: bsz x max_len' | :param torch.LongTensor tgt_tokens: bsz x max_len' | ||||
@@ -58,7 +59,7 @@ class SequenceGeneratorModel(nn.Module): | |||||
def predict(self, src_tokens, src_seq_len=None): | def predict(self, src_tokens, src_seq_len=None): | ||||
""" | """ | ||||
给定source的内容,输出generate的内容 | |||||
给定source的内容,输出generate的内容。 | |||||
:param torch.LongTensor src_tokens: bsz x max_len | :param torch.LongTensor src_tokens: bsz x max_len | ||||
:param torch.LongTensor src_seq_len: bsz | :param torch.LongTensor src_seq_len: bsz | ||||
@@ -18,10 +18,16 @@ __all__ = ['Seq2SeqModel', 'TransformerSeq2SeqModel', 'LSTMSeq2SeqModel'] | |||||
class Seq2SeqModel(nn.Module): | class Seq2SeqModel(nn.Module): | ||||
def __init__(self, encoder: Seq2SeqEncoder, decoder: Seq2SeqDecoder): | def __init__(self, encoder: Seq2SeqEncoder, decoder: Seq2SeqDecoder): | ||||
""" | """ | ||||
可以用于在Trainer中训练的Seq2Seq模型。正常情况下,继承了该函数之后,只需要实现classmethod build_model即可。 | |||||
:param encoder: Encoder | |||||
:param decoder: Decoder | |||||
可以用于在Trainer中训练的Seq2Seq模型。正常情况下,继承了该函数之后,只需要实现classmethod build_model即可。如果需要使用该模型 | |||||
进行生成,需要把该模型输入到 :class:`~fastNLP.models.SequenceGeneratorModel` 中。在本模型中,forward()会把encoder后的 | |||||
结果传入到decoder中,并将decoder的输出output出来。 | |||||
:param encoder: Seq2SeqEncoder 对象,需要实现对应的forward()函数,接受两个参数,第一个为bsz x max_len的source tokens, 第二个为 | |||||
bsz的source的长度;需要返回两个tensor: encoder_outputs: bsz x max_len x hidden_size, encoder_mask: bsz x max_len | |||||
为1的地方需要被attend。如果encoder的输出或者输入有变化,可以重载本模型的prepare_state()函数或者forward()函数 | |||||
:param decoder: Seq2SeqDecoder 对象,需要实现init_state()函数,输出为两个参数,第一个为bsz x max_len x hidden_size是 | |||||
encoder的输出; 第二个为bsz x max_len,为encoder输出的mask,为0的地方为pad。若decoder需要更多输入,请重载当前模型的 | |||||
prepare_state()或forward()函数 | |||||
""" | """ | ||||
super().__init__() | super().__init__() | ||||
self.encoder = encoder | self.encoder = encoder | ||||
@@ -16,7 +16,7 @@ __all__ = ['Seq2SeqDecoder', 'TransformerSeq2SeqDecoder', 'LSTMSeq2SeqDecoder'] | |||||
class Seq2SeqDecoder(nn.Module): | class Seq2SeqDecoder(nn.Module): | ||||
""" | """ | ||||
Sequence-to-Sequence Decoder的基类。一定需要实现forward函数,剩下的函数根据需要实现。每个Seq2SeqDecoder都应该有相应的State对象 | |||||
Sequence-to-Sequence Decoder的基类。一定需要实现forward、decode函数,剩下的函数根据需要实现。每个Seq2SeqDecoder都应该有相应的State对象 | |||||
用来承载该Decoder所需要的Encoder输出、Decoder需要记录的历史信息(例如LSTM的hidden信息)。 | 用来承载该Decoder所需要的Encoder输出、Decoder需要记录的历史信息(例如LSTM的hidden信息)。 | ||||
""" | """ | ||||
@@ -61,7 +61,7 @@ class Seq2SeqDecoder(nn.Module): | |||||
""" | """ | ||||
根据states中的内容,以及tokens中的内容进行之后的生成。 | 根据states中的内容,以及tokens中的内容进行之后的生成。 | ||||
:param torch.LongTensor tokens: bsz x max_len, 上一个时刻的token输出。 | |||||
:param torch.LongTensor tokens: bsz x max_len, 截止到上一个时刻所有的token输出。 | |||||
:param State state: 记录了encoder输出与decoder过去状态 | :param State state: 记录了encoder输出与decoder过去状态 | ||||
:return: torch.FloatTensor: bsz x vocab_size, 输出的是下一个时刻的分布 | :return: torch.FloatTensor: bsz x vocab_size, 输出的是下一个时刻的分布 | ||||
""" | """ | ||||
@@ -184,21 +184,23 @@ class DistilBertEmbeddings(nn.Module): | |||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12) | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12) | ||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | self.dropout = nn.Dropout(config.hidden_dropout_prob) | ||||
def forward(self, input_ids, token_type_ids): | |||||
def forward(self, input_ids, token_type_ids, position_ids=None): | |||||
r""" | r""" | ||||
Parameters | Parameters | ||||
---------- | ---------- | ||||
input_ids: torch.tensor(bs, max_seq_length) | input_ids: torch.tensor(bs, max_seq_length) | ||||
The token ids to embed. | The token ids to embed. | ||||
token_type_ids: no used. | token_type_ids: no used. | ||||
position_ids: no used. | |||||
Outputs | Outputs | ||||
------- | ------- | ||||
embeddings: torch.tensor(bs, max_seq_length, dim) | embeddings: torch.tensor(bs, max_seq_length, dim) | ||||
The embedded tokens (plus position embeddings, no token_type embeddings) | The embedded tokens (plus position embeddings, no token_type embeddings) | ||||
""" | """ | ||||
seq_length = input_ids.size(1) | seq_length = input_ids.size(1) | ||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length) | |||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length) | |||||
if position_ids is None: | |||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length) | |||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length) | |||||
word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim) | word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim) | ||||
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim) | position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim) | ||||
@@ -374,20 +376,18 @@ class BertEncoder(nn.Module): | |||||
self.num_output_layer = max(min(num_output_layer, len(self.layer)), 0) | self.num_output_layer = max(min(num_output_layer, len(self.layer)), 0) | ||||
if self.num_output_layer + 1 < len(self.layer): | if self.num_output_layer + 1 < len(self.layer): | ||||
logger.info(f'The transformer encoder will early exit after layer-{self.num_output_layer} ' | logger.info(f'The transformer encoder will early exit after layer-{self.num_output_layer} ' | ||||
f'(start from 0)!') | |||||
f'(layer 0 means embedding layer)!') | |||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): | def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): | ||||
all_encoder_layers = [] | all_encoder_layers = [] | ||||
for idx, layer_module in enumerate(self.layer): | for idx, layer_module in enumerate(self.layer): | ||||
if idx > self.num_output_layer: | |||||
if idx >= self.num_output_layer: | |||||
break | break | ||||
hidden_states = layer_module(hidden_states, attention_mask) | hidden_states = layer_module(hidden_states, attention_mask) | ||||
if output_all_encoded_layers: | if output_all_encoded_layers: | ||||
all_encoder_layers.append(hidden_states) | all_encoder_layers.append(hidden_states) | ||||
if not output_all_encoded_layers: | if not output_all_encoded_layers: | ||||
all_encoder_layers.append(hidden_states) | all_encoder_layers.append(hidden_states) | ||||
if len(all_encoder_layers) == 0: | |||||
all_encoder_layers.append(hidden_states) | |||||
return all_encoder_layers | return all_encoder_layers | ||||
@@ -445,8 +445,8 @@ class BertModel(nn.Module): | |||||
self.hidden_size = self.config.hidden_size | self.hidden_size = self.config.hidden_size | ||||
self.model_type = 'bert' | self.model_type = 'bert' | ||||
neg_num_output_layer = kwargs.get('neg_num_output_layer', -1) | neg_num_output_layer = kwargs.get('neg_num_output_layer', -1) | ||||
pos_num_output_layer = kwargs.get('pos_num_output_layer', self.config.num_hidden_layers - 1) | |||||
self.num_output_layer = max(neg_num_output_layer + self.config.num_hidden_layers, pos_num_output_layer) | |||||
pos_num_output_layer = kwargs.get('pos_num_output_layer', self.config.num_hidden_layers) | |||||
self.num_output_layer = max(neg_num_output_layer + 1 + self.config.num_hidden_layers, pos_num_output_layer) | |||||
if hasattr(config, 'sinusoidal_pos_embds'): | if hasattr(config, 'sinusoidal_pos_embds'): | ||||
self.model_type = 'distilbert' | self.model_type = 'distilbert' | ||||
elif 'model_type' in kwargs: | elif 'model_type' in kwargs: | ||||
@@ -464,6 +464,24 @@ class BertModel(nn.Module): | |||||
logger.info('DistilBert has NOT pooler, will use hidden states of [CLS] token as pooled output.') | logger.info('DistilBert has NOT pooler, will use hidden states of [CLS] token as pooled output.') | ||||
self.apply(self.init_bert_weights) | self.apply(self.init_bert_weights) | ||||
@property | |||||
def dtype(self): | |||||
""" | |||||
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). | |||||
""" | |||||
try: | |||||
return next(self.parameters()).dtype | |||||
except StopIteration: | |||||
# For nn.DataParallel compatibility in PyTorch 1.5 | |||||
def find_tensor_attributes(module: nn.Module): | |||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] | |||||
return tuples | |||||
gen = self._named_members(get_members_fn=find_tensor_attributes) | |||||
first_tuple = next(gen) | |||||
return first_tuple[1].dtype | |||||
def init_bert_weights(self, module): | def init_bert_weights(self, module): | ||||
r""" Initialize the weights. | r""" Initialize the weights. | ||||
""" | """ | ||||
@@ -477,7 +495,8 @@ class BertModel(nn.Module): | |||||
if isinstance(module, nn.Linear) and module.bias is not None: | if isinstance(module, nn.Linear) and module.bias is not None: | ||||
module.bias.data.zero_() | module.bias.data.zero_() | ||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): | |||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, | |||||
position_ids=None): | |||||
""" | """ | ||||
:param torch.LongTensor input_ids: bsz x max_len的输入id | :param torch.LongTensor input_ids: bsz x max_len的输入id | ||||
@@ -485,6 +504,7 @@ class BertModel(nn.Module): | |||||
:param attention_mask: 需要attend的为1,不需要为0 | :param attention_mask: 需要attend的为1,不需要为0 | ||||
:param bool output_all_encoded_layers: 是否输出所有层,默认输出token embedding(包含bpe, position以及type embedding) | :param bool output_all_encoded_layers: 是否输出所有层,默认输出token embedding(包含bpe, position以及type embedding) | ||||
及每一层的hidden states。如果为False,只输出最后一层的结果 | 及每一层的hidden states。如果为False,只输出最后一层的结果 | ||||
:param torch.LongTensor position_ids: bsz x max_len, position的id | |||||
:return: encode_layers: 如果output_all_encoded_layers为True,返回list(共num_layers+1个元素),每个元素为 | :return: encode_layers: 如果output_all_encoded_layers为True,返回list(共num_layers+1个元素),每个元素为 | ||||
bsz x max_len x hidden_size否则返回bsz x max_len x hidden_size的tensor; | bsz x max_len x hidden_size否则返回bsz x max_len x hidden_size的tensor; | ||||
pooled_output: bsz x hidden_size为cls的表示,可以用于句子的分类 | pooled_output: bsz x hidden_size为cls的表示,可以用于句子的分类 | ||||
@@ -506,13 +526,16 @@ class BertModel(nn.Module): | |||||
# positions we want to attend and -10000.0 for masked positions. | # positions we want to attend and -10000.0 for masked positions. | ||||
# Since we are adding it to the raw scores before the softmax, this is | # Since we are adding it to the raw scores before the softmax, this is | ||||
# effectively the same as removing these entirely. | # effectively the same as removing these entirely. | ||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | |||||
# this will case an issue when DataParallel: https://github.com/pytorch/pytorch/issues/40457#issuecomment-648396469 | |||||
# extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | |||||
extended_attention_mask = extended_attention_mask.to(self.dtype) | |||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | ||||
embedding_output = self.embeddings(input_ids, token_type_ids) | |||||
embedding_output = self.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids) | |||||
encoded_layers = self.encoder(embedding_output, | encoded_layers = self.encoder(embedding_output, | ||||
extended_attention_mask, | extended_attention_mask, | ||||
output_all_encoded_layers=output_all_encoded_layers) | output_all_encoded_layers=output_all_encoded_layers) | ||||
encoded_layers.insert(0, embedding_output) | |||||
sequence_output = encoded_layers[-1] | sequence_output = encoded_layers[-1] | ||||
if self.model_type != 'distilbert': | if self.model_type != 'distilbert': | ||||
pooled_output = self.pooler(sequence_output) | pooled_output = self.pooler(sequence_output) | ||||
@@ -520,8 +543,6 @@ class BertModel(nn.Module): | |||||
pooled_output = sequence_output[:, 0] | pooled_output = sequence_output[:, 0] | ||||
if not output_all_encoded_layers: | if not output_all_encoded_layers: | ||||
encoded_layers = encoded_layers[-1] | encoded_layers = encoded_layers[-1] | ||||
else: | |||||
encoded_layers.insert(0, embedding_output) | |||||
return encoded_layers, pooled_output | return encoded_layers, pooled_output | ||||
@classmethod | @classmethod | ||||
@@ -787,6 +787,24 @@ class GPT2Model(GPT2PreTrainedModel): | |||||
for layer, heads in heads_to_prune.items(): | for layer, heads in heads_to_prune.items(): | ||||
self.h[layer].attn.prune_heads(heads) | self.h[layer].attn.prune_heads(heads) | ||||
@property | |||||
def dtype(self): | |||||
""" | |||||
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). | |||||
""" | |||||
try: | |||||
return next(self.parameters()).dtype | |||||
except StopIteration: | |||||
# For nn.DataParallel compatibility in PyTorch 1.5 | |||||
def find_tensor_attributes(module: nn.Module): | |||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] | |||||
return tuples | |||||
gen = self._named_members(get_members_fn=find_tensor_attributes) | |||||
first_tuple = next(gen) | |||||
return first_tuple[1].dtype | |||||
def forward(self, input_ids, state=None, attention_mask=None, token_type_ids=None, position_ids=None, | def forward(self, input_ids, state=None, attention_mask=None, token_type_ids=None, position_ids=None, | ||||
head_mask=None, output_attentions=True): | head_mask=None, output_attentions=True): | ||||
""" | """ | ||||
@@ -834,7 +852,9 @@ class GPT2Model(GPT2PreTrainedModel): | |||||
# positions we want to attend and -10000.0 for masked positions. | # positions we want to attend and -10000.0 for masked positions. | ||||
# Since we are adding it to the raw scores before the softmax, this is | # Since we are adding it to the raw scores before the softmax, this is | ||||
# effectively the same as removing these entirely. | # effectively the same as removing these entirely. | ||||
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | |||||
# this will case an issue when DataParallel: https://github.com/pytorch/pytorch/issues/40457#issuecomment-648396469 | |||||
# attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | |||||
attention_mask = attention_mask.to(self.dtype) | |||||
attention_mask = (1.0 - attention_mask) * -10000.0 | attention_mask = (1.0 - attention_mask) * -10000.0 | ||||
# attention_mask = attention_mask.masked_fill(attention_mask.eq(0), -10000.0) | # attention_mask = attention_mask.masked_fill(attention_mask.eq(0), -10000.0) | ||||
@@ -70,7 +70,7 @@ class LSTM(nn.Module): | |||||
x = x[sort_idx] | x = x[sort_idx] | ||||
else: | else: | ||||
x = x[:, sort_idx] | x = x[:, sort_idx] | ||||
x = rnn.pack_padded_sequence(x, sort_lens, batch_first=self.batch_first) | |||||
x = rnn.pack_padded_sequence(x, sort_lens.cpu(), batch_first=self.batch_first) | |||||
output, hx = self.lstm(x, hx) # -> [N,L,C] | output, hx = self.lstm(x, hx) # -> [N,L,C] | ||||
output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len) | output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len) | ||||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | ||||
@@ -39,7 +39,7 @@ class RobertaEmbeddings(BertEmbeddings): | |||||
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx | config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx | ||||
) | ) | ||||
def forward(self, input_ids, token_type_ids, words_embeddings=None): | |||||
def forward(self, input_ids, token_type_ids, words_embeddings=None, **kwargs): | |||||
position_ids = self.create_position_ids_from_input_ids(input_ids) | position_ids = self.create_position_ids_from_input_ids(input_ids) | ||||
return super().forward( | return super().forward( | ||||
@@ -12,9 +12,11 @@ import torch.nn.functional as F | |||||
from ...core.utils import _get_model_device | from ...core.utils import _get_model_device | ||||
from functools import partial | from functools import partial | ||||
class SequenceGenerator: | class SequenceGenerator: | ||||
""" | """ | ||||
给定一个Seq2SeqDecoder,decode出句子 | |||||
给定一个Seq2SeqDecoder,decode出句子。输入的decoder对象需要有decode()函数, 接受的第一个参数为decode的到目前位置的所有输出, | |||||
第二个参数为state。SequenceGenerator不会对state进行任何操作。 | |||||
""" | """ | ||||
def __init__(self, decoder: Seq2SeqDecoder, max_length=20, max_len_a=0.0, num_beams=1, | def __init__(self, decoder: Seq2SeqDecoder, max_length=20, max_len_a=0.0, num_beams=1, | ||||
@@ -65,7 +67,8 @@ class SequenceGenerator: | |||||
""" | """ | ||||
:param State state: encoder结果的State, 是与Decoder配套是用的 | :param State state: encoder结果的State, 是与Decoder配套是用的 | ||||
:param torch.LongTensor,None tokens: batch_size x length, 开始的token | |||||
:param torch.LongTensor,None tokens: batch_size x length, 开始的token。如果为None,则默认添加bos_token作为开头的token | |||||
进行生成。 | |||||
:return: bsz x max_length' 生成的token序列。如果eos_token_id不为None, 每个sequence的结尾一定是eos_token_id | :return: bsz x max_length' 生成的token序列。如果eos_token_id不为None, 每个sequence的结尾一定是eos_token_id | ||||
""" | """ | ||||
@@ -168,6 +171,8 @@ def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_le | |||||
_eos_token_id = eos_token_id | _eos_token_id = eos_token_id | ||||
scores = decoder.decode(tokens=tokens, state=state) # 主要是为了update state | scores = decoder.decode(tokens=tokens, state=state) # 主要是为了update state | ||||
if _eos_token_id!=-1: # 防止第一个位置为结束 | |||||
scores[:, _eos_token_id] = -1e12 | |||||
next_tokens = scores.argmax(dim=-1, keepdim=True) | next_tokens = scores.argmax(dim=-1, keepdim=True) | ||||
token_ids = torch.cat([tokens, next_tokens], dim=1) | token_ids = torch.cat([tokens, next_tokens], dim=1) | ||||
cur_len = token_ids.size(1) | cur_len = token_ids.size(1) | ||||
@@ -261,6 +266,8 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ | |||||
_eos_token_id = eos_token_id | _eos_token_id = eos_token_id | ||||
scores = decoder.decode(tokens=tokens, state=state) # 这里要传入的是整个句子的长度 | scores = decoder.decode(tokens=tokens, state=state) # 这里要传入的是整个句子的长度 | ||||
if _eos_token_id!=-1: # 防止第一个位置为结束 | |||||
scores[:, _eos_token_id] = -1e12 | |||||
vocab_size = scores.size(1) | vocab_size = scores.size(1) | ||||
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." | assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." | ||||
@@ -321,7 +328,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ | |||||
max_len_eos_mask = max_lengths.eq(cur_len+1) | max_len_eos_mask = max_lengths.eq(cur_len+1) | ||||
eos_scores = scores[:, _eos_token_id] | eos_scores = scores[:, _eos_token_id] | ||||
# 如果已经达到最大长度,就把eos的分数加大 | # 如果已经达到最大长度,就把eos的分数加大 | ||||
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e12, eos_scores) | |||||
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e32, eos_scores) | |||||
if do_sample: | if do_sample: | ||||
if temperature > 0 and temperature != 1: | if temperature > 0 and temperature != 1: | ||||
@@ -355,9 +362,9 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ | |||||
# 接下来需要组装下一个batch的结果。 | # 接下来需要组装下一个batch的结果。 | ||||
# 需要选定哪些留下来 | # 需要选定哪些留下来 | ||||
next_scores, sorted_inds = next_scores.sort(dim=-1, descending=True) | |||||
next_tokens = next_tokens.gather(dim=1, index=sorted_inds) | |||||
from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds) | |||||
# next_scores, sorted_inds = next_scores.sort(dim=-1, descending=True) | |||||
# next_tokens = next_tokens.gather(dim=1, index=sorted_inds) | |||||
# from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds) | |||||
not_eos_mask = next_tokens.ne(_eos_token_id) # 为1的地方不是eos | not_eos_mask = next_tokens.ne(_eos_token_id) # 为1的地方不是eos | ||||
keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) # 为1的地方需要保留 | keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) # 为1的地方需要保留 | ||||
@@ -412,7 +419,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ | |||||
break | break | ||||
# select the best hypotheses | # select the best hypotheses | ||||
tgt_len = token_ids.new(batch_size) | |||||
tgt_len = token_ids.new_zeros(batch_size) | |||||
best = [] | best = [] | ||||
for i, hypotheses in enumerate(hypos): | for i, hypotheses in enumerate(hypos): | ||||
@@ -424,7 +431,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ | |||||
best.append(best_hyp) | best.append(best_hyp) | ||||
# generate target batch | # generate target batch | ||||
decoded = token_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id) | |||||
decoded = token_ids.new_zeros(batch_size, tgt_len.max().item()).fill_(pad_token_id) | |||||
for i, hypo in enumerate(best): | for i, hypo in enumerate(best): | ||||
decoded[i, :tgt_len[i]] = hypo | decoded[i, :tgt_len[i]] = hypo | ||||
@@ -3,6 +3,5 @@ torch>=1.0.0 | |||||
tqdm>=4.28.1 | tqdm>=4.28.1 | ||||
prettytable>=0.7.2 | prettytable>=0.7.2 | ||||
requests | requests | ||||
spacy | |||||
prettytable>=0.7.2 | prettytable>=0.7.2 | ||||
regex!=2019.12.17 | regex!=2019.12.17 |
@@ -445,7 +445,7 @@ class TestCase1(unittest.TestCase): | |||||
sample_count = 0 | sample_count = 0 | ||||
for batch_x, batch_y in data_iter: | for batch_x, batch_y in data_iter: | ||||
sample_count += len(batch_x['seq_len']) | sample_count += len(batch_x['seq_len']) | ||||
self.assertTrue(sum(batch_x['seq_len'])<120) | |||||
self.assertTrue(sum(batch_x['seq_len'])<=120) | |||||
self.assertEqual(sample_count, num_samples) | self.assertEqual(sample_count, num_samples) | ||||
""" | """ | ||||
@@ -136,6 +136,14 @@ class TestDataSetMethods(unittest.TestCase): | |||||
ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k", ignore_type=True) | ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k", ignore_type=True) | ||||
# expect no exception raised | # expect no exception raised | ||||
def test_apply_tqdm(self): | |||||
import time | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
def do_nothing(ins): | |||||
time.sleep(0.01) | |||||
ds.apply(do_nothing, use_tqdm=True) | |||||
ds.apply_field(do_nothing, field_name='x', use_tqdm=True) | |||||
def test_apply_cannot_modify_instance(self): | def test_apply_cannot_modify_instance(self): | ||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ||||
def modify_inplace(instance): | def modify_inplace(instance): | ||||
@@ -268,6 +276,74 @@ class TestDataSetMethods(unittest.TestCase): | |||||
with self.assertRaises(RuntimeError) as RE: | with self.assertRaises(RuntimeError) as RE: | ||||
ds.add_field('test', []) | ds.add_field('test', []) | ||||
def test_concat(self): | |||||
""" | |||||
测试两个dataset能否正确concat | |||||
""" | |||||
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) | |||||
ds2 = DataSet({"x": [[4,3,2,1] for i in range(10)], "y": [[6,5] for i in range(10)]}) | |||||
ds3 = ds1.concat(ds2) | |||||
self.assertEqual(len(ds3), 20) | |||||
self.assertListEqual(ds1[9]['x'], [1, 2, 3, 4]) | |||||
self.assertListEqual(ds1[10]['x'], [4,3,2,1]) | |||||
ds2[0]['x'][0] = 100 | |||||
self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了 | |||||
ds3[10]['x'][0] = -100 | |||||
self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了 | |||||
# 测试inplace | |||||
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) | |||||
ds2 = DataSet({"x": [[4, 3, 2, 1] for i in range(10)], "y": [[6, 5] for i in range(10)]}) | |||||
ds3 = ds1.concat(ds2, inplace=True) | |||||
ds2[0]['x'][0] = 100 | |||||
self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了 | |||||
ds3[10]['x'][0] = -100 | |||||
self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了 | |||||
ds3[0]['x'][0] = 100 | |||||
self.assertEqual(ds1[0]['x'][0], 100) # 改变copy前的field了 | |||||
# 测试mapping | |||||
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) | |||||
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)]}) | |||||
ds3 = ds1.concat(ds2, field_mapping={'X':'x', 'Y':'y'}) | |||||
self.assertEqual(len(ds3), 20) | |||||
# 测试忽略掉多余的 | |||||
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) | |||||
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)], 'Z':[0]*10}) | |||||
ds3 = ds1.concat(ds2, field_mapping={'X':'x', 'Y':'y'}) | |||||
# 测试报错 | |||||
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) | |||||
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)]}) | |||||
with self.assertRaises(RuntimeError): | |||||
ds3 = ds1.concat(ds2, field_mapping={'X':'x'}) | |||||
def test_no_padder(self): | |||||
ds = DataSet() | |||||
ds.add_field('idx', [1, 2, 3], padder=None) | |||||
self.assertEqual(ds['idx'].padder, None) # should be None, but AutoPadder | |||||
def test_copy_padder(self): | |||||
from fastNLP.core.field import AutoPadder | |||||
ds = DataSet() | |||||
ds.add_field('idx', [1, 2, 3]) | |||||
ds['idx'].set_padder(None) # workaround of problem 1 | |||||
ds.apply_field(lambda x: x, 'idx', 'idx') | |||||
self.assertEqual(ds['idx'].padder, None) # should be None, but AutoPadder | |||||
ds = DataSet() | |||||
ds.add_field('idx', [1, 2, 3]) | |||||
ds.apply_field(lambda x: x, 'idx', 'idx') | |||||
self.assertTrue(isinstance(ds.get_field('idx').padder, AutoPadder)) # should be None, but AutoPadder | |||||
class TestDataSetIter(unittest.TestCase): | class TestDataSetIter(unittest.TestCase): | ||||
def test__repr__(self): | def test__repr__(self): | ||||
@@ -1,4 +1,7 @@ | |||||
import os | import os | ||||
# have to add this, otherwise cannot import fastNLP when check_call() | |||||
import sys | |||||
sys.path.append(os.sep.join(os.path.abspath(__file__).split(os.sep)[:-3])) | |||||
import shutil | import shutil | ||||
import subprocess | import subprocess | ||||
import unittest | import unittest | ||||
@@ -6,13 +9,14 @@ from argparse import ArgumentParser | |||||
import numpy as np | import numpy as np | ||||
import torch.cuda | import torch.cuda | ||||
import torch.distributed as dist | |||||
from fastNLP import AccuracyMetric | from fastNLP import AccuracyMetric | ||||
from fastNLP import CrossEntropyLoss, BCELoss | from fastNLP import CrossEntropyLoss, BCELoss | ||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
from fastNLP import Instance | from fastNLP import Instance | ||||
from fastNLP import SGD | from fastNLP import SGD | ||||
from fastNLP.core.callback import EchoCallback | |||||
from fastNLP.core.callback import EchoCallback, GradientClipCallback | |||||
from fastNLP.core.dist_trainer import DistTrainer, get_local_rank | from fastNLP.core.dist_trainer import DistTrainer, get_local_rank | ||||
from fastNLP.models.base_model import NaiveClassifier | from fastNLP.models.base_model import NaiveClassifier | ||||
@@ -103,7 +107,7 @@ class TestDistTrainer(unittest.TestCase): | |||||
model=model, train_data=data_set, optimizer=SGD(lr=0.1), | model=model, train_data=data_set, optimizer=SGD(lr=0.1), | ||||
loss=CrossEntropyLoss(pred="predict", target="y"), | loss=CrossEntropyLoss(pred="predict", target="y"), | ||||
batch_size_per_gpu=8, n_epochs=3, print_every=50, save_path=self.save_path, | batch_size_per_gpu=8, n_epochs=3, print_every=50, save_path=self.save_path, | ||||
fp16='O1' | |||||
fp16=True | |||||
) | ) | ||||
trainer.train() | trainer.train() | ||||
""" | """ | ||||
@@ -113,18 +117,20 @@ class TestDistTrainer(unittest.TestCase): | |||||
shutil.rmtree(self.save_path) | shutil.rmtree(self.save_path) | ||||
def run3(self): | def run3(self): | ||||
# test callbacks, especially clip-norm | |||||
set_rng_seed(100) | set_rng_seed(100) | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = DistTrainer( | trainer = DistTrainer( | ||||
data_set, model, optimizer=None, | data_set, model, optimizer=None, | ||||
loss=BCELoss(pred="predict", target="y"), | loss=BCELoss(pred="predict", target="y"), | ||||
n_epochs=3, print_every=50, | n_epochs=3, print_every=50, | ||||
callbacks_all=[EchoCallback('callbacks_all')], | |||||
callbacks_all=[GradientClipCallback()], | |||||
callbacks_master=[EchoCallback('callbacks_master')] | callbacks_master=[EchoCallback('callbacks_master')] | ||||
) | ) | ||||
trainer.train() | trainer.train() | ||||
def run4(self): | def run4(self): | ||||
# test metrics, save, and others | |||||
set_rng_seed(100) | set_rng_seed(100) | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
@@ -173,4 +179,5 @@ if __name__ == '__main__': | |||||
parser.add_argument('--test', type=int) | parser.add_argument('--test', type=int) | ||||
args, _ = parser.parse_known_args() | args, _ = parser.parse_known_args() | ||||
if args.test and hasattr(runner, 'run%s' % args.test): | if args.test and hasattr(runner, 'run%s' % args.test): | ||||
dist.init_process_group("nccl") | |||||
getattr(runner, 'run%s' % args.test)() | getattr(runner, 'run%s' % args.test)() |
@@ -44,3 +44,11 @@ class TestSampler(unittest.TestCase): | |||||
indices = sampler(data_set) | indices = sampler(data_set) | ||||
self.assertEqual(len(indices), 10) | self.assertEqual(len(indices), 10) | ||||
# 跑通即可,不验证效果 | # 跑通即可,不验证效果 | ||||
def test_ConstantTokenNumSampler(self): | |||||
# 需要check的是,是否在number上是接近的 | |||||
pass | |||||
def test_ConstTokenNumSampler(self): | |||||
# 需要check的是,是否可以直接运行 | |||||
pass |
@@ -9,13 +9,17 @@ import torch | |||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
from fastNLP import Instance | from fastNLP import Instance | ||||
from fastNLP import BCELoss | |||||
from fastNLP import BCELoss, BCEWithLogits | |||||
from fastNLP import CrossEntropyLoss | from fastNLP import CrossEntropyLoss | ||||
from fastNLP import AccuracyMetric | from fastNLP import AccuracyMetric | ||||
from fastNLP import SGD | from fastNLP import SGD | ||||
from fastNLP import Trainer | from fastNLP import Trainer | ||||
from fastNLP.models.base_model import NaiveClassifier | from fastNLP.models.base_model import NaiveClassifier | ||||
from fastNLP import TorchLoaderIter | from fastNLP import TorchLoaderIter | ||||
from fastNLP.models import BaseModel | |||||
from fastNLP.modules import MLP | |||||
from pkg_resources import parse_version | |||||
def prepare_fake_dataset(): | def prepare_fake_dataset(): | ||||
@@ -575,3 +579,148 @@ class TrainerTestGround(unittest.TestCase): | |||||
) | ) | ||||
trainer.train() | trainer.train() | ||||
""" | """ | ||||
class NaiveClassifier2(BaseModel): | |||||
r""" | |||||
一个简单的分类器例子,可用于各种测试 | |||||
""" | |||||
def __init__(self, in_feature_dim, out_feature_dim): | |||||
super(NaiveClassifier2, self).__init__() | |||||
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) | |||||
def forward(self, x): | |||||
return {"predict": self.mlp(x)} | |||||
def predict(self, x): | |||||
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5} | |||||
class Fp16TrainerTest(unittest.TestCase): | |||||
def test_raise_error(self): | |||||
data_set = prepare_fake_dataset() | |||||
data_set.set_input("x", flag=True) | |||||
data_set.set_target("y", flag=True) | |||||
train_set, dev_set = data_set.split(0.3) | |||||
model = NaiveClassifier2(2, 1) | |||||
with self.assertRaises(RuntimeError): | |||||
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCEWithLogits(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, | |||||
use_tqdm=True, check_code_level=2, fp16=True) | |||||
with self.assertRaises(RuntimeError): | |||||
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCEWithLogits(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, | |||||
use_tqdm=True, check_code_level=2, fp16=True, device='cpu') | |||||
with self.assertRaises(RuntimeError): | |||||
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCEWithLogits(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, | |||||
use_tqdm=True, check_code_level=2, fp16=True, device=torch.device('cpu')) | |||||
@unittest.skipIf(torch.cuda.is_available()==False or parse_version(torch.__version__) < parse_version('1.6'), "Skip when no cuda device detch") | |||||
def test_run_fp16(self): | |||||
data_set = prepare_fake_dataset() | |||||
data_set.set_input("x", flag=True) | |||||
data_set.set_target("y", flag=True) | |||||
train_set, dev_set = data_set.split(0.3) | |||||
model = NaiveClassifier2(2, 1) | |||||
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCEWithLogits(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, | |||||
use_tqdm=True, check_code_level=2, fp16=True, device=0) | |||||
trainer.train(load_best_model=False) | |||||
model = NaiveClassifier2(2, 1) | |||||
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCEWithLogits(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, | |||||
use_tqdm=True, check_code_level=2, fp16=True, device=0, test_use_fp16=False) | |||||
trainer.train(load_best_model=False) | |||||
@unittest.skipIf(torch.cuda.device_count()<2 or parse_version(torch.__version__) < parse_version('1.6'), "Skip when lower than 1 gpus.") | |||||
def test_run_data_parallel(self): | |||||
data_set = prepare_fake_dataset() | |||||
data_set.set_input("x", flag=True) | |||||
data_set.set_target("y", flag=True) | |||||
train_set, dev_set = data_set.split(0.3) | |||||
class NaiveClassifier2(BaseModel): | |||||
r""" | |||||
一个简单的分类器例子,可用于各种测试 | |||||
""" | |||||
def __init__(self, in_feature_dim, out_feature_dim): | |||||
super(NaiveClassifier2, self).__init__() | |||||
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) | |||||
def forward(self, x): | |||||
return {"predict": self.mlp(x)} | |||||
def predict(self, x): | |||||
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5} | |||||
model = NaiveClassifier2(2, 1) | |||||
with self.assertRaises(RuntimeError): | |||||
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCEWithLogits(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, | |||||
use_tqdm=True, check_code_level=2, fp16=True, device=[0, 1]) | |||||
with self.assertRaises(RuntimeError): | |||||
class NaiveClassifier3(BaseModel): | |||||
r""" | |||||
一个简单的分类器例子,可用于各种测试 | |||||
""" | |||||
def __init__(self, in_feature_dim, out_feature_dim): | |||||
super(NaiveClassifier3, self).__init__() | |||||
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) | |||||
@torch.cuda.amp.autocast() | |||||
def forward(self, x): | |||||
return {"predict": self.mlp(x)} | |||||
@torch.cuda.amp.autocast() | |||||
def predict(self, x): | |||||
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5} | |||||
model = NaiveClassifier3(2, 1) | |||||
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCEWithLogits(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, | |||||
use_tqdm=True, check_code_level=2, fp16=True, device=[0, 1], test_use_fp16=True) | |||||
class NaiveClassifier4(BaseModel): | |||||
r""" | |||||
一个简单的分类器例子,可用于各种测试 | |||||
""" | |||||
def __init__(self, in_feature_dim, out_feature_dim): | |||||
super(NaiveClassifier4, self).__init__() | |||||
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) | |||||
def forward(self, x): | |||||
with torch.cuda.amp.autocast(): | |||||
return {"predict": self.mlp(x)} | |||||
def predict(self, x): | |||||
with torch.cuda.amp.autocast(): | |||||
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5} | |||||
model = NaiveClassifier4(2, 1) | |||||
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCEWithLogits(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, | |||||
use_tqdm=True, check_code_level=2, fp16=True, device=[0, 1], test_use_fp16=True) | |||||
trainer.train(load_best_model=False) |
@@ -31,29 +31,33 @@ class TestDownload(unittest.TestCase): | |||||
class TestBertEmbedding(unittest.TestCase): | class TestBertEmbedding(unittest.TestCase): | ||||
def test_bert_embedding_1(self): | def test_bert_embedding_1(self): | ||||
vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split()) | |||||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1) | |||||
requires_grad = embed.requires_grad | |||||
embed.requires_grad = not requires_grad | |||||
embed.train() | |||||
words = torch.LongTensor([[2, 3, 4, 0]]) | |||||
result = embed(words) | |||||
self.assertEqual(result.size(), (1, 4, 16)) | |||||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1) | |||||
embed.eval() | |||||
words = torch.LongTensor([[2, 3, 4, 0]]) | |||||
result = embed(words) | |||||
self.assertEqual(result.size(), (1, 4, 16)) | |||||
# 自动截断而不报错 | |||||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1, | |||||
auto_truncate=True) | |||||
words = torch.LongTensor([[2, 3, 4, 1]*10, | |||||
[2, 3]+[0]*38]) | |||||
result = embed(words) | |||||
self.assertEqual(result.size(), (2, 40, 16)) | |||||
for pool_method in ['first', 'last', 'max', 'avg']: | |||||
with self.subTest(pool_method=pool_method): | |||||
vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split()) | |||||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1, | |||||
pool_method=pool_method) | |||||
requires_grad = embed.requires_grad | |||||
embed.requires_grad = not requires_grad | |||||
embed.train() | |||||
words = torch.LongTensor([[2, 3, 4, 0]]) | |||||
result = embed(words) | |||||
self.assertEqual(result.size(), (1, 4, 16)) | |||||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1, | |||||
pool_method=pool_method) | |||||
embed.eval() | |||||
words = torch.LongTensor([[2, 3, 4, 0]]) | |||||
result = embed(words) | |||||
self.assertEqual(result.size(), (1, 4, 16)) | |||||
# 自动截断而不报错 | |||||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1, | |||||
auto_truncate=True, pool_method=pool_method) | |||||
words = torch.LongTensor([[2, 3, 4, 1]*10, | |||||
[2, 3]+[0]*38]) | |||||
result = embed(words) | |||||
self.assertEqual(result.size(), (2, 40, 16)) | |||||
def test_save_load(self): | def test_save_load(self): | ||||
bert_save_test = 'bert_save_test' | bert_save_test = 'bert_save_test' | ||||
@@ -18,3 +18,16 @@ class TestCharEmbed(unittest.TestCase): | |||||
y = embed(x) | y = embed(x) | ||||
self.assertEqual(tuple(y.size()), (2, 3, 130)) | self.assertEqual(tuple(y.size()), (2, 3, 130)) | ||||
def test_case_2(self): | |||||
# 测试只需要拥有一样的index就可以concat | |||||
ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['hello', 'Jack'])]) | |||||
vocab1 = Vocabulary().from_dataset(ds, field_name='words') | |||||
vocab2 = Vocabulary().from_dataset(ds, field_name='words') | |||||
self.assertEqual(len(vocab1), 5) | |||||
cnn_embed = CNNCharEmbedding(vocab1, embed_size=60) | |||||
lstm_embed = LSTMCharEmbedding(vocab2, embed_size=70) | |||||
embed = StackEmbedding([cnn_embed, lstm_embed]) | |||||
x = torch.LongTensor([[2, 1, 0], [4, 3, 4]]) | |||||
y = embed(x) | |||||
self.assertEqual(tuple(y.size()), (2, 3, 130)) | |||||
@@ -74,6 +74,7 @@ class TestRunMatchingPipe(unittest.TestCase): | |||||
name, vocabs = y | name, vocabs = y | ||||
self.assertEqual(x + 1 if name == 'words' else x, len(vocabs)) | self.assertEqual(x + 1 if name == 'words' else x, len(vocabs)) | ||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||||
def test_spacy(self): | def test_spacy(self): | ||||
data_set_dict = { | data_set_dict = { | ||||
'Quora': ('tests/data_for_tests/io/Quora', QuoraPipe, QuoraBertPipe, (2, 2, 2), (93, 2)), | 'Quora': ('tests/data_for_tests/io/Quora', QuoraPipe, QuoraBertPipe, (2, 2, 2), (93, 2)), | ||||