matma-proj-11/clustering/kmeans.py

39 lines
1.3 KiB
Python
Raw Permalink Normal View History

2024-06-05 01:38:52 +02:00
from typing import List
import numpy as np
from cluster import Cluster
class KMeans:
def __init__(self, n_clusters: int, random_state: int = 0) -> None:
self.n_clusters = n_clusters
self.random_state = random_state
self.clusters: List[Cluster] = list()
self.labels: np.ndarray = None
def fit(self, X: np.ndarray) -> None:
self.__generate_clusters(X)
# self.labels = np.zeros(X.shape[0])
# while True:
# new_labels = np.array([np.argmin(np.linalg.norm(X - centroid, axis=1)) for centroid in self.centroids])
# if np.array_equal(self.labels, new_labels):
# break
# self.labels = new_labels
# for i in range(self.n_clusters):
# self.centroids[i] = np.mean(X[self.labels == i], axis=0)
def __generate_clusters(self, X: np.ndarray) -> None:
generator = np.random.default_rng(self.random_state)
min_val, max_val = np.min(X), np.max(X)
self.clusters = [Cluster(
generator.uniform(min_val, max_val, X.shape[1])
) for _ in range(self.n_clusters)]
# TODO: Delete code below
kmeans = KMeans(2)
kmeans.fit(np.array([[1, 2], [3, 4], [5, 6], [7, 8]]))
for cluster in kmeans.clusters:
print(cluster.centroid)