|
- import os
- import joblib
- import numpy as np
- from learnware.model import BaseModel
- from .model import ConvModel
- import torch
-
-
- class Model(BaseModel):
- def __init__(self):
- super().__init__(input_shape=(3, 32, 32), output_shape=(10,))
- dir_path = os.path.dirname(os.path.abspath(__file__))
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- self.model = ConvModel(channel=3, n_random_features=10).to(self.device)
- self.model.load_state_dict(torch.load(os.path.join(dir_path, "conv_model.pth")))
- self.model.eval()
-
- def fit(self, X: np.ndarray, y: np.ndarray):
- pass
-
- def predict(self, X: np.ndarray) -> np.ndarray:
- X = torch.Tensor(X).to(self.device)
- return self.model(X)
-
- def finetune(self, X: np.ndarray, y: np.ndarray):
- pass
|