| @@ -15,25 +15,11 @@ from typing import Tuple, Any, List, Union, Dict | |||||
| import scipy | import scipy | ||||
| from sklearn.cluster import MiniBatchKMeans | from sklearn.cluster import MiniBatchKMeans | ||||
| # try: | |||||
| # import faiss | |||||
| # ver = faiss.__version__ | |||||
| # _FAISS_INSTALLED = ver >= "1.7.1" | |||||
| # except ImportError: | |||||
| # _FAISS_INSTALLED = False | |||||
| from ..base import RegularStatsSpecification | from ..base import RegularStatsSpecification | ||||
| from ....logger import get_module_logger | from ....logger import get_module_logger | ||||
| logger = get_module_logger("rkme") | logger = get_module_logger("rkme") | ||||
| # if not _FAISS_INSTALLED: | |||||
| # logger.warning( | |||||
| # "Required faiss version >= 1.7.1 is not detected! Please run 'conda install -c pytorch faiss-cpu' first" | |||||
| # ) | |||||
| class RKMETableSpecification(RegularStatsSpecification): | class RKMETableSpecification(RegularStatsSpecification): | ||||
| """Reduced Kernel Mean Embedding (RKME) Specification""" | """Reduced Kernel Mean Embedding (RKME) Specification""" | ||||
| @@ -128,7 +114,7 @@ class RKMETableSpecification(RegularStatsSpecification): | |||||
| self.beta = torch.from_numpy(self.beta).double().to(self.device) | self.beta = torch.from_numpy(self.beta).double().to(self.device) | ||||
| return | return | ||||
| # Initialize Z by clustering, utiliing kmeans or faiss to speed up the process. | |||||
| # Initialize Z by clustering, utiliing kmeans to speed up the process. | |||||
| self._init_z_by_kmeans(X, K) | self._init_z_by_kmeans(X, K) | ||||
| self._update_beta(X, nonnegative_beta) | self._update_beta(X, nonnegative_beta) | ||||
| @@ -140,23 +126,6 @@ class RKMETableSpecification(RegularStatsSpecification): | |||||
| # Reshape to original dimensions | # Reshape to original dimensions | ||||
| self.z = self.z.reshape(Z_shape) | self.z = self.z.reshape(Z_shape) | ||||
| # def _init_z_by_faiss(self, X: Union[np.ndarray, torch.tensor], K: int): | |||||
| # """Intialize Z by faiss clustering. | |||||
| # Parameters | |||||
| # ---------- | |||||
| # X : np.ndarray or torch.tensor | |||||
| # Raw data in np.ndarray format or torch.tensor format. | |||||
| # K : int | |||||
| # Size of the construced reduced set. | |||||
| # """ | |||||
| # X = X.astype("float32") | |||||
| # numDim = X.shape[1] | |||||
| # kmeans = faiss.Kmeans(numDim, K, niter=100, verbose=False) | |||||
| # kmeans.train(X) | |||||
| # center = torch.from_numpy(kmeans.centroids).double() | |||||
| # self.z = center | |||||
| def _init_z_by_kmeans(self, X: Union[np.ndarray, torch.tensor], K: int): | def _init_z_by_kmeans(self, X: Union[np.ndarray, torch.tensor], K: int): | ||||
| """Intialize Z by kmeans clustering. | """Intialize Z by kmeans clustering. | ||||
| @@ -590,25 +559,8 @@ def rkme_solve_qp(K: np.ndarray, C: np.ndarray): | |||||
| A = np.array(np.ones((1, n))) | A = np.array(np.ones((1, n))) | ||||
| A = scipy.sparse.csc_matrix(A) | A = scipy.sparse.csc_matrix(A) | ||||
| b = np.array(np.ones((1, 1))) | b = np.array(np.ones((1, 1))) | ||||
| # sol = solve_qp(P, q, G, h, A, b, solver="clarabel") # Requires the sum of x to be 1 | |||||
| # sol = solver_qp(P, q, G, h, solver="clarabel") # Otherwise | |||||
| problem = Problem(P, q, G, h, A, b) | problem = Problem(P, q, G, h, A, b) | ||||
| solution = solve_problem(problem, solver="clarabel") | solution = solve_problem(problem, solver="clarabel") | ||||
| w = solution.x | w = solution.x | ||||
| w = torch.from_numpy(w).reshape(-1) | w = torch.from_numpy(w).reshape(-1) | ||||
| return w, solution.obj | return w, solution.obj | ||||
| # from cvxopt import solvers, matrix | |||||
| # n = K.shape[0] | |||||
| # P = matrix(K) | |||||
| # q = matrix(-C) | |||||
| # G = matrix(-np.eye(n)) | |||||
| # h = matrix(np.zeros((n, 1))) | |||||
| # A = matrix(np.ones((1, n))) | |||||
| # b = matrix(np.ones((1, 1))) | |||||
| # solvers.options["show_progress"] = False | |||||
| # sol = solvers.qp(P, q, G, h, A, b) # Requires the sum of x to be 1 | |||||
| # # sol = solvers.qp(P, q, G, h) # Otherwise | |||||
| # w = np.array(sol["x"]) | |||||
| # w = torch.from_numpy(w).reshape(-1) | |||||
| # return w, sol["primal objective"] | |||||