Browse Source

[ENH] change kmeans to fast_pytorch_kmeans

tags/v0.3.2
GeneLiuXe 2 years ago
parent
commit
29bd9dcc5c
2 changed files with 11 additions and 9 deletions
  1. +10
    -9
      learnware/specification/regular/table/rkme.py
  2. +1
    -0
      setup.py

+ 10
- 9
learnware/specification/regular/table/rkme.py View File

@@ -1,17 +1,15 @@
from __future__ import annotations

import os
import copy
import torch
import json
import codecs
import random
import scipy
import numpy as np
from qpsolvers import solve_qp, Problem, solve_problem
from collections import Counter
from typing import Tuple, Any, List, Union, Dict
import scipy
from sklearn.cluster import MiniBatchKMeans
from typing import Any, Union
from fast_pytorch_kmeans import KMeans

from ..base import RegularStatSpecification
from ....logger import get_module_logger
@@ -140,11 +138,14 @@ class RKMETableSpecification(RegularStatSpecification):
K : int
Size of the construced reduced set.
"""
X = X.astype("float32")
kmeans = MiniBatchKMeans(n_clusters=K, max_iter=100, verbose=False, n_init="auto")
if isinstance(X, np.ndarray):
X = X.astype("float32")
X = torch.from_numpy(X)
X = X.to(self._device)
kmeans = KMeans(n_clusters=K, mode='euclidean', max_iter=100, verbose=0)
kmeans.fit(X)
center = torch.from_numpy(kmeans.cluster_centers_).double()
self.z = center
self.z = kmeans.centroids.double()

def _update_beta(self, X: Any, nonnegative_beta: bool = True):
"""Fix Z and update beta using its closed-form solution.


+ 1
- 0
setup.py View File

@@ -70,6 +70,7 @@ REQUIRED = [
"portalocker>=2.0.0",
"qpsolvers[clarabel]>=4.0.1",
"geatpy>=2.7.0;python_version<'3.11'",
"fast_pytorch_kmeans>=0.2.0.1",
]

here = os.path.abspath(os.path.dirname(__file__))


Loading…
Cancel
Save