You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

interface.py 4.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # Copyright 2021 The KubeEdge Authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import pandas as pd
  16. import numpy as np
  17. import xgboost
  18. from sklearn.model_selection import train_test_split
  19. from sklearn.metrics import precision_score
  20. os.environ['BACKEND_TYPE'] = 'SKLEARN'
  21. DATACONF = {
  22. "ATTRIBUTES": ["Season", "Cooling startegy_building level"],
  23. "LABEL": "Thermal preference",
  24. }
  25. def feature_process(df: pd.DataFrame):
  26. if "City" in df.columns:
  27. df.drop(["City"], axis=1, inplace=True)
  28. for feature in df.columns:
  29. if feature in ["Season", ]:
  30. continue
  31. df[feature] = df[feature].apply(lambda x: float(x) if x else 0.0)
  32. df['Thermal preference'] = df['Thermal preference'].apply(
  33. lambda x: int(float(x)) if x else 1)
  34. return df
  35. class Estimator:
  36. def __init__(self):
  37. """Model init"""
  38. self.model = xgboost.XGBClassifier(
  39. learning_rate=0.1,
  40. n_estimators=600,
  41. max_depth=2,
  42. min_child_weight=1,
  43. gamma=0,
  44. subsample=0.8,
  45. colsample_bytree=0.8,
  46. objective="multi:softmax",
  47. num_class=3,
  48. nthread=4,
  49. seed=27)
  50. def train(self, train_data, valid_data=None,
  51. save_best=True,
  52. metric_name="mlogloss",
  53. early_stopping_rounds=100
  54. ):
  55. es = [
  56. xgboost.callback.EarlyStopping(
  57. metric_name=metric_name,
  58. rounds=early_stopping_rounds,
  59. save_best=save_best
  60. )
  61. ]
  62. x, y = train_data.x, train_data.y
  63. if valid_data:
  64. x1, y1 = valid_data.x, valid_data.y
  65. else:
  66. x, x1, y, y1 = train_test_split(
  67. x, y, test_size=0.1, random_state=42)
  68. history = self.model.fit(x, y, eval_set=[(x1, y1), ], callbacks=es)
  69. d = {}
  70. for k, v in history.evals_result().items():
  71. for k1, v1, in v.items():
  72. m = np.mean(v1)
  73. if k1 not in d:
  74. d[k1] = []
  75. d[k1].append(m)
  76. for k, v in d.items():
  77. d[k] = np.mean(v)
  78. return d
  79. def predict(self, datas, **kwargs):
  80. """ Model inference """
  81. return self.model.predict(datas)
  82. def predict_proba(self, datas, **kwargs):
  83. return self.model.predict_proba(datas)
  84. def evaluate(self, test_data, **kwargs):
  85. """ Model evaluate """
  86. y_pred = self.predict(test_data.x)
  87. return precision_score(test_data.y, y_pred, average="micro")
  88. def load(self, model_url):
  89. self.model.load_model(model_url)
  90. return self
  91. def save(self, model_path=None):
  92. """
  93. save model as a single pb file from checkpoint
  94. """
  95. return self.model.save_model(model_path)
  96. if __name__ == '__main__':
  97. from sedna.datasources import CSVDataParse
  98. from sedna.common.config import BaseConfig
  99. train_dataset_url = BaseConfig.train_dataset_url
  100. train_data = CSVDataParse(data_type="train", func=feature_process)
  101. train_data.parse(train_dataset_url, label=DATACONF["LABEL"])
  102. test_dataset_url = BaseConfig.test_dataset_url
  103. valid_data = CSVDataParse(data_type="valid", func=feature_process)
  104. valid_data.parse(test_dataset_url, label=DATACONF["LABEL"])
  105. model = Estimator()
  106. print(model.train(train_data))
  107. print(model.evaluate(test_data=valid_data))