|
|
|
@@ -1,17 +1,15 @@ |
|
|
|
from __future__ import annotations |
|
|
|
|
|
|
|
import os |
|
|
|
import copy |
|
|
|
import torch |
|
|
|
import json |
|
|
|
import codecs |
|
|
|
import random |
|
|
|
import scipy |
|
|
|
import numpy as np |
|
|
|
from qpsolvers import solve_qp, Problem, solve_problem |
|
|
|
from collections import Counter |
|
|
|
from typing import Tuple, Any, List, Union, Dict |
|
|
|
import scipy |
|
|
|
from sklearn.cluster import MiniBatchKMeans |
|
|
|
from typing import Any, Union |
|
|
|
from fast_pytorch_kmeans import KMeans |
|
|
|
|
|
|
|
from ..base import RegularStatSpecification |
|
|
|
from ....logger import get_module_logger |
|
|
|
@@ -140,11 +138,14 @@ class RKMETableSpecification(RegularStatSpecification): |
|
|
|
K : int |
|
|
|
Size of the construced reduced set. |
|
|
|
""" |
|
|
|
X = X.astype("float32") |
|
|
|
kmeans = MiniBatchKMeans(n_clusters=K, max_iter=100, verbose=False, n_init="auto") |
|
|
|
if isinstance(X, np.ndarray): |
|
|
|
X = X.astype("float32") |
|
|
|
X = torch.from_numpy(X) |
|
|
|
|
|
|
|
X = X.to(self._device) |
|
|
|
kmeans = KMeans(n_clusters=K, mode='euclidean', max_iter=100, verbose=0) |
|
|
|
kmeans.fit(X) |
|
|
|
center = torch.from_numpy(kmeans.cluster_centers_).double() |
|
|
|
self.z = center |
|
|
|
self.z = kmeans.centroids.double() |
|
|
|
|
|
|
|
def _update_beta(self, X: Any, nonnegative_beta: bool = True): |
|
|
|
"""Fix Z and update beta using its closed-form solution. |
|
|
|
|