add cluster

This commit is contained in:
mszmyd 2024-06-04 20:39:28 +02:00
parent 901303c4d7
commit b06d1771db

141
src/cluster_class.py Normal file
View File

@ -0,0 +1,141 @@
import os
import matplotlib.pyplot as plt
import numpy as np
from math import sqrt
from random import randrange, random
WORKDIR = './lab1'
class Cluster:
def __init__(self, centroid, color):
self.centroid = centroid
self.color = color
self.points = list()
def distance(self, point):
sum = 0
for dim, p in enumerate(self.centroid):
sum += (p - point[dim]) ** 2
return sqrt(sum)
def append(self, point):
self.points.append(point)
def draw(self, with_centroids=True, ax1=None, ax2=None, ax3=None):
if len(self.points) != 0:
if len(self.points[0]) > 2:
unzipped = list(map(list, zip(*self.points)))
ax1.scatter(unzipped[0], unzipped[1], unzipped[2])
ax2.scatter(unzipped[3], unzipped[4], unzipped[5])
ax3.scatter(unzipped[6], unzipped[7], unzipped[8])
else:
plt.scatter(*zip(*self.points), c=self.color)
if with_centroids:
plt.scatter(
self.centroid[0], self.centroid[1], c=self.color, edgecolors='k')
def clear(self):
self.points = []
def calc_new_centroid(self):
if len(self.points) == 0:
return False
arr = np.array(self.points)
new_centroid = np.average(arr, axis=0).tolist()
if new_centroid == self.centroid:
return False
else:
self.clear()
self.centroid = new_centroid
return True
def generate_clusters(number_of_clusters: int, number_of_dimensions: int, random_range=0, with_centroids: bool = True):
clusters = list()
if with_centroids:
for _ in range(number_of_clusters):
clusters.append(Cluster([randrange(random_range)
for i in range(number_of_dimensions)], [random(), random(), random()]))
else:
for _ in range(number_of_clusters):
clusters.append(Cluster([-1
for i in range(number_of_dimensions)], [random(), random(), random()]))
return clusters
def k_mean(points, clusters):
while True:
for point in points:
ind = 0
best = 1e7
for i, cluster in enumerate(clusters):
d = cluster.distance(point)
if d < best:
best = d
ind = i
clusters[ind].append(point)
progress = False
for cluster in clusters:
res = cluster.calc_new_centroid()
if res:
progress = True
if not progress:
break
def load_file(filepath: str):
points = list()
with open(filepath, 'r') as f:
lines = f.readlines()
for line in lines:
point = list(map(int, line.split()))
points.append(point)
return points
def draw_raw_points(points):
if len(points[0]) > 2:
pass
else:
plt.scatter(*zip(*points))
plt.show()
def main(dimension: str):
if dimension.lower() == '2d':
number_of_clusters = 15
dim = 2
random_range = int(1e6)
points = load_file("./s1.txt")
elif dimension.lower() == '9d':
number_of_clusters = 2
dim = 9
random_range = int(10)
points = load_file("./breast.txt")
else:
raise ValueError('dimension must be 2d or 9d')
clusters = generate_clusters(number_of_clusters, dim, random_range)
k_mean(points, clusters)
if dimension.lower() == '9d':
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(projection='3d')
fig2 = plt.figure(2)
ax2 = fig2.add_subplot(projection='3d')
fig3 = plt.figure(3)
ax3 = fig3.add_subplot(projection='3d')
for cluster in clusters:
cluster.draw(ax1, ax2, ax3)
else:
for cluster in clusters:
cluster.draw()
plt.show()
if __name__ == '__main__':
os.chdir(WORKDIR)
main('2d')
main('9d')