Browse Source

update example-12 lxr 220527

tags/v1.0.0alpha
lxr-tech 2 years ago
parent
commit
19a48c7101
2 changed files with 218 additions and 608 deletions
  1. +10
    -1
      tutorials/fastnlp_tutorial_e1.ipynb
  2. +208
    -607
      tutorials/fastnlp_tutorial_e2.ipynb

+ 10
- 1
tutorials/fastnlp_tutorial_e1.ipynb View File

@@ -233,7 +233,7 @@
} }
], ],
"source": [ "source": [
"num_labels = 3 if task.startswith(\"mnli\") else 1 if task==\"stsb\" else 2\n",
"num_labels = 3 if task.startswith(\"mnli\") else 1 if task == \"stsb\" else 2\n",
"\n", "\n",
"model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n", "model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n",
"\n", "\n",
@@ -881,6 +881,15 @@
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.7.13" "version": "3.7.13"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
} }
}, },
"nbformat": 4, "nbformat": 4,


+ 208
- 607
tutorials/fastnlp_tutorial_e2.ipynb View File

@@ -4,7 +4,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# E2. 使用 PrefixTuning 完成 SST2 分类"
"# E2. 使用 continuous prompt 完成 SST2 分类"
] ]
}, },
{ {
@@ -35,10 +35,12 @@
], ],
"source": [ "source": [
"import torch\n", "import torch\n",
"import torch.nn as nn\n",
"from torch.optim import AdamW\n", "from torch.optim import AdamW\n",
"from torch.utils.data import DataLoader, Dataset\n", "from torch.utils.data import DataLoader, Dataset\n",
"\n", "\n",
"import torch.nn as nn\n",
"from torch.nn.utils.rnn import pad_sequence\n",
"\n",
"import transformers\n", "import transformers\n",
"from transformers import AutoTokenizer\n", "from transformers import AutoTokenizer\n",
"from transformers import AutoModelForSequenceClassification\n", "from transformers import AutoModelForSequenceClassification\n",
@@ -69,180 +71,226 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 3,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using the latest cached version of the module from /remote-home/xrliu/.cache/huggingface/modules/datasets_modules/datasets/glue/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad (last modified on Thu May 26 15:30:15 2022) since it couldn't be found locally at glue., or remotely on the Hugging Face Hub.\n",
"Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "253d79d7a67e4dc88338448b5bcb3fb9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from datasets import load_dataset, load_metric\n",
"\n",
"dataset = load_dataset(\"glue\", \"mnli\" if task == \"mnli-mm\" else task)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n"
]
}
],
"outputs": [],
"source": [ "source": [
"tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n",
"class PromptEncoder(nn.Module):\n",
" def __init__(self, template, hidden_size):\n",
" nn.Module.__init__(self)\n",
" self.template = template\n",
" self.hidden_size = hidden_size\n",
" self.cloze_mask = [[1] * self.template[0] + [1] * self.template[1]]\n",
" self.cloze_mask = torch.LongTensor(self.cloze_mask).bool()\n",
"\n",
" self.seq_indices = torch.LongTensor(list(range(len(self.cloze_mask[0]))))\n",
" # embed\n",
" self.embedding = torch.nn.Embedding(len(self.cloze_mask[0]), hidden_size)\n",
" # LSTM\n",
" self.lstm_head = torch.nn.LSTM(input_size=hidden_size,\n",
" hidden_size=hidden_size // 2,\n",
" num_layers=2, dropout=0.0,\n",
" bidirectional=True, batch_first=True)\n",
" # MLP\n",
" self.mlp_head = nn.Sequential(nn.Linear(hidden_size, hidden_size),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_size, hidden_size))\n",
" print(\"init prompt encoder...\")\n",
"\n", "\n",
"print(tokenizer(\"Hello, this one sentence!\", \"And this sentence goes with it.\"))"
" def forward(self, device):\n",
" input_embeds = self.embedding(self.seq_indices.to(device)).unsqueeze(0)\n",
" output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]).squeeze()\n",
" return output_embeds"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"task_to_keys = {\n",
" \"cola\": (\"sentence\", None),\n",
" \"mnli\": (\"premise\", \"hypothesis\"),\n",
" \"mnli-mm\": (\"premise\", \"hypothesis\"),\n",
" \"mrpc\": (\"sentence1\", \"sentence2\"),\n",
" \"qnli\": (\"question\", \"sentence\"),\n",
" \"qqp\": (\"question1\", \"question2\"),\n",
" \"rte\": (\"sentence1\", \"sentence2\"),\n",
" \"sst2\": (\"sentence\", None),\n",
" \"stsb\": (\"sentence1\", \"sentence2\"),\n",
" \"wnli\": (\"sentence1\", \"sentence2\"),\n",
"}\n",
"class ClassModel(nn.Module):\n",
" def __init__(self, num_labels, model_checkpoint, pseudo_token='[PROMPT]', template=(3, 3)):\n",
" nn.Module.__init__(self)\n",
" self.template = template\n",
" self.num_labels = num_labels\n",
" self.spell_length = sum(template)\n",
" self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n",
" self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n",
" num_labels=num_labels)\n",
" for param in self.back_bone.parameters():\n",
" param.requires_grad = False\n",
" self.embeddings = self.back_bone.get_input_embeddings()\n",
" \n",
" self.hidden_size = self.embeddings.embedding_dim\n",
" self.tokenizer.add_special_tokens({'additional_special_tokens': [pseudo_token]})\n",
" self.pseudo_token_id = self.tokenizer.get_vocab()[pseudo_token]\n",
" self.pad_token_id = self.tokenizer.pad_token_id\n",
" \n",
" self.prompt_encoder = PromptEncoder(self.template, self.hidden_size)\n",
"\n",
" self.loss_fn = nn.CrossEntropyLoss()\n",
"\n",
" def get_query(self, query):\n",
" device = query.device\n",
" return torch.cat([torch.tensor([self.tokenizer.cls_token_id]).to(device), # [CLS]\n",
" torch.tensor([self.pseudo_token_id] * self.template[0]).to(device), # [PROMPT]\n",
" torch.tensor([self.tokenizer.mask_token_id]).to(device), # [MASK] \n",
" torch.tensor([self.pseudo_token_id] * self.template[1]).to(device), # [PROMPT]\n",
" query, \n",
" torch.tensor([self.tokenizer.sep_token_id]).to(device)], dim=0) # [SEP]\n",
"\n",
" def forward(self, input_ids):\n",
" input_ids = torch.stack([self.get_query(input_ids[i]) for i in range(len(input_ids))])\n",
" attention_mask = input_ids != self.pad_token_id\n",
" \n",
" bz = input_ids.shape[0]\n",
" inputs_embeds = input_ids.clone()\n",
" inputs_embeds[(input_ids == self.pseudo_token_id)] = self.tokenizer.unk_token_id\n",
" inputs_embeds = self.embeddings(inputs_embeds)\n",
"\n",
" blocked_indices = (input_ids == self.pseudo_token_id).nonzero().reshape((bz, self.spell_length, 2))[:, :, 1] # bz\n",
" replace_embeds = self.prompt_encoder(input_ids.device)\n",
" for bidx in range(bz):\n",
" for i in range(self.spell_length):\n",
" inputs_embeds[bidx, blocked_indices[bidx, i], :] = replace_embeds[i, :]\n",
" \n",
" return self.back_bone(inputs_embeds=inputs_embeds, attention_mask=attention_mask)\n",
"\n", "\n",
"sentence1_key, sentence2_key = task_to_keys[task]"
" def train_step(self, input_ids, attention_mask, labels):\n",
" pred = self(input_ids).logits\n",
" return {\"loss\": self.loss_fn(pred, labels)}\n",
"\n",
" def evaluate_step(self, input_ids, attention_mask, labels):\n",
" pred = self(input_ids).logits\n",
" pred = torch.max(pred, dim=-1)[1]\n",
" return {\"pred\": pred, \"target\": labels}"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_transform.weight']\n",
"- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.weight', 'pre_classifier.weight', 'classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
},
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Sentence: hide new secretions from the parental units \n"
"init prompt encoder...\n"
] ]
} }
], ],
"source": [ "source": [
"if sentence2_key is None:\n",
" print(f\"Sentence: {dataset['train'][0][sentence1_key]}\")\n",
"else:\n",
" print(f\"Sentence 1: {dataset['train'][0][sentence1_key]}\")\n",
" print(f\"Sentence 2: {dataset['train'][0][sentence2_key]}\")"
"num_labels = 3 if task.startswith(\"mnli\") else 1 if task == \"stsb\" else 2\n",
"\n",
"model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n",
"\n",
"optimizers = AdamW(params=model.parameters(), lr=5e-4)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7,
"metadata": {},
"execution_count": 6,
"metadata": {
"scrolled": false
},
"outputs": [ "outputs": [
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ca1fbe5e8eb059f3.arrow\n",
"Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-03661263fbf302f5.arrow\n",
"Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-fbe8e7a4e4f18f45.arrow\n"
"Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
] ]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f82d2ccee863492582f94552654482f9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
} }
], ],
"source": [ "source": [
"def preprocess_function(examples):\n",
" if sentence2_key is None:\n",
" return tokenizer(examples[sentence1_key], truncation=True)\n",
" return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)\n",
"\n",
"encoded_dataset = dataset.map(preprocess_function, batched=True)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"class ClassModel(nn.Module):\n",
" def __init__(self, num_labels, model_checkpoint):\n",
" nn.Module.__init__(self)\n",
" self.num_labels = num_labels\n",
" self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n",
" num_labels=num_labels)\n",
" self.loss_fn = nn.CrossEntropyLoss()\n",
"\n",
" def forward(self, input_ids, attention_mask):\n",
" return self.back_bone(input_ids, attention_mask)\n",
"\n",
" def train_step(self, input_ids, attention_mask, labels):\n",
" pred = self(input_ids, attention_mask).logits\n",
" return {\"loss\": self.loss_fn(pred, labels)}\n",
"from datasets import load_dataset, load_metric\n",
"\n", "\n",
" def evaluate_step(self, input_ids, attention_mask, labels):\n",
" pred = self(input_ids, attention_mask).logits\n",
" pred = torch.max(pred, dim=-1)[1]\n",
" return {\"pred\": pred, \"target\": labels}"
"dataset = load_dataset(\"glue\", \"mnli\" if task == \"mnli-mm\" else task)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9,
"execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight']\n",
"- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'classifier.bias', 'pre_classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cf324902e7b94ea9be709b979b425c96",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/68 [00:00<?, ?ba/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "21eb6203ec6f4592b8cb8530a59eda49",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1 [00:00<?, ?ba/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "05b83c4b1a9f44aea805788e1e52db78",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/2 [00:00<?, ?ba/s]"
]
},
"metadata": {},
"output_type": "display_data"
} }
], ],
"source": [ "source": [
"num_labels = 3 if task.startswith(\"mnli\") else 1 if task==\"stsb\" else 2\n",
"\n",
"model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n",
"def preprocess_function(examples):\n",
" return model.tokenizer(examples['sentence'], truncation=True)\n",
"\n", "\n",
"optimizers = AdamW(params=model.parameters(), lr=5e-5)"
"encoded_dataset = dataset.map(preprocess_function, batched=True)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10,
"execution_count": 8,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -261,7 +309,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11,
"execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -287,7 +335,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12,
"execution_count": 10,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -301,7 +349,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13,
"execution_count": 11,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -319,54 +367,15 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# help(model.back_bone.forward)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 12,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/html": [ "text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[21:00:11] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Running evaluator sanity check for <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> batches. <a href=\"file://../fastNLP/core/controllers/trainer.py\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">trainer.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../fastNLP/core/controllers/trainer.py#592\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">592</span></a>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[2;36m[21:00:11]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=22992;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=669026;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
], ],
"text/plain": [
"\n"
]
"text/plain": []
}, },
"metadata": {}, "metadata": {},
"output_type": "display_data" "output_type": "display_data"
@@ -374,12 +383,9 @@
{ {
"data": { "data": {
"text/html": [ "text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
], ],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
"text/plain": []
}, },
"metadata": {}, "metadata": {},
"output_type": "display_data" "output_type": "display_data"
@@ -387,33 +393,32 @@
{ {
"data": { "data": {
"text/html": [ "text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.871875</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">279.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n" "</pre>\n"
], ],
"text/plain": [ "text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
"\n"
] ]
}, },
"metadata": {}, "metadata": {},
"output_type": "display_data" "output_type": "display_data"
},
}
],
"source": [
"trainer.run(num_eval_batch_per_dl=10)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{ {
"data": { "data": {
"text/html": [ "text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
], ],
"text/plain": [
"\n"
]
"text/plain": []
}, },
"metadata": {}, "metadata": {},
"output_type": "display_data" "output_type": "display_data"
@@ -421,439 +426,26 @@
{ {
"data": { "data": {
"text/html": [ "text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
], ],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
"text/plain": []
}, },
"metadata": {}, "metadata": {},
"output_type": "display_data" "output_type": "display_data"
}, },
{ {
"data": { "data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.878125</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">281.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [ "text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
"{'acc#acc': 0.565367, 'total#acc': 872.0, 'correct#acc': 493.0}"
] ]
}, },
"execution_count": 13,
"metadata": {}, "metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.871875</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">279.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.903125</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">289.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.903125\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m289.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.871875</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">279.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.890625</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">285.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.890625\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m285.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.875</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">280.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.875\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m280.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8875</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">284.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">9</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8875</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">284.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.890625</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">285.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.890625\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m285.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
"output_type": "execute_result"
} }
], ],
"source": [ "source": [
"trainer.run(num_eval_batch_per_dl=10)"
"trainer.evaluator.run()"
] ]
}, },
{ {
@@ -881,6 +473,15 @@
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.7.13" "version": "3.7.13"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
} }
}, },
"nbformat": 4, "nbformat": 4,


Loading…
Cancel
Save