Browse Source

[ENH] change kmeans to fast_pytorch_kmeans

tags/v0.3.2
GeneLiuXe 2 years ago
parent
commit
29bd9dcc5c
2 changed files with 11 additions and 9 deletions
  1. +10
    -9
      learnware/specification/regular/table/rkme.py
  2. +1
    -0
      setup.py

+ 10
- 9
learnware/specification/regular/table/rkme.py View File

@@ -1,17 +1,15 @@
from __future__ import annotations from __future__ import annotations


import os import os
import copy
import torch import torch
import json import json
import codecs import codecs
import random
import scipy
import numpy as np import numpy as np
from qpsolvers import solve_qp, Problem, solve_problem from qpsolvers import solve_qp, Problem, solve_problem
from collections import Counter from collections import Counter
from typing import Tuple, Any, List, Union, Dict
import scipy
from sklearn.cluster import MiniBatchKMeans
from typing import Any, Union
from fast_pytorch_kmeans import KMeans


from ..base import RegularStatSpecification from ..base import RegularStatSpecification
from ....logger import get_module_logger from ....logger import get_module_logger
@@ -140,11 +138,14 @@ class RKMETableSpecification(RegularStatSpecification):
K : int K : int
Size of the construced reduced set. Size of the construced reduced set.
""" """
X = X.astype("float32")
kmeans = MiniBatchKMeans(n_clusters=K, max_iter=100, verbose=False, n_init="auto")
if isinstance(X, np.ndarray):
X = X.astype("float32")
X = torch.from_numpy(X)
X = X.to(self._device)
kmeans = KMeans(n_clusters=K, mode='euclidean', max_iter=100, verbose=0)
kmeans.fit(X) kmeans.fit(X)
center = torch.from_numpy(kmeans.cluster_centers_).double()
self.z = center
self.z = kmeans.centroids.double()


def _update_beta(self, X: Any, nonnegative_beta: bool = True): def _update_beta(self, X: Any, nonnegative_beta: bool = True):
"""Fix Z and update beta using its closed-form solution. """Fix Z and update beta using its closed-form solution.


+ 1
- 0
setup.py View File

@@ -70,6 +70,7 @@ REQUIRED = [
"portalocker>=2.0.0", "portalocker>=2.0.0",
"qpsolvers[clarabel]>=4.0.1", "qpsolvers[clarabel]>=4.0.1",
"geatpy>=2.7.0;python_version<'3.11'", "geatpy>=2.7.0;python_version<'3.11'",
"fast_pytorch_kmeans>=0.2.0.1",
] ]


here = os.path.abspath(os.path.dirname(__file__)) here = os.path.abspath(os.path.dirname(__file__))


Loading…
Cancel
Save