|
|
@@ -193,7 +193,7 @@ |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"data_instances = [torch.randn(1, 45, 45).to(device) for _ in range(32)]\n", |
|
|
|
"data_instances = [torch.randn(1, 45, 45) for _ in range(32)]\n", |
|
|
|
"pred_idx = base_model.predict(X=data_instances)\n", |
|
|
|
"print(\n", |
|
|
|
" f\"Predicted class index for a batch of 32 instances: \"\n", |
|
|
@@ -440,7 +440,7 @@ |
|
|
|
"name": "python", |
|
|
|
"nbconvert_exporter": "python", |
|
|
|
"pygments_lexer": "ipython3", |
|
|
|
"version": "3.8.18" |
|
|
|
"version": "3.8.13" |
|
|
|
}, |
|
|
|
"orig_nbformat": 4, |
|
|
|
"vscode": { |
|
|
|