From b06d1771db6f57c2873d47433f2b41906d8cb55f Mon Sep 17 00:00:00 2001 From: mszmyd Date: Tue, 4 Jun 2024 20:39:28 +0200 Subject: [PATCH] add cluster --- src/cluster_class.py | 141 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 src/cluster_class.py diff --git a/src/cluster_class.py b/src/cluster_class.py new file mode 100644 index 0000000..d2af583 --- /dev/null +++ b/src/cluster_class.py @@ -0,0 +1,141 @@ +import os +import matplotlib.pyplot as plt +import numpy as np +from math import sqrt +from random import randrange, random + +WORKDIR = './lab1' + + +class Cluster: + def __init__(self, centroid, color): + self.centroid = centroid + self.color = color + self.points = list() + + def distance(self, point): + sum = 0 + for dim, p in enumerate(self.centroid): + sum += (p - point[dim]) ** 2 + return sqrt(sum) + + def append(self, point): + self.points.append(point) + + def draw(self, with_centroids=True, ax1=None, ax2=None, ax3=None): + if len(self.points) != 0: + if len(self.points[0]) > 2: + unzipped = list(map(list, zip(*self.points))) + ax1.scatter(unzipped[0], unzipped[1], unzipped[2]) + ax2.scatter(unzipped[3], unzipped[4], unzipped[5]) + ax3.scatter(unzipped[6], unzipped[7], unzipped[8]) + else: + plt.scatter(*zip(*self.points), c=self.color) + if with_centroids: + plt.scatter( + self.centroid[0], self.centroid[1], c=self.color, edgecolors='k') + + def clear(self): + self.points = [] + + def calc_new_centroid(self): + if len(self.points) == 0: + return False + arr = np.array(self.points) + new_centroid = np.average(arr, axis=0).tolist() + if new_centroid == self.centroid: + return False + else: + self.clear() + self.centroid = new_centroid + return True + + +def generate_clusters(number_of_clusters: int, number_of_dimensions: int, random_range=0, with_centroids: bool = True): + clusters = list() + if with_centroids: + for _ in range(number_of_clusters): + clusters.append(Cluster([randrange(random_range) + for i in range(number_of_dimensions)], [random(), random(), random()])) + else: + for _ in range(number_of_clusters): + clusters.append(Cluster([-1 + for i in range(number_of_dimensions)], [random(), random(), random()])) + return clusters + + +def k_mean(points, clusters): + while True: + for point in points: + ind = 0 + best = 1e7 + for i, cluster in enumerate(clusters): + d = cluster.distance(point) + if d < best: + best = d + ind = i + clusters[ind].append(point) + progress = False + for cluster in clusters: + res = cluster.calc_new_centroid() + if res: + progress = True + if not progress: + break + + +def load_file(filepath: str): + points = list() + with open(filepath, 'r') as f: + lines = f.readlines() + for line in lines: + point = list(map(int, line.split())) + points.append(point) + return points + + +def draw_raw_points(points): + if len(points[0]) > 2: + pass + else: + plt.scatter(*zip(*points)) + plt.show() + + +def main(dimension: str): + if dimension.lower() == '2d': + number_of_clusters = 15 + dim = 2 + random_range = int(1e6) + points = load_file("./s1.txt") + elif dimension.lower() == '9d': + number_of_clusters = 2 + dim = 9 + random_range = int(10) + points = load_file("./breast.txt") + else: + raise ValueError('dimension must be 2d or 9d') + + clusters = generate_clusters(number_of_clusters, dim, random_range) + + k_mean(points, clusters) + + if dimension.lower() == '9d': + fig1 = plt.figure(1) + ax1 = fig1.add_subplot(projection='3d') + fig2 = plt.figure(2) + ax2 = fig2.add_subplot(projection='3d') + fig3 = plt.figure(3) + ax3 = fig3.add_subplot(projection='3d') + for cluster in clusters: + cluster.draw(ax1, ax2, ax3) + else: + for cluster in clusters: + cluster.draw() + plt.show() + + +if __name__ == '__main__': + os.chdir(WORKDIR) + main('2d') + main('9d')