clean code
This commit is contained in:
parent
10c6742b86
commit
3e72d809ec
@ -0,0 +1,2 @@
|
|||||||
|
numpy==1.26.4
|
||||||
|
matplotlib==3.9.0
|
125
src/cluster.py
Normal file
125
src/cluster.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
from math import sqrt
|
||||||
|
from random import randrange
|
||||||
|
from typing import Any, List, Tuple
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class Cluster:
|
||||||
|
def __init__(self, centroid: Tuple[Any]):
|
||||||
|
self.centroid = centroid
|
||||||
|
self.color = np.random.rand(3,)
|
||||||
|
self.points = list()
|
||||||
|
self.dimension = len(self.centroid)
|
||||||
|
if self.dimension != 2 and self.dimension % 3 != 0:
|
||||||
|
raise ValueError('dimension must be 2 or a multiple of 3')
|
||||||
|
|
||||||
|
def distance(self, point: Tuple[Any]) -> float:
|
||||||
|
sum = 0
|
||||||
|
for dim, p in enumerate(self.centroid):
|
||||||
|
sum += (p - point[dim]) ** 2
|
||||||
|
return sqrt(sum)
|
||||||
|
|
||||||
|
def append(self, point: Tuple[Any]) -> None:
|
||||||
|
self.points.append(point)
|
||||||
|
|
||||||
|
def draw(self, draw_centroid: bool = True, *args) -> None:
|
||||||
|
if len(self.points) == 0:
|
||||||
|
return
|
||||||
|
if self.dimension == 2:
|
||||||
|
if draw_centroid:
|
||||||
|
plt.scatter(*zip(*self.points), c=self.color),
|
||||||
|
plt.scatter(
|
||||||
|
self.centroid[0], self.centroid[1], c=self.color, edgecolors='k')
|
||||||
|
return
|
||||||
|
plt.scatter(*zip(*self.points), c=self.color)
|
||||||
|
elif self.dimension % 3 == 0:
|
||||||
|
unzipped = list(map(list, zip(*self.points)))
|
||||||
|
for i, ax in enumerate(args):
|
||||||
|
ax.scatter(*unzipped[i * 3: (i + 1) * 3])
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
self.points = list()
|
||||||
|
|
||||||
|
def calc_new_centroid(self) -> bool:
|
||||||
|
if len(self.points) == 0:
|
||||||
|
return False
|
||||||
|
arr = np.array(self.points)
|
||||||
|
new_centroid = tuple(np.average(arr, axis=0))
|
||||||
|
if new_centroid == self.centroid:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
self.clear()
|
||||||
|
self.centroid = new_centroid
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def generate_clusters(number_of_clusters: int, number_of_dimensions: int, random_range=0) -> List[Cluster]:
|
||||||
|
return [Cluster([randrange(random_range) for _ in range(number_of_dimensions)]) for _ in range(number_of_clusters)]
|
||||||
|
|
||||||
|
|
||||||
|
def k_means(points: Tuple[Any], clusters: List[Cluster], max_iter: int = 1e9):
|
||||||
|
for _ in range(int(max_iter)):
|
||||||
|
for point in points:
|
||||||
|
ind = 0
|
||||||
|
best = 1e7
|
||||||
|
for i, cluster in enumerate(clusters):
|
||||||
|
d = cluster.distance(point)
|
||||||
|
if d < best:
|
||||||
|
best = d
|
||||||
|
ind = i
|
||||||
|
clusters[ind].append(point)
|
||||||
|
progress = False
|
||||||
|
for cluster in clusters:
|
||||||
|
changed = cluster.calc_new_centroid()
|
||||||
|
if changed:
|
||||||
|
progress = True
|
||||||
|
if not progress:
|
||||||
|
break
|
||||||
|
|
||||||
|
# TODO: Delete below code
|
||||||
|
def load_file(filepath: str):
|
||||||
|
points = list()
|
||||||
|
with open(filepath, 'r') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
for line in lines:
|
||||||
|
point = list(map(int, line.split()))
|
||||||
|
points.append(point)
|
||||||
|
return points
|
||||||
|
|
||||||
|
|
||||||
|
def main(dimension: str):
|
||||||
|
if dimension.lower() == '2d':
|
||||||
|
number_of_clusters = 15
|
||||||
|
dim = 2
|
||||||
|
random_range = int(1e6)
|
||||||
|
points = load_file("./s1.txt")
|
||||||
|
elif dimension.lower() == '9d':
|
||||||
|
number_of_clusters = 2
|
||||||
|
dim = 9
|
||||||
|
random_range = int(10)
|
||||||
|
points = load_file("./breast.txt")
|
||||||
|
else:
|
||||||
|
raise ValueError('dimension must be 2d or 9d')
|
||||||
|
|
||||||
|
clusters = generate_clusters(number_of_clusters, dim, random_range)
|
||||||
|
|
||||||
|
k_means(points, clusters)
|
||||||
|
|
||||||
|
if dimension.lower() == '9d':
|
||||||
|
fig = plt.figure(figsize=(20, 10))
|
||||||
|
ax1 = fig.add_subplot(131, projection='3d')
|
||||||
|
ax2 = fig.add_subplot(132, projection='3d')
|
||||||
|
ax3 = fig.add_subplot(133, projection='3d')
|
||||||
|
for cluster in clusters:
|
||||||
|
cluster.draw(False, ax1, ax2, ax3)
|
||||||
|
else:
|
||||||
|
for cluster in clusters:
|
||||||
|
cluster.draw()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main('2d')
|
||||||
|
# main('9d')
|
@ -1,141 +0,0 @@
|
|||||||
import os
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
from math import sqrt
|
|
||||||
from random import randrange, random
|
|
||||||
|
|
||||||
WORKDIR = './lab1'
|
|
||||||
|
|
||||||
|
|
||||||
class Cluster:
|
|
||||||
def __init__(self, centroid, color):
|
|
||||||
self.centroid = centroid
|
|
||||||
self.color = color
|
|
||||||
self.points = list()
|
|
||||||
|
|
||||||
def distance(self, point):
|
|
||||||
sum = 0
|
|
||||||
for dim, p in enumerate(self.centroid):
|
|
||||||
sum += (p - point[dim]) ** 2
|
|
||||||
return sqrt(sum)
|
|
||||||
|
|
||||||
def append(self, point):
|
|
||||||
self.points.append(point)
|
|
||||||
|
|
||||||
def draw(self, with_centroids=True, ax1=None, ax2=None, ax3=None):
|
|
||||||
if len(self.points) != 0:
|
|
||||||
if len(self.points[0]) > 2:
|
|
||||||
unzipped = list(map(list, zip(*self.points)))
|
|
||||||
ax1.scatter(unzipped[0], unzipped[1], unzipped[2])
|
|
||||||
ax2.scatter(unzipped[3], unzipped[4], unzipped[5])
|
|
||||||
ax3.scatter(unzipped[6], unzipped[7], unzipped[8])
|
|
||||||
else:
|
|
||||||
plt.scatter(*zip(*self.points), c=self.color)
|
|
||||||
if with_centroids:
|
|
||||||
plt.scatter(
|
|
||||||
self.centroid[0], self.centroid[1], c=self.color, edgecolors='k')
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
self.points = []
|
|
||||||
|
|
||||||
def calc_new_centroid(self):
|
|
||||||
if len(self.points) == 0:
|
|
||||||
return False
|
|
||||||
arr = np.array(self.points)
|
|
||||||
new_centroid = np.average(arr, axis=0).tolist()
|
|
||||||
if new_centroid == self.centroid:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
self.clear()
|
|
||||||
self.centroid = new_centroid
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def generate_clusters(number_of_clusters: int, number_of_dimensions: int, random_range=0, with_centroids: bool = True):
|
|
||||||
clusters = list()
|
|
||||||
if with_centroids:
|
|
||||||
for _ in range(number_of_clusters):
|
|
||||||
clusters.append(Cluster([randrange(random_range)
|
|
||||||
for i in range(number_of_dimensions)], [random(), random(), random()]))
|
|
||||||
else:
|
|
||||||
for _ in range(number_of_clusters):
|
|
||||||
clusters.append(Cluster([-1
|
|
||||||
for i in range(number_of_dimensions)], [random(), random(), random()]))
|
|
||||||
return clusters
|
|
||||||
|
|
||||||
|
|
||||||
def k_mean(points, clusters):
|
|
||||||
while True:
|
|
||||||
for point in points:
|
|
||||||
ind = 0
|
|
||||||
best = 1e7
|
|
||||||
for i, cluster in enumerate(clusters):
|
|
||||||
d = cluster.distance(point)
|
|
||||||
if d < best:
|
|
||||||
best = d
|
|
||||||
ind = i
|
|
||||||
clusters[ind].append(point)
|
|
||||||
progress = False
|
|
||||||
for cluster in clusters:
|
|
||||||
res = cluster.calc_new_centroid()
|
|
||||||
if res:
|
|
||||||
progress = True
|
|
||||||
if not progress:
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
def load_file(filepath: str):
|
|
||||||
points = list()
|
|
||||||
with open(filepath, 'r') as f:
|
|
||||||
lines = f.readlines()
|
|
||||||
for line in lines:
|
|
||||||
point = list(map(int, line.split()))
|
|
||||||
points.append(point)
|
|
||||||
return points
|
|
||||||
|
|
||||||
|
|
||||||
def draw_raw_points(points):
|
|
||||||
if len(points[0]) > 2:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
plt.scatter(*zip(*points))
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
def main(dimension: str):
|
|
||||||
if dimension.lower() == '2d':
|
|
||||||
number_of_clusters = 15
|
|
||||||
dim = 2
|
|
||||||
random_range = int(1e6)
|
|
||||||
points = load_file("./s1.txt")
|
|
||||||
elif dimension.lower() == '9d':
|
|
||||||
number_of_clusters = 2
|
|
||||||
dim = 9
|
|
||||||
random_range = int(10)
|
|
||||||
points = load_file("./breast.txt")
|
|
||||||
else:
|
|
||||||
raise ValueError('dimension must be 2d or 9d')
|
|
||||||
|
|
||||||
clusters = generate_clusters(number_of_clusters, dim, random_range)
|
|
||||||
|
|
||||||
k_mean(points, clusters)
|
|
||||||
|
|
||||||
if dimension.lower() == '9d':
|
|
||||||
fig1 = plt.figure(1)
|
|
||||||
ax1 = fig1.add_subplot(projection='3d')
|
|
||||||
fig2 = plt.figure(2)
|
|
||||||
ax2 = fig2.add_subplot(projection='3d')
|
|
||||||
fig3 = plt.figure(3)
|
|
||||||
ax3 = fig3.add_subplot(projection='3d')
|
|
||||||
for cluster in clusters:
|
|
||||||
cluster.draw(ax1, ax2, ax3)
|
|
||||||
else:
|
|
||||||
for cluster in clusters:
|
|
||||||
cluster.draw()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
os.chdir(WORKDIR)
|
|
||||||
main('2d')
|
|
||||||
main('9d')
|
|
Loading…
Reference in New Issue
Block a user