diff --git a/clustering/cluster.py b/clustering/cluster.py index c5b894a..dcc10ab 100644 --- a/clustering/cluster.py +++ b/clustering/cluster.py @@ -54,7 +54,7 @@ class Cluster: self.centroid = new_centroid return True - +# TODO: Delete code below after moving to kmeans.py def generate_clusters(number_of_clusters: int, number_of_dimensions: int, random_range=0) -> List[Cluster]: return [Cluster([randrange(random_range) for _ in range(number_of_dimensions)]) for _ in range(number_of_clusters)] @@ -78,7 +78,6 @@ def k_means(points: Tuple[Any], clusters: List[Cluster], max_iter: int = 1e9): if not progress: break -# TODO: Delete below code def load_file(filepath: str): points = list() with open(filepath, 'r') as f: diff --git a/clustering/kmeans.py b/clustering/kmeans.py index e69de29..3ce5183 100644 --- a/clustering/kmeans.py +++ b/clustering/kmeans.py @@ -0,0 +1,39 @@ +from typing import List + +import numpy as np +from cluster import Cluster + + +class KMeans: + def __init__(self, n_clusters: int, random_state: int = 0) -> None: + self.n_clusters = n_clusters + self.random_state = random_state + self.clusters: List[Cluster] = list() + self.labels: np.ndarray = None + + def fit(self, X: np.ndarray) -> None: + self.__generate_clusters(X) + # self.labels = np.zeros(X.shape[0]) + + # while True: + # new_labels = np.array([np.argmin(np.linalg.norm(X - centroid, axis=1)) for centroid in self.centroids]) + # if np.array_equal(self.labels, new_labels): + # break + # self.labels = new_labels + + # for i in range(self.n_clusters): + # self.centroids[i] = np.mean(X[self.labels == i], axis=0) + + def __generate_clusters(self, X: np.ndarray) -> None: + generator = np.random.default_rng(self.random_state) + min_val, max_val = np.min(X), np.max(X) + self.clusters = [Cluster( + generator.uniform(min_val, max_val, X.shape[1]) + ) for _ in range(self.n_clusters)] + + +# TODO: Delete code below +kmeans = KMeans(2) +kmeans.fit(np.array([[1, 2], [3, 4], [5, 6], [7, 8]])) +for cluster in kmeans.clusters: + print(cluster.centroid) \ No newline at end of file