Browse Source

[MNT] remove redundant codes

tags/v0.3.2
nju-xy 2 years ago
parent
commit
bc98d0a5dc
1 changed files with 1 additions and 49 deletions
  1. +1
    -49
      learnware/specification/regular/table/rkme.py

+ 1
- 49
learnware/specification/regular/table/rkme.py View File

@@ -15,25 +15,11 @@ from typing import Tuple, Any, List, Union, Dict
import scipy
from sklearn.cluster import MiniBatchKMeans

# try:
# import faiss

# ver = faiss.__version__
# _FAISS_INSTALLED = ver >= "1.7.1"
# except ImportError:
# _FAISS_INSTALLED = False

from ..base import RegularStatsSpecification
from ....logger import get_module_logger

logger = get_module_logger("rkme")

# if not _FAISS_INSTALLED:
# logger.warning(
# "Required faiss version >= 1.7.1 is not detected! Please run 'conda install -c pytorch faiss-cpu' first"
# )


class RKMETableSpecification(RegularStatsSpecification):
"""Reduced Kernel Mean Embedding (RKME) Specification"""

@@ -128,7 +114,7 @@ class RKMETableSpecification(RegularStatsSpecification):
self.beta = torch.from_numpy(self.beta).double().to(self.device)
return

# Initialize Z by clustering, utiliing kmeans or faiss to speed up the process.
# Initialize Z by clustering, utiliing kmeans to speed up the process.
self._init_z_by_kmeans(X, K)
self._update_beta(X, nonnegative_beta)

@@ -140,23 +126,6 @@ class RKMETableSpecification(RegularStatsSpecification):
# Reshape to original dimensions
self.z = self.z.reshape(Z_shape)

# def _init_z_by_faiss(self, X: Union[np.ndarray, torch.tensor], K: int):
# """Intialize Z by faiss clustering.

# Parameters
# ----------
# X : np.ndarray or torch.tensor
# Raw data in np.ndarray format or torch.tensor format.
# K : int
# Size of the construced reduced set.
# """
# X = X.astype("float32")
# numDim = X.shape[1]
# kmeans = faiss.Kmeans(numDim, K, niter=100, verbose=False)
# kmeans.train(X)
# center = torch.from_numpy(kmeans.centroids).double()
# self.z = center

def _init_z_by_kmeans(self, X: Union[np.ndarray, torch.tensor], K: int):
"""Intialize Z by kmeans clustering.

@@ -590,25 +559,8 @@ def rkme_solve_qp(K: np.ndarray, C: np.ndarray):
A = np.array(np.ones((1, n)))
A = scipy.sparse.csc_matrix(A)
b = np.array(np.ones((1, 1)))
# sol = solve_qp(P, q, G, h, A, b, solver="clarabel") # Requires the sum of x to be 1
# sol = solver_qp(P, q, G, h, solver="clarabel") # Otherwise
problem = Problem(P, q, G, h, A, b)
solution = solve_problem(problem, solver="clarabel")
w = solution.x
w = torch.from_numpy(w).reshape(-1)
return w, solution.obj

# from cvxopt import solvers, matrix
# n = K.shape[0]
# P = matrix(K)
# q = matrix(-C)
# G = matrix(-np.eye(n))
# h = matrix(np.zeros((n, 1)))
# A = matrix(np.ones((1, n)))
# b = matrix(np.ones((1, 1)))
# solvers.options["show_progress"] = False
# sol = solvers.qp(P, q, G, h, A, b) # Requires the sum of x to be 1
# # sol = solvers.qp(P, q, G, h) # Otherwise
# w = np.array(sol["x"])
# w = torch.from_numpy(w).reshape(-1)
# return w, sol["primal objective"]

Loading…
Cancel
Save