|
|
@@ -0,0 +1,768 @@ |
|
|
|
{ |
|
|
|
"cells": [ |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"fastNLP上手教程\n", |
|
|
|
"-------\n", |
|
|
|
"\n", |
|
|
|
"fastNLP提供方便的数据预处理,训练和测试模型的功能" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 3, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"import sys\n", |
|
|
|
"sys.path.append('/Users/yh/Desktop/fastNLP/fastNLP/')" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"DataSet & Instance\n", |
|
|
|
"------\n", |
|
|
|
"\n", |
|
|
|
"fastNLP用DataSet和Instance保存和处理数据。每个DataSet表示一个数据集,每个Instance表示一个数据样本。一个DataSet存有多个Instance,每个Instance可以自定义存哪些内容。\n", |
|
|
|
"\n", |
|
|
|
"有一些read_*方法,可以轻松从文件读取数据,存成DataSet。" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 4, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"8529\n" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"from fastNLP import DataSet\n", |
|
|
|
"from fastNLP import Instance\n", |
|
|
|
"\n", |
|
|
|
"# 从csv读取数据到DataSet\n", |
|
|
|
"dataset = DataSet.read_csv('../sentence.csv', headers=('raw_sentence', 'label'), sep='\\t')\n", |
|
|
|
"print(len(dataset))" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 5, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"{'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n", |
|
|
|
"'label': 1}\n", |
|
|
|
"{'raw_sentence': -LRB- Tries -RRB- to parody a genre that 's already a joke in the United States .,\n", |
|
|
|
"'label': 1}\n" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"# 使用数字索引[k],获取第k个样本\n", |
|
|
|
"print(dataset[0])\n", |
|
|
|
"\n", |
|
|
|
"# 索引也可以是负数\n", |
|
|
|
"print(dataset[-3])" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"## Instance\n", |
|
|
|
"Instance表示一个样本,由一个或多个field(域,属性,特征)组成,每个field有名字和值。\n", |
|
|
|
"\n", |
|
|
|
"在初始化Instance时即可定义它包含的域,使用 \"field_name=field_value\"的写法。" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 6, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/plain": [ |
|
|
|
"{'raw_sentence': fake data,\n", |
|
|
|
"'label': 0}" |
|
|
|
] |
|
|
|
}, |
|
|
|
"execution_count": 6, |
|
|
|
"metadata": {}, |
|
|
|
"output_type": "execute_result" |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"# DataSet.append(Instance)加入新数据\n", |
|
|
|
"dataset.append(Instance(raw_sentence='fake data', label='0'))\n", |
|
|
|
"dataset[-1]" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"## DataSet.apply方法\n", |
|
|
|
"数据预处理利器" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 7, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n", |
|
|
|
"'label': 1}\n" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"# 将所有数字转为小写\n", |
|
|
|
"dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n", |
|
|
|
"print(dataset[0])" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 8, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n", |
|
|
|
"'label': 1}\n" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"# label转int\n", |
|
|
|
"dataset.apply(lambda x: int(x['label']), new_field_name='label')\n", |
|
|
|
"print(dataset[0])" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 9, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"ename": "RuntimeError", |
|
|
|
"evalue": "Cannot create FieldArray with an empty list.", |
|
|
|
"output_type": "error", |
|
|
|
"traceback": [ |
|
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", |
|
|
|
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", |
|
|
|
"\u001b[0;32m<ipython-input-9-d70cf5545af4>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msplit_sent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mins\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mins\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'raw_sentence'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msplit_sent\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnew_field_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'words'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
|
|
|
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/dataset.py\u001b[0m in \u001b[0;36mapply\u001b[0;34m(self, func, new_field_name, **kwargs)\u001b[0m\n\u001b[1;32m 265\u001b[0m **extra_param)\n\u001b[1;32m 266\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 267\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_field\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnew_field_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfields\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mextra_param\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 268\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
|
|
|
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/dataset.py\u001b[0m in \u001b[0;36madd_field\u001b[0;34m(self, name, fields, padding_val, is_input, is_target)\u001b[0m\n\u001b[1;32m 158\u001b[0m f\"Dataset size {len(self)} != field size {len(fields)}\")\n\u001b[1;32m 159\u001b[0m self.field_arrays[name] = FieldArray(name, fields, padding_val=padding_val, is_target=is_target,\n\u001b[0;32m--> 160\u001b[0;31m is_input=is_input)\n\u001b[0m\u001b[1;32m 161\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdelete_field\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
|
|
|
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, name, content, padding_val, is_target, is_input)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_input\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 38\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_input\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mis_input\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 39\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_target\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_target\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mis_target\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
|
|
|
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36mis_input\u001b[0;34m(self, value)\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mis_input\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mis_input\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 48\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpytype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_type_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 49\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_map_to_np_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpytype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_is_input\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
|
|
|
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m_type_detection\u001b[0;34m(self, content)\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# strict check 2-D list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Please provide 2-D list.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 73\u001b[0;31m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_type_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 74\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype_set\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mint\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfloat\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
|
|
|
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# strict check 2-D list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Please provide 2-D list.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 73\u001b[0;31m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_type_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 74\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype_set\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mint\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfloat\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
|
|
|
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m_type_detection\u001b[0;34m(self, content)\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[0;31m# content is a 1-D list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 84\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Cannot create FieldArray with an empty list.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 85\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mitem\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", |
|
|
|
"\u001b[0;31mRuntimeError\u001b[0m: Cannot create FieldArray with an empty list." |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"# 使用空格分割句子\n", |
|
|
|
"def split_sent(ins):\n", |
|
|
|
" return ins['raw_sentence'].split()\n", |
|
|
|
"dataset.apply(split_sent, new_field_name='words')\n", |
|
|
|
"print(dataset[0])" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 17, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n", |
|
|
|
"'label': 1,\n", |
|
|
|
"'words': ['a', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.'],\n", |
|
|
|
"'seq_len': 37}\n" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"# 增加长度信息\n", |
|
|
|
"dataset.apply(lambda x: len(x['words']), new_field_name='seq_len')\n", |
|
|
|
"print(dataset[0])" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"## DataSet.drop\n", |
|
|
|
"筛选数据" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 19, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"38\n" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"dataset.drop(lambda x: x['seq_len'] <= 3)\n", |
|
|
|
"print(len(dataset))" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"## 配置DataSet\n", |
|
|
|
"1. 哪些域是特征,哪些域是标签\n", |
|
|
|
"2. 切分训练集/验证集" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 20, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"# 设置DataSet中,哪些field要转为tensor\n", |
|
|
|
"\n", |
|
|
|
"# set target,loss或evaluate中的golden,计算loss,模型评估时使用\n", |
|
|
|
"dataset.set_target(\"label\")\n", |
|
|
|
"# set input,模型forward时使用\n", |
|
|
|
"dataset.set_input(\"words\")" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 21, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"27\n", |
|
|
|
"11" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"# 分出测试集、训练集\n", |
|
|
|
"\n", |
|
|
|
"test_data, train_data = dataset.split(0.3)\n", |
|
|
|
"print(len(test_data))\n", |
|
|
|
"print(len(train_data))" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"Vocabulary\n", |
|
|
|
"------\n", |
|
|
|
"\n", |
|
|
|
"fastNLP中的Vocabulary轻松构建词表,将词转成数字" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 22, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"{'raw_sentence': that the chuck norris `` grenade gag '' occurs about 7 times during windtalkers is a good indication of how serious-minded the film is .,\n", |
|
|
|
"'label': 2,\n", |
|
|
|
"'words': [6, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 8, 24, 1, 5, 1, 1, 2, 15, 10, 3],\n", |
|
|
|
"'seq_len': 25}\n" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"from fastNLP import Vocabulary\n", |
|
|
|
"\n", |
|
|
|
"# 构建词表, Vocabulary.add(word)\n", |
|
|
|
"vocab = Vocabulary(min_freq=2)\n", |
|
|
|
"train_data.apply(lambda x: [vocab.add(word) for word in x['words']])\n", |
|
|
|
"vocab.build_vocab()\n", |
|
|
|
"\n", |
|
|
|
"# index句子, Vocabulary.to_index(word)\n", |
|
|
|
"train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')\n", |
|
|
|
"test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"print(test_data[0])" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"# Model\n", |
|
|
|
"定义一个PyTorch模型" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 23, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/plain": [ |
|
|
|
"CNNText(\n", |
|
|
|
" (embed): Embedding(\n", |
|
|
|
" (embed): Embedding(32, 50, padding_idx=0)\n", |
|
|
|
" (dropout): Dropout(p=0.0)\n", |
|
|
|
" )\n", |
|
|
|
" (conv_pool): ConvMaxpool(\n", |
|
|
|
" (convs): ModuleList(\n", |
|
|
|
" (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))\n", |
|
|
|
" (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))\n", |
|
|
|
" (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))\n", |
|
|
|
" )\n", |
|
|
|
" )\n", |
|
|
|
" (dropout): Dropout(p=0.1)\n", |
|
|
|
" (fc): Linear(\n", |
|
|
|
" (linear): Linear(in_features=12, out_features=5, bias=True)\n", |
|
|
|
" )\n", |
|
|
|
")" |
|
|
|
] |
|
|
|
}, |
|
|
|
"execution_count": 23, |
|
|
|
"metadata": {}, |
|
|
|
"output_type": "execute_result" |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"from fastNLP.models import CNNText\n", |
|
|
|
"model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n", |
|
|
|
"model" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"这是上述模型的forward方法。如果你不知道什么是forward方法,请参考我们的PyTorch教程。\n", |
|
|
|
"\n", |
|
|
|
"注意两点:\n", |
|
|
|
"1. forward参数名字叫**word_seq**,请记住。\n", |
|
|
|
"2. forward的返回值是一个**dict**,其中有个key的名字叫**output**。\n", |
|
|
|
"\n", |
|
|
|
"```Python\n", |
|
|
|
" def forward(self, word_seq):\n", |
|
|
|
" \"\"\"\n", |
|
|
|
"\n", |
|
|
|
" :param word_seq: torch.LongTensor, [batch_size, seq_len]\n", |
|
|
|
" :return output: dict of torch.LongTensor, [batch_size, num_classes]\n", |
|
|
|
" \"\"\"\n", |
|
|
|
" x = self.embed(word_seq) # [N,L] -> [N,L,C]\n", |
|
|
|
" x = self.conv_pool(x) # [N,L,C] -> [N,C]\n", |
|
|
|
" x = self.dropout(x)\n", |
|
|
|
" x = self.fc(x) # [N,C] -> [N, N_class]\n", |
|
|
|
" return {'output': x}\n", |
|
|
|
"```" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"这是上述模型的predict方法,是用来直接输出该任务的预测结果,与forward目的不同。\n", |
|
|
|
"\n", |
|
|
|
"注意两点:\n", |
|
|
|
"1. predict参数名也叫**word_seq**。\n", |
|
|
|
"2. predict的返回值是也一个**dict**,其中有个key的名字叫**predict**。\n", |
|
|
|
"\n", |
|
|
|
"```\n", |
|
|
|
" def predict(self, word_seq):\n", |
|
|
|
" \"\"\"\n", |
|
|
|
"\n", |
|
|
|
" :param word_seq: torch.LongTensor, [batch_size, seq_len]\n", |
|
|
|
" :return predict: dict of torch.LongTensor, [batch_size, seq_len]\n", |
|
|
|
" \"\"\"\n", |
|
|
|
" output = self(word_seq)\n", |
|
|
|
" _, predict = output['output'].max(dim=1)\n", |
|
|
|
" return {'predict': predict}\n", |
|
|
|
"```" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"Trainer & Tester\n", |
|
|
|
"------\n", |
|
|
|
"\n", |
|
|
|
"使用fastNLP的Trainer训练模型" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 25, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"from fastNLP import Trainer\n", |
|
|
|
"from copy import deepcopy\n", |
|
|
|
"from fastNLP.core.losses import CrossEntropyLoss\n", |
|
|
|
"from fastNLP.core.metrics import AccuracyMetric\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"# 更改DataSet中对应field的名称,与模型的forward的参数名一致\n", |
|
|
|
"# 因为forward的参数叫word_seq, 所以要把原本叫words的field改名为word_seq\n", |
|
|
|
"# 这里的演示是让你了解这种**命名规则**\n", |
|
|
|
"train_data.rename_field('words', 'word_seq')\n", |
|
|
|
"test_data.rename_field('words', 'word_seq')\n", |
|
|
|
"\n", |
|
|
|
"# 顺便把label换名为label_seq\n", |
|
|
|
"train_data.rename_field('label', 'label_seq')\n", |
|
|
|
"test_data.rename_field('label', 'label_seq')" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"### loss\n", |
|
|
|
"训练模型需要提供一个损失函数\n", |
|
|
|
"\n", |
|
|
|
"下面提供了一个在分类问题中常用的交叉熵损失。注意它的**初始化参数**。\n", |
|
|
|
"\n", |
|
|
|
"pred参数对应的是模型的forward返回的dict的一个key的名字,这里是\"output\"。\n", |
|
|
|
"\n", |
|
|
|
"target参数对应的是dataset作为标签的field的名字,这里是\"label_seq\"。" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 26, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"loss = CrossEntropyLoss(pred=\"output\", target=\"label_seq\")" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"### Metric\n", |
|
|
|
"定义评价指标\n", |
|
|
|
"\n", |
|
|
|
"这里使用准确率。参数的“命名规则”跟上面类似。\n", |
|
|
|
"\n", |
|
|
|
"pred参数对应的是模型的predict方法返回的dict的一个key的名字,这里是\"predict\"。\n", |
|
|
|
"\n", |
|
|
|
"target参数对应的是dataset作为标签的field的名字,这里是\"label_seq\"。" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 27, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"metric = AccuracyMetric(pred=\"predict\", target=\"label_seq\")" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 30, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"training epochs started 2018-12-04 22:51:24\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stderr", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
" \r" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"Epoch 1/5. Step:1/5. AccuracyMetric: acc=0.296296\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stderr", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
" \r" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"Epoch 2/5. Step:2/5. AccuracyMetric: acc=0.407407\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stderr", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
" \r" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"Epoch 3/5. Step:3/5. AccuracyMetric: acc=0.518519\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stderr", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
" \r" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"Epoch 4/5. Step:4/5. AccuracyMetric: acc=0.481481\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stderr", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
" \r" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"Epoch 5/5. Step:5/5. AccuracyMetric: acc=0.592593\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stderr", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
" \r" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"# 实例化Trainer,传入模型和数据,进行训练\n", |
|
|
|
"# 先在test_data拟合\n", |
|
|
|
"copy_model = deepcopy(model)\n", |
|
|
|
"overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data,\n", |
|
|
|
" losser=loss,\n", |
|
|
|
" metrics=metric,\n", |
|
|
|
" save_path=None,\n", |
|
|
|
" batch_size=32,\n", |
|
|
|
" n_epochs=5)\n", |
|
|
|
"overfit_trainer.train()" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 31, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"training epochs started 2018-12-04 22:52:01\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stderr", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
" \r" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"Epoch 1/5. Step:1/5. AccuracyMetric: acc=0.296296\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stderr", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
" \r" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"Epoch 2/5. Step:2/5. AccuracyMetric: acc=0.222222\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stderr", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
" \r" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"Epoch 3/5. Step:3/5. AccuracyMetric: acc=0.259259\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stderr", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
" \r" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"Epoch 4/5. Step:4/5. AccuracyMetric: acc=0.296296\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stderr", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
" \r" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"Epoch 5/5. Step:5/5. AccuracyMetric: acc=0.259259\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stderr", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
" \r" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"Train finished!\n" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"# 用train_data训练,在test_data验证\n", |
|
|
|
"trainer = Trainer(model=model, train_data=train_data, dev_data=test_data,\n", |
|
|
|
" losser=CrossEntropyLoss(pred=\"output\", target=\"label_seq\"),\n", |
|
|
|
" metrics=AccuracyMetric(pred=\"predict\", target=\"label_seq\"),\n", |
|
|
|
" save_path=None,\n", |
|
|
|
" batch_size=32,\n", |
|
|
|
" n_epochs=5)\n", |
|
|
|
"trainer.train()\n", |
|
|
|
"print('Train finished!')" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 33, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"[tester] \n", |
|
|
|
"AccuracyMetric: acc=0.259259\n", |
|
|
|
"{'AccuracyMetric': {'acc': 0.259259}}\n" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"# 调用Tester在test_data上评价效果\n", |
|
|
|
"from fastNLP import Tester\n", |
|
|
|
"\n", |
|
|
|
"tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred=\"predict\", target=\"label_seq\"),\n", |
|
|
|
" batch_size=4)\n", |
|
|
|
"acc = tester.test()\n", |
|
|
|
"print(acc)" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [] |
|
|
|
} |
|
|
|
], |
|
|
|
"metadata": { |
|
|
|
"kernelspec": { |
|
|
|
"display_name": "Python 3", |
|
|
|
"language": "python", |
|
|
|
"name": "python3" |
|
|
|
}, |
|
|
|
"language_info": { |
|
|
|
"codemirror_mode": { |
|
|
|
"name": "ipython", |
|
|
|
"version": 3 |
|
|
|
}, |
|
|
|
"file_extension": ".py", |
|
|
|
"mimetype": "text/x-python", |
|
|
|
"name": "python", |
|
|
|
"nbconvert_exporter": "python", |
|
|
|
"pygments_lexer": "ipython3", |
|
|
|
"version": "3.6.7" |
|
|
|
} |
|
|
|
}, |
|
|
|
"nbformat": 4, |
|
|
|
"nbformat_minor": 2 |
|
|
|
} |