|
|
@@ -136,7 +136,7 @@ |
|
|
|
"在`fastNLP 0.8`中,使用`pytorch.nn.Module`搭建需要训练的模型,在搭建模型过程中,除了\n", |
|
|
|
"\n", |
|
|
|
"  添加`pytorch`要求的`forward`方法外,还需要添加 **`train_step`** 和 **`evaluate_step`** 这两个方法\n", |
|
|
|
"***\n", |
|
|
|
"\n", |
|
|
|
"```python\n", |
|
|
|
"class Model(torch.nn.Module):\n", |
|
|
|
" def __init__(self):\n", |
|
|
@@ -177,9 +177,7 @@ |
|
|
|
"\n", |
|
|
|
"  从模块角度,该字典的键值和`metric`中的`update`函数的签名一致,这样的机制在传参时被称为“**参数匹配**”\n", |
|
|
|
"\n", |
|
|
|
"***\n", |
|
|
|
"\n", |
|
|
|
"![fastNLP 0.8 中,Trainer 和 Evaluator 的关系图](./figures/T0-fig-trainer-and-evaluator.png)" |
|
|
|
"<img src=\"./figures/T0-fig-trainer-and-evaluator.png\" width=\"80%\" height=\"80%\" align=\"center\"></img>" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
@@ -206,7 +204,7 @@ |
|
|
|
"  而在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`True`时\n", |
|
|
|
"\n", |
|
|
|
"    `fastNLP 0.8`会将`batch`直接传给模型的`train_step`、`evaluate_step`或`forward`函数\n", |
|
|
|
"***\n", |
|
|
|
"\n", |
|
|
|
"```python\n", |
|
|
|
"class Dataset(torch.utils.data.Dataset):\n", |
|
|
|
" def __init__(self, x, y):\n", |
|
|
@@ -253,7 +251,7 @@ |
|
|
|
"id": "5314482b", |
|
|
|
"metadata": { |
|
|
|
"pycharm": { |
|
|
|
"is_executing": false |
|
|
|
"is_executing": true |
|
|
|
} |
|
|
|
}, |
|
|
|
"outputs": [], |
|
|
@@ -641,11 +639,11 @@ |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/html": [ |
|
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'acc#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.43</span><span style=\"font-weight: bold\">}</span>\n", |
|
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'acc#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.29</span><span style=\"font-weight: bold\">}</span>\n", |
|
|
|
"</pre>\n" |
|
|
|
], |
|
|
|
"text/plain": [ |
|
|
|
"\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.43\u001b[0m\u001b[1m}\u001b[0m\n" |
|
|
|
"\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.29\u001b[0m\u001b[1m}\u001b[0m\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
"metadata": {}, |
|
|
@@ -654,7 +652,7 @@ |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/plain": [ |
|
|
|
"{'acc#acc': 0.43}" |
|
|
|
"{'acc#acc': 0.29}" |
|
|
|
] |
|
|
|
}, |
|
|
|
"execution_count": 9, |
|
|
|