diff --git a/requirements.txt b/requirements.txt index e69de29..f5a499a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -0,0 +1,2 @@ +numpy==1.26.4 +matplotlib==3.9.0 \ No newline at end of file diff --git a/src/cluster.py b/src/cluster.py new file mode 100644 index 0000000..c5b894a --- /dev/null +++ b/src/cluster.py @@ -0,0 +1,125 @@ +from math import sqrt +from random import randrange +from typing import Any, List, Tuple + +import matplotlib.pyplot as plt +import numpy as np + + +class Cluster: + def __init__(self, centroid: Tuple[Any]): + self.centroid = centroid + self.color = np.random.rand(3,) + self.points = list() + self.dimension = len(self.centroid) + if self.dimension != 2 and self.dimension % 3 != 0: + raise ValueError('dimension must be 2 or a multiple of 3') + + def distance(self, point: Tuple[Any]) -> float: + sum = 0 + for dim, p in enumerate(self.centroid): + sum += (p - point[dim]) ** 2 + return sqrt(sum) + + def append(self, point: Tuple[Any]) -> None: + self.points.append(point) + + def draw(self, draw_centroid: bool = True, *args) -> None: + if len(self.points) == 0: + return + if self.dimension == 2: + if draw_centroid: + plt.scatter(*zip(*self.points), c=self.color), + plt.scatter( + self.centroid[0], self.centroid[1], c=self.color, edgecolors='k') + return + plt.scatter(*zip(*self.points), c=self.color) + elif self.dimension % 3 == 0: + unzipped = list(map(list, zip(*self.points))) + for i, ax in enumerate(args): + ax.scatter(*unzipped[i * 3: (i + 1) * 3]) + + def clear(self) -> None: + self.points = list() + + def calc_new_centroid(self) -> bool: + if len(self.points) == 0: + return False + arr = np.array(self.points) + new_centroid = tuple(np.average(arr, axis=0)) + if new_centroid == self.centroid: + return False + else: + self.clear() + self.centroid = new_centroid + return True + + +def generate_clusters(number_of_clusters: int, number_of_dimensions: int, random_range=0) -> List[Cluster]: + return [Cluster([randrange(random_range) for _ in range(number_of_dimensions)]) for _ in range(number_of_clusters)] + + +def k_means(points: Tuple[Any], clusters: List[Cluster], max_iter: int = 1e9): + for _ in range(int(max_iter)): + for point in points: + ind = 0 + best = 1e7 + for i, cluster in enumerate(clusters): + d = cluster.distance(point) + if d < best: + best = d + ind = i + clusters[ind].append(point) + progress = False + for cluster in clusters: + changed = cluster.calc_new_centroid() + if changed: + progress = True + if not progress: + break + +# TODO: Delete below code +def load_file(filepath: str): + points = list() + with open(filepath, 'r') as f: + lines = f.readlines() + for line in lines: + point = list(map(int, line.split())) + points.append(point) + return points + + +def main(dimension: str): + if dimension.lower() == '2d': + number_of_clusters = 15 + dim = 2 + random_range = int(1e6) + points = load_file("./s1.txt") + elif dimension.lower() == '9d': + number_of_clusters = 2 + dim = 9 + random_range = int(10) + points = load_file("./breast.txt") + else: + raise ValueError('dimension must be 2d or 9d') + + clusters = generate_clusters(number_of_clusters, dim, random_range) + + k_means(points, clusters) + + if dimension.lower() == '9d': + fig = plt.figure(figsize=(20, 10)) + ax1 = fig.add_subplot(131, projection='3d') + ax2 = fig.add_subplot(132, projection='3d') + ax3 = fig.add_subplot(133, projection='3d') + for cluster in clusters: + cluster.draw(False, ax1, ax2, ax3) + else: + for cluster in clusters: + cluster.draw() + plt.show() + + +if __name__ == '__main__': + main('2d') + # main('9d') diff --git a/src/cluster_class.py b/src/cluster_class.py deleted file mode 100644 index d2af583..0000000 --- a/src/cluster_class.py +++ /dev/null @@ -1,141 +0,0 @@ -import os -import matplotlib.pyplot as plt -import numpy as np -from math import sqrt -from random import randrange, random - -WORKDIR = './lab1' - - -class Cluster: - def __init__(self, centroid, color): - self.centroid = centroid - self.color = color - self.points = list() - - def distance(self, point): - sum = 0 - for dim, p in enumerate(self.centroid): - sum += (p - point[dim]) ** 2 - return sqrt(sum) - - def append(self, point): - self.points.append(point) - - def draw(self, with_centroids=True, ax1=None, ax2=None, ax3=None): - if len(self.points) != 0: - if len(self.points[0]) > 2: - unzipped = list(map(list, zip(*self.points))) - ax1.scatter(unzipped[0], unzipped[1], unzipped[2]) - ax2.scatter(unzipped[3], unzipped[4], unzipped[5]) - ax3.scatter(unzipped[6], unzipped[7], unzipped[8]) - else: - plt.scatter(*zip(*self.points), c=self.color) - if with_centroids: - plt.scatter( - self.centroid[0], self.centroid[1], c=self.color, edgecolors='k') - - def clear(self): - self.points = [] - - def calc_new_centroid(self): - if len(self.points) == 0: - return False - arr = np.array(self.points) - new_centroid = np.average(arr, axis=0).tolist() - if new_centroid == self.centroid: - return False - else: - self.clear() - self.centroid = new_centroid - return True - - -def generate_clusters(number_of_clusters: int, number_of_dimensions: int, random_range=0, with_centroids: bool = True): - clusters = list() - if with_centroids: - for _ in range(number_of_clusters): - clusters.append(Cluster([randrange(random_range) - for i in range(number_of_dimensions)], [random(), random(), random()])) - else: - for _ in range(number_of_clusters): - clusters.append(Cluster([-1 - for i in range(number_of_dimensions)], [random(), random(), random()])) - return clusters - - -def k_mean(points, clusters): - while True: - for point in points: - ind = 0 - best = 1e7 - for i, cluster in enumerate(clusters): - d = cluster.distance(point) - if d < best: - best = d - ind = i - clusters[ind].append(point) - progress = False - for cluster in clusters: - res = cluster.calc_new_centroid() - if res: - progress = True - if not progress: - break - - -def load_file(filepath: str): - points = list() - with open(filepath, 'r') as f: - lines = f.readlines() - for line in lines: - point = list(map(int, line.split())) - points.append(point) - return points - - -def draw_raw_points(points): - if len(points[0]) > 2: - pass - else: - plt.scatter(*zip(*points)) - plt.show() - - -def main(dimension: str): - if dimension.lower() == '2d': - number_of_clusters = 15 - dim = 2 - random_range = int(1e6) - points = load_file("./s1.txt") - elif dimension.lower() == '9d': - number_of_clusters = 2 - dim = 9 - random_range = int(10) - points = load_file("./breast.txt") - else: - raise ValueError('dimension must be 2d or 9d') - - clusters = generate_clusters(number_of_clusters, dim, random_range) - - k_mean(points, clusters) - - if dimension.lower() == '9d': - fig1 = plt.figure(1) - ax1 = fig1.add_subplot(projection='3d') - fig2 = plt.figure(2) - ax2 = fig2.add_subplot(projection='3d') - fig3 = plt.figure(3) - ax3 = fig3.add_subplot(projection='3d') - for cluster in clusters: - cluster.draw(ax1, ax2, ax3) - else: - for cluster in clusters: - cluster.draw() - plt.show() - - -if __name__ == '__main__': - os.chdir(WORKDIR) - main('2d') - main('9d')