You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

example_init.py 840 B

1234567891011121314151617181920212223242526
  1. import os
  2. import joblib
  3. import numpy as np
  4. from learnware.model import BaseModel
  5. from .model import ConvModel
  6. import torch
  7. class Model(BaseModel):
  8. def __init__(self):
  9. super().__init__(input_shape=(3, 32, 32), output_shape=(10,))
  10. dir_path = os.path.dirname(os.path.abspath(__file__))
  11. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  12. self.model = ConvModel(channel=3, n_random_features=10).to(self.device)
  13. self.model.load_state_dict(torch.load(os.path.join(dir_path, "conv_model.pth")))
  14. self.model.eval()
  15. def fit(self, X: np.ndarray, y: np.ndarray):
  16. pass
  17. def predict(self, X: np.ndarray) -> np.ndarray:
  18. X = torch.Tensor(X).to(self.device)
  19. return self.model(X)
  20. def finetune(self, X: np.ndarray, y: np.ndarray):
  21. pass