You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

example_init.py 801 B

1234567891011121314151617181920212223242526272829
  1. import os
  2. import pickle
  3. import numpy as np
  4. from learnware.model import BaseModel
  5. class Model(BaseModel):
  6. def __init__(self):
  7. super(Model, self).__init__(input_shape=(1,), output_shape=(1,))
  8. dir_path = os.path.dirname(os.path.abspath(__file__))
  9. modelv_path = os.path.join(dir_path, "modelv.pth")
  10. with open(modelv_path, "rb") as f:
  11. self.modelv = pickle.load(f)
  12. modell_path = os.path.join(dir_path, "modell.pth")
  13. with open(modell_path, "rb") as f:
  14. self.modell = pickle.load(f)
  15. def fit(self, X: np.ndarray, y: np.ndarray):
  16. pass
  17. def predict(self, X: np.ndarray) -> np.ndarray:
  18. return self.modell.predict(self.modelv.transform(X))
  19. def finetune(self, X: np.ndarray, y: np.ndarray):
  20. pass