Browse Source

update advance_tutorial jupyter notebook

tags/v0.4.10
xuyige 5 years ago
parent
commit
76f9bbf5f1
1 changed files with 94 additions and 77 deletions
  1. +94
    -77
      tutorials/fastnlp_advanced_tutorial/advance_tutorial.ipynb

+ 94
- 77
tutorials/fastnlp_advanced_tutorial/advance_tutorial.ipynb View File

@@ -20,16 +20,7 @@
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/remote-home/ygxu/anaconda3/envs/no-fastnlp/lib/python3.7/site-packages/tqdm/autonotebook/__init__.py:14: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
" \" (e.g. in jupyter console)\", TqdmExperimentalWarning)\n"
]
}
],
"outputs": [],
"source": [ "source": [
"# 声明部件\n", "# 声明部件\n",
"import torch\n", "import torch\n",
@@ -179,11 +170,11 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"DataSet({'image': tensor([[ 2.1747, -1.0147, -1.3853, 0.0216, -0.4957],\n",
" [ 0.8138, -0.2933, -0.1217, -0.6027, 0.3932],\n",
" [ 0.6750, -1.1136, -1.3371, -0.0185, -0.3206],\n",
" [-0.5076, -0.3822, 0.1719, -0.6447, -0.5702],\n",
" [ 0.3804, 0.0889, 0.8027, -0.7121, -0.7320]]) type=torch.Tensor,\n",
"DataSet({'image': tensor([[ 4.7106e-01, -1.2246e+00, 3.1234e-01, -1.6781e+00, -8.7967e-01],\n",
" [ 1.1454e+00, 1.2236e-01, 3.0258e-01, -1.5454e+00, 8.9201e-01],\n",
" [-5.7143e-03, 3.9488e-01, 2.0287e-01, -1.5726e+00, 9.3171e-01],\n",
" [ 6.8914e-01, -2.6302e-01, -8.2694e-01, 9.5942e-01, -5.2589e-01],\n",
" [-5.7798e-03, -9.1621e-03, 1.0077e-03, 9.1716e-02, 1.0565e+00]]) type=torch.Tensor,\n",
"'label': 0 type=int})" "'label': 0 type=int})"
] ]
}, },
@@ -644,20 +635,20 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"({'premise': [2, 145, 146, 80, 147, 26, 148, 2, 104, 149, 150, 2, 151, 5, 55, 152, 105, 3] type=list,\n",
" 'hypothesis': [22, 80, 8, 1, 1, 20, 1, 3] type=list,\n",
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'label': 2 type=int},\n",
" {'premise': [11, 5, 18, 5, 24, 6, 2, 10, 59, 52, 14, 9, 2, 53, 29, 60, 54, 45, 6, 46, 5, 7, 61, 3] type=list,\n",
" 'hypothesis': [22, 11, 1, 45, 3] type=list,\n",
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'hypothesis_len': [1, 1, 1, 1, 1] type=list,\n",
"({'premise': [2, 10, 9, 2, 15, 115, 6, 11, 5, 132, 17, 2, 76, 9, 77, 55, 3] type=list,\n",
" 'hypothesis': [1, 2, 56, 17, 1, 4, 13, 49, 123, 12, 6, 11, 3] type=list,\n",
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'label': 0 type=int},\n",
" {'premise': [50, 124, 10, 7, 68, 91, 92, 38, 2, 55, 3] type=list,\n",
" 'hypothesis': [21, 10, 5, 2, 55, 7, 99, 64, 48, 1, 22, 1, 3] type=list,\n",
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'label': 1 type=int},\n", " 'label': 1 type=int},\n",
" {'premise': [2, 11, 8, 14, 16, 7, 15, 50, 2, 66, 4, 76, 2, 10, 8, 98, 9, 58, 67, 3] type=list,\n",
" 'hypothesis': [22, 27, 50, 3] type=list,\n",
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'hypothesis_len': [1, 1, 1, 1] type=list,\n",
" {'premise': [13, 24, 4, 14, 29, 5, 25, 4, 8, 39, 9, 14, 34, 4, 40, 41, 4, 16, 12, 2, 11, 4, 30, 28, 2, 42, 8, 2, 43, 44, 17, 2, 45, 35, 26, 31, 27, 5, 6, 32, 3] type=list,\n",
" 'hypothesis': [37, 49, 123, 30, 28, 2, 55, 12, 2, 11, 3] type=list,\n",
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'label': 0 type=int})" " 'label': 0 type=int})"
] ]
}, },
@@ -718,15 +709,15 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"({'premise': [1037, 2210, 2223, 2136, 5363, 2000, 4608, 1037, 5479, 8058, 2046, 1037, 2918, 1999, 2019, 5027, 2208, 1012] type=list,\n",
" 'hypothesis': [100, 2136, 2003, 2652, 3598, 2006, 100, 1012] type=list,\n",
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'label': 2 type=int},\n",
" {'premise': [2450, 1999, 2317, 1999, 100, 1998, 1037, 2158, 3621, 2369, 3788, 2007, 1037, 3696, 2005, 2198, 100, 10733, 1998, 100, 1999, 1996, 4281, 1012] type=list,\n",
" 'hypothesis': [100, 2450, 13063, 10733, 1012] type=list,\n",
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'hypothesis_len': [1, 1, 1, 1, 1] type=list,\n",
"({'premise': [1037, 2158, 1998, 1037, 2450, 2892, 1996, 2395, 1999, 2392, 1997, 1037, 10733, 1998, 100, 4825, 1012] type=list,\n",
" 'hypothesis': [100, 1037, 3232, 1997, 7884, 1010, 2048, 2111, 3328, 2408, 1996, 2395, 1012] type=list,\n",
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'label': 0 type=int},\n",
" {'premise': [2019, 3080, 2158, 2003, 5948, 4589, 10869, 2012, 1037, 4825, 1012] type=list,\n",
" 'hypothesis': [100, 2158, 1999, 1037, 4825, 2003, 3403, 2005, 2010, 7954, 2000, 7180, 1012] type=list,\n",
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'label': 1 type=int})" " 'label': 1 type=int})"
] ]
}, },
@@ -769,7 +760,7 @@
" 'num_classes': 3,\n", " 'num_classes': 3,\n",
" 'gpu': True,\n", " 'gpu': True,\n",
" 'batch_size': 32,\n", " 'batch_size': 32,\n",
" 'vocab_size': 165}"
" 'vocab_size': 156}"
] ]
}, },
"execution_count": 26, "execution_count": 26,
@@ -797,7 +788,7 @@
"ESIM(\n", "ESIM(\n",
" (drop): Dropout(p=0.3)\n", " (drop): Dropout(p=0.3)\n",
" (embedding): Embedding(\n", " (embedding): Embedding(\n",
" (embed): Embedding(165, 300, padding_idx=0)\n",
" (embed): Embedding(156, 300, padding_idx=0)\n",
" (dropout): Dropout(p=0.3)\n", " (dropout): Dropout(p=0.3)\n",
" )\n", " )\n",
" (embedding_layer): Linear(\n", " (embedding_layer): Linear(\n",
@@ -821,7 +812,6 @@
" )\n", " )\n",
" (output): Linear(in_features=300, out_features=3, bias=True)\n", " (output): Linear(in_features=300, out_features=3, bias=True)\n",
" (dropout): Dropout(p=0.3)\n", " (dropout): Dropout(p=0.3)\n",
" (hidden_active): Tanh()\n",
" )\n", " )\n",
")" ")"
] ]
@@ -848,7 +838,7 @@
"text/plain": [ "text/plain": [
"CNNText(\n", "CNNText(\n",
" (embed): Embedding(\n", " (embed): Embedding(\n",
" (embed): Embedding(165, 50, padding_idx=0)\n",
" (embed): Embedding(156, 50, padding_idx=0)\n",
" (dropout): Dropout(p=0.0)\n", " (dropout): Dropout(p=0.0)\n",
" )\n", " )\n",
" (conv_pool): ConvMaxpool(\n", " (conv_pool): ConvMaxpool(\n",
@@ -1019,43 +1009,49 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"training epochs started 2019-01-09 00-08-17\n",
"[tester] \n",
"AccuracyMetric: acc=0.206897\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/remote-home/ygxu/anaconda3/envs/no-fastnlp/lib/python3.7/site-packages/torch/nn/functional.py:1320: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.\n",
" warnings.warn(\"nn.functional.tanh is deprecated. Use torch.tanh instead.\")\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[tester] \n",
"AccuracyMetric: acc=0.206897\n",
"[tester] \n",
"AccuracyMetric: acc=0.206897\n",
"[tester] \n",
"AccuracyMetric: acc=0.206897\n",
"[tester] \n",
"AccuracyMetric: acc=0.206897\n",
"training epochs started 2019-04-14-23-22-28\n",
"[epoch: 1 step: 1] train loss: 1.51372 time: 0:00:00\n",
"[epoch: 1 step: 2] train loss: 1.26874 time: 0:00:00\n",
"[epoch: 1 step: 3] train loss: 1.49786 time: 0:00:00\n",
"[epoch: 1 step: 4] train loss: 1.37505 time: 0:00:00\n",
"Evaluation at Epoch 1/5. Step:4/20. AccuracyMetric: acc=0.344828\n",
"\n",
"[epoch: 2 step: 5] train loss: 1.21877 time: 0:00:00\n",
"[epoch: 2 step: 6] train loss: 1.14183 time: 0:00:00\n",
"[epoch: 2 step: 7] train loss: 1.15934 time: 0:00:00\n",
"[epoch: 2 step: 8] train loss: 1.55148 time: 0:00:00\n",
"Evaluation at Epoch 2/5. Step:8/20. AccuracyMetric: acc=0.344828\n",
"\n", "\n",
"In Epoch:1/Step:4, got best dev performance:AccuracyMetric: acc=0.206897\n",
"[epoch: 3 step: 9] train loss: 1.1457 time: 0:00:00\n",
"[epoch: 3 step: 10] train loss: 1.0547 time: 0:00:00\n",
"[epoch: 3 step: 11] train loss: 1.40139 time: 0:00:00\n",
"[epoch: 3 step: 12] train loss: 0.551445 time: 0:00:00\n",
"Evaluation at Epoch 3/5. Step:12/20. AccuracyMetric: acc=0.275862\n",
"\n",
"[epoch: 4 step: 13] train loss: 1.07965 time: 0:00:00\n",
"[epoch: 4 step: 14] train loss: 1.04118 time: 0:00:00\n",
"[epoch: 4 step: 15] train loss: 1.11719 time: 0:00:00\n",
"[epoch: 4 step: 16] train loss: 1.09861 time: 0:00:00\n",
"Evaluation at Epoch 4/5. Step:16/20. AccuracyMetric: acc=0.275862\n",
"\n",
"[epoch: 5 step: 17] train loss: 1.10795 time: 0:00:00\n",
"[epoch: 5 step: 18] train loss: 1.26715 time: 0:00:00\n",
"[epoch: 5 step: 19] train loss: 1.19875 time: 0:00:00\n",
"[epoch: 5 step: 20] train loss: 1.09862 time: 0:00:00\n",
"Evaluation at Epoch 5/5. Step:20/20. AccuracyMetric: acc=0.37931\n",
"\n",
"\n",
"In Epoch:5/Step:20, got best dev performance:AccuracyMetric: acc=0.37931\n",
"Reloaded the best model.\n" "Reloaded the best model.\n"
] ]
}, },
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"{'best_eval': {'AccuracyMetric': {'acc': 0.206897}},\n",
" 'best_epoch': 1,\n",
" 'best_step': 4,\n",
" 'seconds': 0.79}"
"{'best_eval': {'AccuracyMetric': {'acc': 0.37931}},\n",
" 'best_epoch': 5,\n",
" 'best_step': 20,\n",
" 'seconds': 0.5}"
] ]
}, },
"execution_count": 29, "execution_count": 29,
@@ -1070,8 +1066,8 @@
"trainer = Trainer(\n", "trainer = Trainer(\n",
" train_data=train_data,\n", " train_data=train_data,\n",
" model=model,\n", " model=model,\n",
" loss=CrossEntropyLoss(pred='pred', target='label'),\n",
" metrics=AccuracyMetric(),\n",
" loss=CrossEntropyLoss(pred='pred', target='label'), # 模型预测值通过'pred'来取得,目标值(ground truth)由'label'取得\n",
" metrics=AccuracyMetric(target='label'), # 目标值(ground truth)由'label'取得\n",
" n_epochs=5,\n", " n_epochs=5,\n",
" batch_size=16,\n", " batch_size=16,\n",
" print_every=-1,\n", " print_every=-1,\n",
@@ -1113,13 +1109,13 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"[tester] \n", "[tester] \n",
"AccuracyMetric: acc=0.263158\n"
"AccuracyMetric: acc=0.368421\n"
] ]
}, },
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"{'AccuracyMetric': {'acc': 0.263158}}"
"{'AccuracyMetric': {'acc': 0.368421}}"
] ]
}, },
"execution_count": 30, "execution_count": 30,
@@ -1131,12 +1127,33 @@
"tester = Tester(\n", "tester = Tester(\n",
" data=test_data,\n", " data=test_data,\n",
" model=model,\n", " model=model,\n",
" metrics=AccuracyMetric(),\n",
" metrics=AccuracyMetric(target='label'),\n",
" batch_size=args[\"batch_size\"],\n", " batch_size=args[\"batch_size\"],\n",
")\n", ")\n",
"tester.test()" "tester.test()"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
@@ -1161,7 +1178,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.6.7"
"version": "3.7.0"
} }
}, },
"nbformat": 4, "nbformat": 4,


Loading…
Cancel
Save