Browse Source

简化了 tutorial_7 的代码

tags/v0.5.5
ChenXin 5 years ago
parent
commit
9cb7cdb532
2 changed files with 121 additions and 116 deletions
  1. +3
    -4
      docs/source/tutorials/tutorial_7_metrics.rst
  2. +118
    -112
      tutorials/tutorial_7_metrics.ipynb

+ 3
- 4
docs/source/tutorials/tutorial_7_metrics.rst View File

@@ -7,9 +7,8 @@


.. code-block:: python .. code-block:: python


trainer = Trainer(train_data=train_data, model=model, loss=loss,
optimizer=optimizer, batch_size=32, dev_data=dev_data,
metrics=metric, device=device)
trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,
loss=loss, device=device, metrics=metric)
trainer.train() trainer.train()


除了 :class:`~fastNLP.AccuracyMetric` 之外,:class:`~fastNLP.SpanFPreRecMetric` 也是一种非常见的评价指标, 除了 :class:`~fastNLP.AccuracyMetric` 之外,:class:`~fastNLP.SpanFPreRecMetric` 也是一种非常见的评价指标,
@@ -89,7 +88,7 @@
super().__init__() super().__init__()


# 如果没有注册该则效果与 Version 1 就是一样的 # 如果没有注册该则效果与 Version 1 就是一样的
self._init_param_map(pred=pred, target=target) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可
self._init_param_map(pred=pred, target=target) # 该方法会注册 pred 和 target . 仅需要注册evaluate()方法会用到的参数名即可


# 根据你的情况自定义指标 # 根据你的情况自定义指标
self.total = 0 self.total = 0


+ 118
- 112
tutorials/tutorial_7_metrics.ipynb View File

@@ -11,25 +11,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/remote-home/ynzheng/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/io/loader/classification.py:340: UserWarning: SST2's test file has no target.\n"
]
}
],
"outputs": [],
"source": [ "source": [
"from fastNLP.io import SST2Pipe\n", "from fastNLP.io import SST2Pipe\n",
"from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric\n", "from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric\n",
"from fastNLP.models import CNNText\n", "from fastNLP.models import CNNText\n",
"from fastNLP import CrossEntropyLoss\n",
"import torch\n", "import torch\n",
"from torch.optim import Adam\n",
"from fastNLP import AccuracyMetric\n",
"\n", "\n",
"databundle = SST2Pipe().process_from_file()\n", "databundle = SST2Pipe().process_from_file()\n",
"vocab = databundle.get_vocab('words')\n", "vocab = databundle.get_vocab('words')\n",
@@ -40,7 +29,6 @@
"model = CNNText((len(vocab),100), num_classes=2, dropout=0.1)\n", "model = CNNText((len(vocab),100), num_classes=2, dropout=0.1)\n",
"loss = CrossEntropyLoss()\n", "loss = CrossEntropyLoss()\n",
"metric = AccuracyMetric()\n", "metric = AccuracyMetric()\n",
"optimizer = Adam(model.parameters(), lr=0.001)\n",
"device = 0 if torch.cuda.is_available() else 'cpu'" "device = 0 if torch.cuda.is_available() else 'cpu'"
] ]
}, },
@@ -53,7 +41,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": { "metadata": {
"scrolled": true "scrolled": true
}, },
@@ -63,12 +51,12 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"input fields after batch(if batch size is 2):\n", "input fields after batch(if batch size is 2):\n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 13]) \n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"target fields after batch(if batch size is 2):\n", "target fields after batch(if batch size is 2):\n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\n", "\n",
"training epochs started 2020-02-28-00-11-51\n"
"training epochs started 2020-02-28-00-37-08\n"
] ]
}, },
{ {
@@ -104,11 +92,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.16 seconds!\n",
"Evaluate data in 0.28 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 1/10. Step:154/1540: \n", "Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
"\r", "\r",
"AccuracyMetric: acc=0.722477\n",
"AccuracyMetric: acc=0.747706\n",
"\n" "\n"
] ]
}, },
@@ -131,11 +119,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.36 seconds!\n",
"Evaluate data in 0.17 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 2/10. Step:308/1540: \n", "Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
"\r", "\r",
"AccuracyMetric: acc=0.762615\n",
"AccuracyMetric: acc=0.745413\n",
"\n" "\n"
] ]
}, },
@@ -158,11 +146,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.16 seconds!\n",
"Evaluate data in 0.19 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 3/10. Step:462/1540: \n", "Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
"\r", "\r",
"AccuracyMetric: acc=0.771789\n",
"AccuracyMetric: acc=0.74656\n",
"\n" "\n"
] ]
}, },
@@ -185,11 +173,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.44 seconds!\n",
"Evaluate data in 0.15 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 4/10. Step:616/1540: \n", "Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
"\r", "\r",
"AccuracyMetric: acc=0.759174\n",
"AccuracyMetric: acc=0.762615\n",
"\n" "\n"
] ]
}, },
@@ -212,11 +200,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.29 seconds!\n",
"Evaluate data in 0.42 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 5/10. Step:770/1540: \n", "Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
"\r", "\r",
"AccuracyMetric: acc=0.75344\n",
"AccuracyMetric: acc=0.736239\n",
"\n" "\n"
] ]
}, },
@@ -239,11 +227,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.33 seconds!\n",
"Evaluate data in 0.16 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 6/10. Step:924/1540: \n", "Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
"\r", "\r",
"AccuracyMetric: acc=0.75\n",
"AccuracyMetric: acc=0.761468\n",
"\n" "\n"
] ]
}, },
@@ -266,11 +254,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.19 seconds!\n",
"Evaluate data in 0.42 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 7/10. Step:1078/1540: \n", "Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
"\r", "\r",
"AccuracyMetric: acc=0.741972\n",
"AccuracyMetric: acc=0.727064\n",
"\n" "\n"
] ]
}, },
@@ -293,11 +281,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.49 seconds!\n",
"Evaluate data in 0.21 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 8/10. Step:1232/1540: \n", "Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
"\r", "\r",
"AccuracyMetric: acc=0.740826\n",
"AccuracyMetric: acc=0.731651\n",
"\n" "\n"
] ]
}, },
@@ -320,11 +308,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.15 seconds!\n",
"Evaluate data in 0.52 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 9/10. Step:1386/1540: \n", "Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
"\r", "\r",
"AccuracyMetric: acc=0.75\n",
"AccuracyMetric: acc=0.752294\n",
"\n" "\n"
] ]
}, },
@@ -347,36 +335,35 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.16 seconds!\n",
"Evaluate data in 0.44 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 10/10. Step:1540/1540: \n", "Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
"\r", "\r",
"AccuracyMetric: acc=0.752294\n",
"AccuracyMetric: acc=0.760321\n",
"\n", "\n",
"\r\n", "\r\n",
"In Epoch:3/Step:462, got best dev performance:\n",
"AccuracyMetric: acc=0.771789\n",
"In Epoch:4/Step:616, got best dev performance:\n",
"AccuracyMetric: acc=0.762615\n",
"Reloaded the best model.\n" "Reloaded the best model.\n"
] ]
}, },
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"{'best_eval': {'AccuracyMetric': {'acc': 0.771789}},\n",
" 'best_epoch': 3,\n",
" 'best_step': 462,\n",
" 'seconds': 30.04}"
"{'best_eval': {'AccuracyMetric': {'acc': 0.762615}},\n",
" 'best_epoch': 4,\n",
" 'best_step': 616,\n",
" 'seconds': 32.63}"
] ]
}, },
"execution_count": 2,
"execution_count": 3,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
], ],
"source": [ "source": [
"trainer = Trainer(train_data=train_data, model=model, loss=loss,\n",
" optimizer=optimizer, batch_size=32, dev_data=dev_data,\n",
" metrics=metric, device=device)\n",
"trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,\n",
" loss=loss, device=device, metrics=metric)\n",
"trainer.train()" "trainer.train()"
] ]
}, },
@@ -432,7 +419,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -464,7 +451,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": { "metadata": {
"scrolled": true "scrolled": true
}, },
@@ -474,12 +461,12 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"input fields after batch(if batch size is 2):\n", "input fields after batch(if batch size is 2):\n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 13]) \n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"target fields after batch(if batch size is 2):\n", "target fields after batch(if batch size is 2):\n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\n", "\n",
"training epochs started 2020-02-28-00-12-21\n"
"training epochs started 2020-02-28-00-37-41\n"
] ]
}, },
{ {
@@ -515,11 +502,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.33 seconds!\n",
"Evaluate data in 0.27 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 1/10. Step:154/1540: \n", "Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
"\r", "\r",
"AccMetric: acc=0.7419724770642202\n",
"AccMetric: acc=0.7431192660550459\n",
"\n" "\n"
] ]
}, },
@@ -542,11 +529,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.19 seconds!\n",
"Evaluate data in 0.42 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 2/10. Step:308/1540: \n", "Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
"\r", "\r",
"AccMetric: acc=0.7660550458715596\n",
"AccMetric: acc=0.7522935779816514\n",
"\n" "\n"
] ]
}, },
@@ -569,11 +556,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.27 seconds!\n",
"Evaluate data in 0.51 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 3/10. Step:462/1540: \n", "Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
"\r", "\r",
"AccMetric: acc=0.75\n",
"AccMetric: acc=0.7477064220183486\n",
"\n" "\n"
] ]
}, },
@@ -596,11 +583,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.24 seconds!\n",
"Evaluate data in 0.48 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 4/10. Step:616/1540: \n", "Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
"\r", "\r",
"AccMetric: acc=0.7534403669724771\n",
"AccMetric: acc=0.7442660550458715\n",
"\n" "\n"
] ]
}, },
@@ -623,11 +610,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.29 seconds!\n",
"Evaluate data in 0.5 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 5/10. Step:770/1540: \n", "Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
"\r", "\r",
"AccMetric: acc=0.7488532110091743\n",
"AccMetric: acc=0.7362385321100917\n",
"\n" "\n"
] ]
}, },
@@ -650,11 +637,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.14 seconds!\n",
"Evaluate data in 0.45 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 6/10. Step:924/1540: \n", "Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
"\r", "\r",
"AccMetric: acc=0.7488532110091743\n",
"AccMetric: acc=0.7293577981651376\n",
"\n" "\n"
] ]
}, },
@@ -677,11 +664,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.27 seconds!\n",
"Evaluate data in 0.33 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 7/10. Step:1078/1540: \n", "Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
"\r", "\r",
"AccMetric: acc=0.7568807339449541\n",
"AccMetric: acc=0.7190366972477065\n",
"\n" "\n"
] ]
}, },
@@ -704,11 +691,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.42 seconds!\n",
"Evaluate data in 0.29 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 8/10. Step:1232/1540: \n", "Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
"\r", "\r",
"AccMetric: acc=0.7488532110091743\n",
"AccMetric: acc=0.7419724770642202\n",
"\n" "\n"
] ]
}, },
@@ -731,11 +718,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.16 seconds!\n",
"Evaluate data in 0.34 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 9/10. Step:1386/1540: \n", "Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
"\r", "\r",
"AccMetric: acc=0.7408256880733946\n",
"AccMetric: acc=0.7350917431192661\n",
"\n" "\n"
] ]
}, },
@@ -758,36 +745,35 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.28 seconds!\n",
"Evaluate data in 0.18 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 10/10. Step:1540/1540: \n", "Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
"\r", "\r",
"AccMetric: acc=0.7408256880733946\n",
"AccMetric: acc=0.6846330275229358\n",
"\n", "\n",
"\r\n", "\r\n",
"In Epoch:2/Step:308, got best dev performance:\n", "In Epoch:2/Step:308, got best dev performance:\n",
"AccMetric: acc=0.7660550458715596\n",
"AccMetric: acc=0.7522935779816514\n",
"Reloaded the best model.\n" "Reloaded the best model.\n"
] ]
}, },
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"{'best_eval': {'AccMetric': {'acc': 0.7660550458715596}},\n",
"{'best_eval': {'AccMetric': {'acc': 0.7522935779816514}},\n",
" 'best_epoch': 2,\n", " 'best_epoch': 2,\n",
" 'best_step': 308,\n", " 'best_step': 308,\n",
" 'seconds': 29.74}"
" 'seconds': 42.7}"
] ]
}, },
"execution_count": 4,
"execution_count": 5,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
], ],
"source": [ "source": [
"trainer = Trainer(train_data=train_data, model=model, loss=loss,\n",
" optimizer=optimizer, batch_size=32, dev_data=dev_data,\n",
" metrics=AccMetric(), device=device)\n",
"trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,\n",
" loss=loss, device=device, metrics=AccMetric())\n",
"trainer.train()" "trainer.train()"
] ]
}, },
@@ -802,7 +788,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -841,7 +827,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": { "metadata": {
"scrolled": true "scrolled": true
}, },
@@ -851,12 +837,12 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"input fields after batch(if batch size is 2):\n", "input fields after batch(if batch size is 2):\n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 13]) \n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"target fields after batch(if batch size is 2):\n", "target fields after batch(if batch size is 2):\n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\n", "\n",
"training epochs started 2020-02-28-00-12-51\n"
"training epochs started 2020-02-28-00-38-24\n"
] ]
}, },
{ {
@@ -892,11 +878,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.24 seconds!\n",
"Evaluate data in 0.32 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 1/10. Step:154/1540: \n", "Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
"\r", "\r",
"AccMetric: acc=0.7545871559633027\n",
"AccMetric: acc=0.7511467889908257\n",
"\n" "\n"
] ]
}, },
@@ -919,11 +905,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.24 seconds!\n",
"Evaluate data in 0.29 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 2/10. Step:308/1540: \n", "Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
"\r", "\r",
"AccMetric: acc=0.7534403669724771\n",
"AccMetric: acc=0.7454128440366973\n",
"\n" "\n"
] ]
}, },
@@ -946,11 +932,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.18 seconds!\n",
"Evaluate data in 0.42 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 3/10. Step:462/1540: \n", "Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
"\r", "\r",
"AccMetric: acc=0.7557339449541285\n",
"AccMetric: acc=0.7224770642201835\n",
"\n" "\n"
] ]
}, },
@@ -973,11 +959,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.11 seconds!\n",
"Evaluate data in 0.4 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 4/10. Step:616/1540: \n", "Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
"\r", "\r",
"AccMetric: acc=0.7511467889908257\n",
"AccMetric: acc=0.7534403669724771\n",
"\n" "\n"
] ]
}, },
@@ -1000,11 +986,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.19 seconds!\n",
"Evaluate data in 0.41 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 5/10. Step:770/1540: \n", "Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
"\r", "\r",
"AccMetric: acc=0.7465596330275229\n",
"AccMetric: acc=0.7396788990825688\n",
"\n" "\n"
] ]
}, },
@@ -1027,11 +1013,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.14 seconds!\n",
"Evaluate data in 0.22 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 6/10. Step:924/1540: \n", "Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
"\r", "\r",
"AccMetric: acc=0.7454128440366973\n",
"AccMetric: acc=0.7442660550458715\n",
"\n" "\n"
] ]
}, },
@@ -1054,11 +1040,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.43 seconds!\n",
"Evaluate data in 0.45 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 7/10. Step:1078/1540: \n", "Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
"\r", "\r",
"AccMetric: acc=0.7488532110091743\n",
"AccMetric: acc=0.6903669724770642\n",
"\n" "\n"
] ]
}, },
@@ -1081,11 +1067,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.21 seconds!\n",
"Evaluate data in 0.25 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 8/10. Step:1232/1540: \n", "Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
"\r", "\r",
"AccMetric: acc=0.7431192660550459\n",
"AccMetric: acc=0.7293577981651376\n",
"\n" "\n"
] ]
}, },
@@ -1108,11 +1094,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.1 seconds!\n",
"Evaluate data in 0.4 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 9/10. Step:1386/1540: \n", "Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
"\r", "\r",
"AccMetric: acc=0.7477064220183486\n",
"AccMetric: acc=0.7006880733944955\n",
"\n" "\n"
] ]
}, },
@@ -1135,39 +1121,59 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\r", "\r",
"Evaluate data in 0.29 seconds!\n",
"Evaluate data in 0.48 seconds!\n",
"\r", "\r",
"Evaluation on dev at Epoch 10/10. Step:1540/1540: \n", "Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
"\r", "\r",
"AccMetric: acc=0.7465596330275229\n",
"AccMetric: acc=0.7339449541284404\n",
"\n", "\n",
"\r\n", "\r\n",
"In Epoch:3/Step:462, got best dev performance:\n",
"AccMetric: acc=0.7557339449541285\n",
"In Epoch:4/Step:616, got best dev performance:\n",
"AccMetric: acc=0.7534403669724771\n",
"Reloaded the best model.\n" "Reloaded the best model.\n"
] ]
}, },
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"{'best_eval': {'AccMetric': {'acc': 0.7557339449541285}},\n",
" 'best_epoch': 3,\n",
" 'best_step': 462,\n",
" 'seconds': 28.68}"
"{'best_eval': {'AccMetric': {'acc': 0.7534403669724771}},\n",
" 'best_epoch': 4,\n",
" 'best_step': 616,\n",
" 'seconds': 34.74}"
] ]
}, },
"execution_count": 6,
"execution_count": 7,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
], ],
"source": [ "source": [
"trainer = Trainer(train_data=train_data, model=model, loss=loss,\n",
" optimizer=optimizer, batch_size=32, dev_data=dev_data,\n",
" metrics=AccMetric(pred=\"pred\", target=\"target\"), device=device)\n",
"trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,\n",
" loss=loss, device=device, metrics=AccMetric())\n",
"trainer.train()" "trainer.train()"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"``MetricBase`` 将会在输入的字典 ``pred_dict`` 和 ``target_dict`` 中进行检查.\n",
"``pred_dict`` 是模型当中 ``forward()`` 函数或者 ``predict()`` 函数的返回值.\n",
"``target_dict`` 是DataSet当中的ground truth, 判定ground truth的条件是field的 ``is_target`` 被设置为True.\n",
"\n",
"``MetricBase`` 会进行以下的类型检测:\n",
"\n",
"1. self.evaluate当中是否有 varargs, 这是不支持的.\n",
"2. self.evaluate当中所需要的参数是否既不在 ``pred_dict`` 也不在 ``target_dict`` .\n",
"3. self.evaluate当中所需要的参数是否既在 ``pred_dict`` 也在 ``target_dict`` .\n",
"\n",
"除此以外,在参数被传入self.evaluate以前,这个函数会检测 ``pred_dict`` 和 ``target_dict`` 当中没有被用到的参数\n",
"如果kwargs是self.evaluate的参数,则不会检测\n",
"\n",
"self.evaluate将计算一个批次(batch)的评价指标,并累计。 没有返回值\n",
"self.get_metric将统计当前的评价指标并返回评价结果, 返回值需要是一个dict, key是指标名称,value是指标的值\n"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,


Loading…
Cancel
Save