diff --git a/mindinsight/mindconverter/tutorial/pytorch_bert_migration_tutorial.ipynb b/mindinsight/mindconverter/tutorial/pytorch_bert_migration_tutorial.ipynb new file mode 100644 index 00000000..88f2379e --- /dev/null +++ b/mindinsight/mindconverter/tutorial/pytorch_bert_migration_tutorial.ipynb @@ -0,0 +1,407 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "headed-output", + "metadata": {}, + "source": [ + "# PyTorch BERT迁移案例\n", + "PyTorch模型转换为MindSpore脚本+权重,首先需要将PyTorch模型导出为ONNX模型,然后使用MindConverter CLI工具进行脚本+权重迁移。\n", + "HuggingFace Transformers是PyTorch框架下主流的自然语言处理三方库,我们以Transformer中的BertForMaskedLM为例,演示迁移过程。" + ] + }, + { + "cell_type": "markdown", + "id": "sustained-touch", + "metadata": {}, + "source": [ + "## 1. ONNX模型导出\n", + "\n", + "首先实例化HuggingFace中的BertForMaskedLM,以及相应的分词器(首次使用需要下降模型权重、词表、模型配置等数据)。\n", + "\n", + "关于HuggingFace的使用,本文不做过多介绍,详细使用请参考[HuggingFace使用文档](https://huggingface.co/transformers/model_doc)。\n", + "\n", + "该模型可对句子中被掩蔽(mask)的词进行预测。" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "interpreted-trunk", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch\n", + "from transformers.models.bert import BertForMaskedLM, BertTokenizer\n", + "\n", + "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n", + "model = BertForMaskedLM.from_pretrained(\"bert-base-uncased\")" + ] + }, + { + "cell_type": "markdown", + "id": "bronze-authentication", + "metadata": {}, + "source": [ + "我们使用该模型进行推理,生成若干组测试用例,以验证模型迁移的正确性。\n", + "\n", + "这里我们以一条句子为例`china is a poworful country, its capital is beijing.`。\n", + "\n", + "我们对`beijing`进行掩蔽(mask),输入`china is a poworful country, its capital is [MASK].`至模型,模型预期输出应为`beijing`。" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "legendary-seven", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MASK TOKEN id: 12\n", + "Tokens: [[ 101 2859 2003 1037 23776 16347 5313 2406 1010 2049 3007 2003\n", + " 103 1012 102]]\n", + "Attention mask: [[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]]\n", + "Token type ids: [[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]\n", + "Pred id: 7211\n", + "Pred token: beijing\n" + ] + } + ], + "source": [ + "text = \"china is a poworful country, its capital is [MASK].\"\n", + "tokenized_sentence = tokenizer(text)\n", + "\n", + "mask_idx = tokenized_sentence[\"input_ids\"].index(tokenizer.convert_tokens_to_ids(\"[MASK]\"))\n", + "input_ids = np.array([tokenized_sentence[\"input_ids\"]])\n", + "attention_mask = np.array([tokenized_sentence[\"attention_mask\"]])\n", + "token_type_ids = np.array([tokenized_sentence[\"token_type_ids\"]])\n", + "\n", + "# Get [MASK] token id.\n", + "print(f\"MASK TOKEN id: {mask_idx}\")\n", + "print(f\"Tokens: {input_ids}\") \n", + "print(f\"Attention mask: {attention_mask}\")\n", + "print(f\"Token type ids: {token_type_ids}\")\n", + "\n", + "model.eval()\n", + "with torch.no_grad():\n", + " predictions = model(input_ids=torch.tensor(input_ids),\n", + " attention_mask=torch.tensor(attention_mask),\n", + " token_type_ids=torch.tensor(token_type_ids))\n", + " predicted_index = torch.argmax(predictions[0][0][mask_idx])\n", + " predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]\n", + " print(f\"Pred id: {predicted_index}\")\n", + " print(f\"Pred token: {predicted_token}\")\n", + " assert predicted_token == \"beijing\"" + ] + }, + { + "cell_type": "markdown", + "id": "opponent-validity", + "metadata": {}, + "source": [ + "HuggingFace提供了导出ONNX模型的工具,可使用如下方法将HuggingFace的预训练模型导出为ONNX模型:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ethical-radiation", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "from transformers.convert_graph_to_onnx import convert\n", + "\n", + "\n", + "saved_onnx_path = \"./exported_bert_base_uncased/bert_base_uncased.onnx\"\n", + "convert(\"pt\", model, Path(saved_onnx_path), 11, tokenizer)" + ] + }, + { + "cell_type": "markdown", + "id": "southeast-response", + "metadata": {}, + "source": [ + "根据打印的信息,我们可以看到导出的ONNX模型输入节点有3个:`input_ids`,`token_type_ids`,`attention_mask`,以及相应的输入轴,\n", + "输出节点有一个`output_0`。\n", + "\n", + "至此ONNX模型导出成功,接下来对导出的ONNX模型精度进行验证。" + ] + }, + { + "cell_type": "markdown", + "id": "historic-business", + "metadata": {}, + "source": [ + "## 2. ONNX模型验证\n" + ] + }, + { + "cell_type": "markdown", + "id": "naval-virgin", + "metadata": {}, + "source": [ + "我们仍然使用PyTorch模型推理时的句子`china is a poworful country, its capital is [MASK].`作为输入,观测ONNX模型表现是否符合预期。" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "satisfactory-embassy", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ONNX Pred id: 7211\n", + "ONNX Pred token: beijing\n" + ] + } + ], + "source": [ + "import onnx\n", + "import onnxruntime as ort\n", + "\n", + "model = onnx.load(saved_onnx_path)\n", + "sess = ort.InferenceSession(bytes(model.SerializeToString()))\n", + "result = sess.run(\n", + " output_names=None,\n", + " input_feed={\"input_ids\": input_ids, \n", + " \"attention_mask\": attention_mask,\n", + " \"token_type_ids\": token_type_ids}\n", + ")[0]\n", + "predicted_index = np.argmax(result[0][mask_idx])\n", + "predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]\n", + "\n", + "print(f\"ONNX Pred id: {predicted_index}\")\n", + "print(f\"ONNX Pred token: {predicted_token}\")\n", + "assert predicted_token == \"beijing\"" + ] + }, + { + "cell_type": "markdown", + "id": "legal-consensus", + "metadata": {}, + "source": [ + "可以看到,导出的ONNX模型功能与原PyTorch模型完全一致,接下来可以使用MindConverter进行脚本+权重迁移了!" + ] + }, + { + "cell_type": "markdown", + "id": "adverse-coverage", + "metadata": {}, + "source": [ + "## 3. MindConverter进行模型脚本+权重迁移" + ] + }, + { + "cell_type": "markdown", + "id": "vanilla-nature", + "metadata": {}, + "source": [ + "MindConverter进行模型转换时,需要给定模型路径(`--model_file`)、输入节点(`--input_nodes`)、输入节点尺寸(`--shape`)、输出节点(`--output_nodes`)。\n", + "\n", + "生成的脚本输出路径(`--output`)、转换报告路径(`--report`)为可选参数,默认为当前路径下的output目录,若输出目录不存在将自动创建。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "metallic-wright", + "metadata": {}, + "outputs": [], + "source": [ + "!mindconverter --help" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "horizontal-heater", + "metadata": {}, + "outputs": [], + "source": [ + "!mindconverter --model_file ./exported_bert_base_uncased/bert_base_uncased.onnx --shape 1,128 1,128 1,128 \\\n", + " --input_nodes input_ids,token_type_ids,attention_mask\n", + " --output_nodes output_0\n", + " --output ./converted_bert_base_uncased\n", + " --report ./converted_bert_base_uncased" + ] + }, + { + "cell_type": "markdown", + "id": "blind-forty", + "metadata": {}, + "source": [ + "转换完成后,该目录下生成如下文件:\n", + "- 模型定义脚本(后缀为.py)\n", + "- 权重ckpt文件(后缀为.ckpt)\n", + "- 迁移前后权重映射(后缀为.json)\n", + "- 转换报告(后缀为.txt)\n", + "\n", + "通过ls命令检查一下转换结果。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "blocked-teens", + "metadata": {}, + "outputs": [], + "source": [ + "!ls ./converted_bert_base_uncased" + ] + }, + { + "cell_type": "markdown", + "id": "improving-difference", + "metadata": {}, + "source": [ + "可以看到所有文件已生成。\n", + "\n", + "迁移完成,接下来我们对迁移后模型精度进行验证。" + ] + }, + { + "cell_type": "markdown", + "id": "dimensional-driver", + "metadata": {}, + "source": [ + "## 4. MindSpore模型验证\n", + "我们仍然使用`china is a poworful country, its capital is [MASK].`作为输入,观测迁移后模型表现是否符合预期。" + ] + }, + { + "cell_type": "markdown", + "id": "unexpected-permit", + "metadata": {}, + "source": [ + "由于工具在转换时,需要将模型尺寸冻结,因此在使用MindSpore进行推理验证时,需要将句子补齐(Pad)到固定长度,可通过如下函数实现句子补齐。\n", + "\n", + "推理时,句子长度需小于转换时的最大句长(这里我们最长句子长度为128)。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "going-fields", + "metadata": {}, + "outputs": [], + "source": [ + "def padding(input_ids, attn_mask, token_type_ids, target_len=128):\n", + " length = len(input_ids)\n", + " for i in range(target_len - length):\n", + " input_ids.append(0)\n", + " attn_mask.append(0)\n", + " token_type_ids.append(0)\n", + " return np.array([input_ids]), np.array([attn_mask]), np.array([token_type_ids])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "intense-carrier", + "metadata": {}, + "outputs": [], + "source": [ + "from converted_bert_base_uncased.bert_base_uncased import Model as MsBert\n", + "from mindspore import load_checkpoint, load_param_into_net, context, Tensor\n", + "\n", + "\n", + "context.set_context(mode=context.GRAPH_MODE, device_target=\"GPU\")\n", + "padded_input_ids, padded_attention_mask, padded_token_type = padding(tokenized_sentence[\"input_ids\"], \n", + " tokenized_sentence[\"attention_mask\"], \n", + " tokenized_sentence[\"token_type_ids\"], \n", + " target_len=128)\n", + "padded_input_ids = Tensor(padded_input_ids)\n", + "padded_attention_mask = Tensor(padded_attention_mask)\n", + "padded_token_type = Tensor(padded_token_type)\n", + "\n", + "model = MsBert()\n", + "param_dict = load_checkpoint(\"./converted_bert_base_uncased/bert_base_uncased.ckpt\")\n", + "not_load_params = load_param_into_net(model, param_dict)\n", + "output = model(padded_attention_mask, padded_input_ids, padded_token_type)\n", + "\n", + "assert not not_load_params\n", + "\n", + "predicted_index = np.argmax(output.asnumpy()[0][mask_idx])\n", + "print(f\"ONNX Pred id: {predicted_index}\")\n", + "assert predicted_index == 7211" + ] + }, + { + "cell_type": "markdown", + "id": "national-norfolk", + "metadata": {}, + "source": [ + "至此,使用MindConverter进行脚本+权重迁移完成。\n", + "\n", + "用户可根据使用场景编写训练、推理、部署脚本,实现个人业务逻辑。" + ] + }, + { + "cell_type": "markdown", + "id": "capital-joint", + "metadata": {}, + "source": [ + "## 5. 其他问题" + ] + }, + { + "cell_type": "markdown", + "id": "magnetic-collective", + "metadata": {}, + "source": [ + "1. 如何修改迁移后脚本的批次大小(Batch size)、句子长度(Sequence length),实现模型可支持任意的尺寸的数据推理、训练?\n", + "\n", + "> 答:迁移后脚本存在shape限制,通常是由于Reshape算子导致,或其他涉及张量排布变化的算子导致。以上述Bert迁移为例,首先创建两个全局变量,作为预期的批次大小、句子长度的表示,而后将Reshape操作的目标尺寸进行修改,相应的替换成批次大小、句子长度的全局变量即可。\n", + "> ./converted_bert_base_uncased/modified_bert_base_uncased.py为修改后的可支持任意尺寸数据训练、推理的脚本,该脚本脚本展示了相应的修改。" + ] + }, + { + "cell_type": "markdown", + "id": "proprietary-yugoslavia", + "metadata": {}, + "source": [ + "2. 生成后的脚本中类名的定义不符合开发者的习惯,如`class Module0(nn.Cell)`,人工修改是否会影响转换后的权重加载?\n", + "\n", + "> 答:权重的加载仅与变量名、类结构有关,因此类名可以修改,不影响权重加载。若需要调整类的结构,则相应的权重命名需要同步修改以适应迁移后模型的结构。" + ] + }, + { + "cell_type": "markdown", + "id": "selective-flight", + "metadata": {}, + "source": [ + "## 6. 其他参考教程\n", + "1. [MindConverter高阶使用教程-自定义生成代码结构]()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}