@@ -11,25 +11,14 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": 1 ,
"execution_count": 2 ,
"metadata": {},
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/remote-home/ynzheng/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/io/loader/classification.py:340: UserWarning: SST2's test file has no target.\n"
]
}
],
"outputs": [],
"source": [
"source": [
"from fastNLP.io import SST2Pipe\n",
"from fastNLP.io import SST2Pipe\n",
"from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric\n",
"from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric\n",
"from fastNLP.models import CNNText\n",
"from fastNLP.models import CNNText\n",
"from fastNLP import CrossEntropyLoss\n",
"import torch\n",
"import torch\n",
"from torch.optim import Adam\n",
"from fastNLP import AccuracyMetric\n",
"\n",
"\n",
"databundle = SST2Pipe().process_from_file()\n",
"databundle = SST2Pipe().process_from_file()\n",
"vocab = databundle.get_vocab('words')\n",
"vocab = databundle.get_vocab('words')\n",
@@ -40,7 +29,6 @@
"model = CNNText((len(vocab),100), num_classes=2, dropout=0.1)\n",
"model = CNNText((len(vocab),100), num_classes=2, dropout=0.1)\n",
"loss = CrossEntropyLoss()\n",
"loss = CrossEntropyLoss()\n",
"metric = AccuracyMetric()\n",
"metric = AccuracyMetric()\n",
"optimizer = Adam(model.parameters(), lr=0.001)\n",
"device = 0 if torch.cuda.is_available() else 'cpu'"
"device = 0 if torch.cuda.is_available() else 'cpu'"
]
]
},
},
@@ -53,7 +41,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": 2 ,
"execution_count": 3 ,
"metadata": {
"metadata": {
"scrolled": true
"scrolled": true
},
},
@@ -63,12 +51,12 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"input fields after batch(if batch size is 2):\n",
"input fields after batch(if batch size is 2):\n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 13 ]) \n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4 ]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"target fields after batch(if batch size is 2):\n",
"target fields after batch(if batch size is 2):\n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\n",
"\n",
"training epochs started 2020-02-28-00-11-51 \n"
"training epochs started 2020-02-28-00-37-08 \n"
]
]
},
},
{
{
@@ -104,11 +92,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.16 seconds!\n",
"Evaluate data in 0.28 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
"Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
"\r",
"\r",
"AccuracyMetric: acc=0.722 477\n",
"AccuracyMetric: acc=0.747706 \n",
"\n"
"\n"
]
]
},
},
@@ -131,11 +119,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.36 seconds!\n",
"Evaluate data in 0.17 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
"Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
"\r",
"\r",
"AccuracyMetric: acc=0.762615 \n",
"AccuracyMetric: acc=0.745413 \n",
"\n"
"\n"
]
]
},
},
@@ -158,11 +146,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.16 seconds!\n",
"Evaluate data in 0.19 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
"Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
"\r",
"\r",
"AccuracyMetric: acc=0.771789 \n",
"AccuracyMetric: acc=0.74656 \n",
"\n"
"\n"
]
]
},
},
@@ -185,11 +173,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.44 seconds!\n",
"Evaluate data in 0.15 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
"Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
"\r",
"\r",
"AccuracyMetric: acc=0.759174 \n",
"AccuracyMetric: acc=0.762615 \n",
"\n"
"\n"
]
]
},
},
@@ -212,11 +200,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.29 seconds!\n",
"Evaluate data in 0.4 2 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
"Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
"\r",
"\r",
"AccuracyMetric: acc=0.75344 \n",
"AccuracyMetric: acc=0.736239 \n",
"\n"
"\n"
]
]
},
},
@@ -239,11 +227,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.33 seconds!\n",
"Evaluate data in 0.16 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
"Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
"\r",
"\r",
"AccuracyMetric: acc=0.75 \n",
"AccuracyMetric: acc=0.761468 \n",
"\n"
"\n"
]
]
},
},
@@ -266,11 +254,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.19 seconds!\n",
"Evaluate data in 0.42 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
"Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
"\r",
"\r",
"AccuracyMetric: acc=0.74197 2\n",
"AccuracyMetric: acc=0.727064 \n",
"\n"
"\n"
]
]
},
},
@@ -293,11 +281,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.49 seconds!\n",
"Evaluate data in 0.21 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
"Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
"\r",
"\r",
"AccuracyMetric: acc=0.740826 \n",
"AccuracyMetric: acc=0.731651 \n",
"\n"
"\n"
]
]
},
},
@@ -320,11 +308,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.1 5 seconds!\n",
"Evaluate data in 0.52 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
"Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
"\r",
"\r",
"AccuracyMetric: acc=0.75\n",
"AccuracyMetric: acc=0.752294 \n",
"\n"
"\n"
]
]
},
},
@@ -347,36 +335,35 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.16 seconds!\n",
"Evaluate data in 0.44 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
"Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
"\r",
"\r",
"AccuracyMetric: acc=0.752294 \n",
"AccuracyMetric: acc=0.760321 \n",
"\n",
"\n",
"\r\n",
"\r\n",
"In Epoch:3/Step:462 , got best dev performance:\n",
"AccuracyMetric: acc=0.771789 \n",
"In Epoch:4/Step:616 , got best dev performance:\n",
"AccuracyMetric: acc=0.762615 \n",
"Reloaded the best model.\n"
"Reloaded the best model.\n"
]
]
},
},
{
{
"data": {
"data": {
"text/plain": [
"text/plain": [
"{'best_eval': {'AccuracyMetric': {'acc': 0.771789 }},\n",
" 'best_epoch': 3 ,\n",
" 'best_step': 462 ,\n",
" 'seconds': 30.04 }"
"{'best_eval': {'AccuracyMetric': {'acc': 0.762615 }},\n",
" 'best_epoch': 4 ,\n",
" 'best_step': 616 ,\n",
" 'seconds': 32.63 }"
]
]
},
},
"execution_count": 2 ,
"execution_count": 3 ,
"metadata": {},
"metadata": {},
"output_type": "execute_result"
"output_type": "execute_result"
}
}
],
],
"source": [
"source": [
"trainer = Trainer(train_data=train_data, model=model, loss=loss,\n",
" optimizer=optimizer, batch_size=32, dev_data=dev_data,\n",
" metrics=metric, device=device)\n",
"trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,\n",
" loss=loss, device=device, metrics=metric)\n",
"trainer.train()"
"trainer.train()"
]
]
},
},
@@ -432,7 +419,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": 3 ,
"execution_count": 4 ,
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
@@ -464,7 +451,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": 4 ,
"execution_count": 5 ,
"metadata": {
"metadata": {
"scrolled": true
"scrolled": true
},
},
@@ -474,12 +461,12 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"input fields after batch(if batch size is 2):\n",
"input fields after batch(if batch size is 2):\n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 13 ]) \n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4 ]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"target fields after batch(if batch size is 2):\n",
"target fields after batch(if batch size is 2):\n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\n",
"\n",
"training epochs started 2020-02-28-00-12-2 1\n"
"training epochs started 2020-02-28-00-37-4 1\n"
]
]
},
},
{
{
@@ -515,11 +502,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.33 seconds!\n",
"Evaluate data in 0.27 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
"Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
"\r",
"\r",
"AccMetric: acc=0.7419724770642202 \n",
"AccMetric: acc=0.7431192660550459 \n",
"\n"
"\n"
]
]
},
},
@@ -542,11 +529,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.19 seconds!\n",
"Evaluate data in 0.42 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
"Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
"\r",
"\r",
"AccMetric: acc=0.7660550458715596 \n",
"AccMetric: acc=0.7522935779816514 \n",
"\n"
"\n"
]
]
},
},
@@ -569,11 +556,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.27 seconds!\n",
"Evaluate data in 0.51 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
"Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
"\r",
"\r",
"AccMetric: acc=0.75 \n",
"AccMetric: acc=0.7477064220183486 \n",
"\n"
"\n"
]
]
},
},
@@ -596,11 +583,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.2 4 seconds!\n",
"Evaluate data in 0.48 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
"Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
"\r",
"\r",
"AccMetric: acc=0.7534403669724771 \n",
"AccMetric: acc=0.7442660550458715 \n",
"\n"
"\n"
]
]
},
},
@@ -623,11 +610,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.29 seconds!\n",
"Evaluate data in 0.5 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
"Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
"\r",
"\r",
"AccMetric: acc=0.7488532110091743 \n",
"AccMetric: acc=0.7362385321100917 \n",
"\n"
"\n"
]
]
},
},
@@ -650,11 +637,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.1 4 seconds!\n",
"Evaluate data in 0.45 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
"Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
"\r",
"\r",
"AccMetric: acc=0.7488532110091743 \n",
"AccMetric: acc=0.7293577981651376 \n",
"\n"
"\n"
]
]
},
},
@@ -677,11 +664,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.27 seconds!\n",
"Evaluate data in 0.33 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
"Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
"\r",
"\r",
"AccMetric: acc=0.7568807339449541 \n",
"AccMetric: acc=0.7190366972477065 \n",
"\n"
"\n"
]
]
},
},
@@ -704,11 +691,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.4 2 seconds!\n",
"Evaluate data in 0.29 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
"Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
"\r",
"\r",
"AccMetric: acc=0.7488532110091743 \n",
"AccMetric: acc=0.7419724770642202 \n",
"\n"
"\n"
]
]
},
},
@@ -731,11 +718,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.16 seconds!\n",
"Evaluate data in 0.34 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
"Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
"\r",
"\r",
"AccMetric: acc=0.7408256880733946 \n",
"AccMetric: acc=0.7350917431192661 \n",
"\n"
"\n"
]
]
},
},
@@ -758,36 +745,35 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.2 8 seconds!\n",
"Evaluate data in 0.1 8 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
"Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
"\r",
"\r",
"AccMetric: acc=0.7408256880733946 \n",
"AccMetric: acc=0.6846330275229358 \n",
"\n",
"\n",
"\r\n",
"\r\n",
"In Epoch:2/Step:308, got best dev performance:\n",
"In Epoch:2/Step:308, got best dev performance:\n",
"AccMetric: acc=0.7660550458715596 \n",
"AccMetric: acc=0.7522935779816514 \n",
"Reloaded the best model.\n"
"Reloaded the best model.\n"
]
]
},
},
{
{
"data": {
"data": {
"text/plain": [
"text/plain": [
"{'best_eval': {'AccMetric': {'acc': 0.7660550458715596 }},\n",
"{'best_eval': {'AccMetric': {'acc': 0.7522935779816514 }},\n",
" 'best_epoch': 2,\n",
" 'best_epoch': 2,\n",
" 'best_step': 308,\n",
" 'best_step': 308,\n",
" 'seconds': 29 .74 }"
" 'seconds': 4 2.7}"
]
]
},
},
"execution_count": 4 ,
"execution_count": 5 ,
"metadata": {},
"metadata": {},
"output_type": "execute_result"
"output_type": "execute_result"
}
}
],
],
"source": [
"source": [
"trainer = Trainer(train_data=train_data, model=model, loss=loss,\n",
" optimizer=optimizer, batch_size=32, dev_data=dev_data,\n",
" metrics=AccMetric(), device=device)\n",
"trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,\n",
" loss=loss, device=device, metrics=AccMetric())\n",
"trainer.train()"
"trainer.train()"
]
]
},
},
@@ -802,7 +788,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": 5 ,
"execution_count": 6 ,
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
@@ -841,7 +827,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": 6 ,
"execution_count": 7 ,
"metadata": {
"metadata": {
"scrolled": true
"scrolled": true
},
},
@@ -851,12 +837,12 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"input fields after batch(if batch size is 2):\n",
"input fields after batch(if batch size is 2):\n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 13 ]) \n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4 ]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"target fields after batch(if batch size is 2):\n",
"target fields after batch(if batch size is 2):\n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\n",
"\n",
"training epochs started 2020-02-28-00-12-51 \n"
"training epochs started 2020-02-28-00-38-24 \n"
]
]
},
},
{
{
@@ -892,11 +878,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.24 seconds!\n",
"Evaluate data in 0.3 2 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
"Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
"\r",
"\r",
"AccMetric: acc=0.754587155963302 7\n",
"AccMetric: acc=0.751146788990825 7\n",
"\n"
"\n"
]
]
},
},
@@ -919,11 +905,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.24 seconds!\n",
"Evaluate data in 0.29 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
"Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
"\r",
"\r",
"AccMetric: acc=0.7534403669724771 \n",
"AccMetric: acc=0.7454128440366973 \n",
"\n"
"\n"
]
]
},
},
@@ -946,11 +932,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.18 seconds!\n",
"Evaluate data in 0.42 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
"Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
"\r",
"\r",
"AccMetric: acc=0.755733944954128 5\n",
"AccMetric: acc=0.722477064220183 5\n",
"\n"
"\n"
]
]
},
},
@@ -973,11 +959,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.11 seconds!\n",
"Evaluate data in 0.4 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
"Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
"\r",
"\r",
"AccMetric: acc=0.7511467889908257 \n",
"AccMetric: acc=0.7534403669724771 \n",
"\n"
"\n"
]
]
},
},
@@ -1000,11 +986,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.19 seconds!\n",
"Evaluate data in 0.4 1 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
"Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
"\r",
"\r",
"AccMetric: acc=0.7465596330275229 \n",
"AccMetric: acc=0.7396788990825688 \n",
"\n"
"\n"
]
]
},
},
@@ -1027,11 +1013,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.14 seconds!\n",
"Evaluate data in 0.22 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
"Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
"\r",
"\r",
"AccMetric: acc=0.7454128440366973 \n",
"AccMetric: acc=0.7442660550458715 \n",
"\n"
"\n"
]
]
},
},
@@ -1054,11 +1040,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.43 seconds!\n",
"Evaluate data in 0.45 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
"Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
"\r",
"\r",
"AccMetric: acc=0.7488532110091743 \n",
"AccMetric: acc=0.6903669724770642 \n",
"\n"
"\n"
]
]
},
},
@@ -1081,11 +1067,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.21 seconds!\n",
"Evaluate data in 0.25 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
"Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
"\r",
"\r",
"AccMetric: acc=0.7431192660550459 \n",
"AccMetric: acc=0.7293577981651376 \n",
"\n"
"\n"
]
]
},
},
@@ -1108,11 +1094,11 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.1 seconds!\n",
"Evaluate data in 0.4 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
"Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
"\r",
"\r",
"AccMetric: acc=0.7477064220183486 \n",
"AccMetric: acc=0.7006880733944955 \n",
"\n"
"\n"
]
]
},
},
@@ -1135,39 +1121,59 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"\r",
"\r",
"Evaluate data in 0.29 seconds!\n",
"Evaluate data in 0.48 seconds!\n",
"\r",
"\r",
"Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
"Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
"\r",
"\r",
"AccMetric: acc=0.7465596330275229 \n",
"AccMetric: acc=0.7339449541284404 \n",
"\n",
"\n",
"\r\n",
"\r\n",
"In Epoch:3/Step:462 , got best dev performance:\n",
"AccMetric: acc=0.7557339449541285 \n",
"In Epoch:4/Step:616 , got best dev performance:\n",
"AccMetric: acc=0.7534403669724771 \n",
"Reloaded the best model.\n"
"Reloaded the best model.\n"
]
]
},
},
{
{
"data": {
"data": {
"text/plain": [
"text/plain": [
"{'best_eval': {'AccMetric': {'acc': 0.7557339449541285 }},\n",
" 'best_epoch': 3 ,\n",
" 'best_step': 462 ,\n",
" 'seconds': 28.68 }"
"{'best_eval': {'AccMetric': {'acc': 0.7534403669724771 }},\n",
" 'best_epoch': 4 ,\n",
" 'best_step': 616 ,\n",
" 'seconds': 34.74 }"
]
]
},
},
"execution_count": 6 ,
"execution_count": 7 ,
"metadata": {},
"metadata": {},
"output_type": "execute_result"
"output_type": "execute_result"
}
}
],
],
"source": [
"source": [
"trainer = Trainer(train_data=train_data, model=model, loss=loss,\n",
" optimizer=optimizer, batch_size=32, dev_data=dev_data,\n",
" metrics=AccMetric(pred=\"pred\", target=\"target\"), device=device)\n",
"trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,\n",
" loss=loss, device=device, metrics=AccMetric())\n",
"trainer.train()"
"trainer.train()"
]
]
},
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"``MetricBase`` 将会在输入的字典 ``pred_dict`` 和 ``target_dict`` 中进行检查.\n",
"``pred_dict`` 是模型当中 ``forward()`` 函数或者 ``predict()`` 函数的返回值.\n",
"``target_dict`` 是DataSet当中的ground truth, 判定ground truth的条件是field的 ``is_target`` 被设置为True.\n",
"\n",
"``MetricBase`` 会进行以下的类型检测:\n",
"\n",
"1. self.evaluate当中是否有 varargs, 这是不支持的.\n",
"2. self.evaluate当中所需要的参数是否既不在 ``pred_dict`` 也不在 ``target_dict`` .\n",
"3. self.evaluate当中所需要的参数是否既在 ``pred_dict`` 也在 ``target_dict`` .\n",
"\n",
"除此以外,在参数被传入self.evaluate以前,这个函数会检测 ``pred_dict`` 和 ``target_dict`` 当中没有被用到的参数\n",
"如果kwargs是self.evaluate的参数,则不会检测\n",
"\n",
"self.evaluate将计算一个批次(batch)的评价指标,并累计。 没有返回值\n",
"self.get_metric将统计当前的评价指标并返回评价结果, 返回值需要是一个dict, key是指标名称,value是指标的值\n"
]
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": null,
"execution_count": null,