| @@ -15,25 +15,11 @@ from typing import Tuple, Any, List, Union, Dict | |||
| import scipy | |||
| from sklearn.cluster import MiniBatchKMeans | |||
| # try: | |||
| # import faiss | |||
| # ver = faiss.__version__ | |||
| # _FAISS_INSTALLED = ver >= "1.7.1" | |||
| # except ImportError: | |||
| # _FAISS_INSTALLED = False | |||
| from ..base import RegularStatsSpecification | |||
| from ....logger import get_module_logger | |||
| logger = get_module_logger("rkme") | |||
| # if not _FAISS_INSTALLED: | |||
| # logger.warning( | |||
| # "Required faiss version >= 1.7.1 is not detected! Please run 'conda install -c pytorch faiss-cpu' first" | |||
| # ) | |||
| class RKMETableSpecification(RegularStatsSpecification): | |||
| """Reduced Kernel Mean Embedding (RKME) Specification""" | |||
| @@ -128,7 +114,7 @@ class RKMETableSpecification(RegularStatsSpecification): | |||
| self.beta = torch.from_numpy(self.beta).double().to(self.device) | |||
| return | |||
| # Initialize Z by clustering, utiliing kmeans or faiss to speed up the process. | |||
| # Initialize Z by clustering, utiliing kmeans to speed up the process. | |||
| self._init_z_by_kmeans(X, K) | |||
| self._update_beta(X, nonnegative_beta) | |||
| @@ -140,23 +126,6 @@ class RKMETableSpecification(RegularStatsSpecification): | |||
| # Reshape to original dimensions | |||
| self.z = self.z.reshape(Z_shape) | |||
| # def _init_z_by_faiss(self, X: Union[np.ndarray, torch.tensor], K: int): | |||
| # """Intialize Z by faiss clustering. | |||
| # Parameters | |||
| # ---------- | |||
| # X : np.ndarray or torch.tensor | |||
| # Raw data in np.ndarray format or torch.tensor format. | |||
| # K : int | |||
| # Size of the construced reduced set. | |||
| # """ | |||
| # X = X.astype("float32") | |||
| # numDim = X.shape[1] | |||
| # kmeans = faiss.Kmeans(numDim, K, niter=100, verbose=False) | |||
| # kmeans.train(X) | |||
| # center = torch.from_numpy(kmeans.centroids).double() | |||
| # self.z = center | |||
| def _init_z_by_kmeans(self, X: Union[np.ndarray, torch.tensor], K: int): | |||
| """Intialize Z by kmeans clustering. | |||
| @@ -590,25 +559,8 @@ def rkme_solve_qp(K: np.ndarray, C: np.ndarray): | |||
| A = np.array(np.ones((1, n))) | |||
| A = scipy.sparse.csc_matrix(A) | |||
| b = np.array(np.ones((1, 1))) | |||
| # sol = solve_qp(P, q, G, h, A, b, solver="clarabel") # Requires the sum of x to be 1 | |||
| # sol = solver_qp(P, q, G, h, solver="clarabel") # Otherwise | |||
| problem = Problem(P, q, G, h, A, b) | |||
| solution = solve_problem(problem, solver="clarabel") | |||
| w = solution.x | |||
| w = torch.from_numpy(w).reshape(-1) | |||
| return w, solution.obj | |||
| # from cvxopt import solvers, matrix | |||
| # n = K.shape[0] | |||
| # P = matrix(K) | |||
| # q = matrix(-C) | |||
| # G = matrix(-np.eye(n)) | |||
| # h = matrix(np.zeros((n, 1))) | |||
| # A = matrix(np.ones((1, n))) | |||
| # b = matrix(np.ones((1, 1))) | |||
| # solvers.options["show_progress"] = False | |||
| # sol = solvers.qp(P, q, G, h, A, b) # Requires the sum of x to be 1 | |||
| # # sol = solvers.qp(P, q, G, h) # Otherwise | |||
| # w = np.array(sol["x"]) | |||
| # w = torch.from_numpy(w).reshape(-1) | |||
| # return w, sol["primal objective"] | |||