init kmeans
This commit is contained in:
parent
e95992f0b5
commit
b6ef8ba0f7
@ -54,7 +54,7 @@ class Cluster:
|
||||
self.centroid = new_centroid
|
||||
return True
|
||||
|
||||
|
||||
# TODO: Delete code below after moving to kmeans.py
|
||||
def generate_clusters(number_of_clusters: int, number_of_dimensions: int, random_range=0) -> List[Cluster]:
|
||||
return [Cluster([randrange(random_range) for _ in range(number_of_dimensions)]) for _ in range(number_of_clusters)]
|
||||
|
||||
@ -78,7 +78,6 @@ def k_means(points: Tuple[Any], clusters: List[Cluster], max_iter: int = 1e9):
|
||||
if not progress:
|
||||
break
|
||||
|
||||
# TODO: Delete below code
|
||||
def load_file(filepath: str):
|
||||
points = list()
|
||||
with open(filepath, 'r') as f:
|
||||
|
@ -0,0 +1,39 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from cluster import Cluster
|
||||
|
||||
|
||||
class KMeans:
|
||||
def __init__(self, n_clusters: int, random_state: int = 0) -> None:
|
||||
self.n_clusters = n_clusters
|
||||
self.random_state = random_state
|
||||
self.clusters: List[Cluster] = list()
|
||||
self.labels: np.ndarray = None
|
||||
|
||||
def fit(self, X: np.ndarray) -> None:
|
||||
self.__generate_clusters(X)
|
||||
# self.labels = np.zeros(X.shape[0])
|
||||
|
||||
# while True:
|
||||
# new_labels = np.array([np.argmin(np.linalg.norm(X - centroid, axis=1)) for centroid in self.centroids])
|
||||
# if np.array_equal(self.labels, new_labels):
|
||||
# break
|
||||
# self.labels = new_labels
|
||||
|
||||
# for i in range(self.n_clusters):
|
||||
# self.centroids[i] = np.mean(X[self.labels == i], axis=0)
|
||||
|
||||
def __generate_clusters(self, X: np.ndarray) -> None:
|
||||
generator = np.random.default_rng(self.random_state)
|
||||
min_val, max_val = np.min(X), np.max(X)
|
||||
self.clusters = [Cluster(
|
||||
generator.uniform(min_val, max_val, X.shape[1])
|
||||
) for _ in range(self.n_clusters)]
|
||||
|
||||
|
||||
# TODO: Delete code below
|
||||
kmeans = KMeans(2)
|
||||
kmeans.fit(np.array([[1, 2], [3, 4], [5, 6], [7, 8]]))
|
||||
for cluster in kmeans.clusters:
|
||||
print(cluster.centroid)
|
Loading…
Reference in New Issue
Block a user