You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model0.py 485 B

12345678910111213141516
  1. from learnware.model import BaseModel
  2. import numpy as np
  3. import joblib
  4. import os
  5. class MyModel(BaseModel):
  6. def __init__(self):
  7. super(MyModel, self).__init__(input_shape=(20,), output_shape=(1,))
  8. dir_path = os.path.dirname(os.path.abspath(__file__))
  9. model_path = os.path.join(dir_path, "ridge.pkl")
  10. model = joblib.load(model_path)
  11. self.model = model
  12. def predict(self, X: np.ndarray) -> np.ndarray:
  13. return self.model.predict(X)