from math import sqrt from random import randrange from typing import Any, List, Tuple import matplotlib.pyplot as plt import numpy as np class Cluster: def __init__(self, centroid: Tuple[Any]): self.centroid = centroid self.color = np.random.rand(3,) self.points = list() self.dimension = len(self.centroid) if self.dimension != 2 and self.dimension % 3 != 0: raise ValueError('dimension must be 2 or a multiple of 3') def distance(self, point: Tuple[Any]) -> float: sum = 0 for dim, p in enumerate(self.centroid): sum += (p - point[dim]) ** 2 return sqrt(sum) def append(self, point: Tuple[Any]) -> None: self.points.append(point) def draw(self, draw_centroid: bool = True, *args) -> None: if len(self.points) == 0: return if self.dimension == 2: if draw_centroid: plt.scatter(*zip(*self.points), c=self.color), plt.scatter( self.centroid[0], self.centroid[1], c=self.color, edgecolors='k') return plt.scatter(*zip(*self.points), c=self.color) elif self.dimension % 3 == 0: unzipped = list(map(list, zip(*self.points))) for i, ax in enumerate(args): ax.scatter(*unzipped[i * 3: (i + 1) * 3]) def clear(self) -> None: self.points = list() def calc_new_centroid(self) -> bool: if len(self.points) == 0: return False arr = np.array(self.points) new_centroid = tuple(np.average(arr, axis=0)) if new_centroid == self.centroid: return False else: self.clear() self.centroid = new_centroid return True # TODO: Delete code below after moving to kmeans.py def generate_clusters(number_of_clusters: int, number_of_dimensions: int, random_range=0) -> List[Cluster]: return [Cluster([randrange(random_range) for _ in range(number_of_dimensions)]) for _ in range(number_of_clusters)] def k_means(points: Tuple[Any], clusters: List[Cluster], max_iter: int = 1e9): for _ in range(int(max_iter)): for point in points: ind = 0 best = 1e7 for i, cluster in enumerate(clusters): d = cluster.distance(point) if d < best: best = d ind = i clusters[ind].append(point) progress = False for cluster in clusters: changed = cluster.calc_new_centroid() if changed: progress = True if not progress: break def load_file(filepath: str): points = list() with open(filepath, 'r') as f: lines = f.readlines() for line in lines: point = list(map(int, line.split())) points.append(point) return points def main(dimension: str): if dimension.lower() == '2d': number_of_clusters = 15 dim = 2 random_range = int(1e6) points = load_file("./s1.txt") elif dimension.lower() == '9d': number_of_clusters = 2 dim = 9 random_range = int(10) points = load_file("./breast.txt") else: raise ValueError('dimension must be 2d or 9d') clusters = generate_clusters(number_of_clusters, dim, random_range) k_means(points, clusters) if dimension.lower() == '9d': fig = plt.figure(figsize=(20, 10)) ax1 = fig.add_subplot(131, projection='3d') ax2 = fig.add_subplot(132, projection='3d') ax3 = fig.add_subplot(133, projection='3d') for cluster in clusters: cluster.draw(False, ax1, ax2, ax3) else: for cluster in clusters: cluster.draw() plt.show() if __name__ == '__main__': main('2d') # main('9d')