143 lines
5.4 KiB
Python
143 lines
5.4 KiB
Python
|
import numpy as np
|
||
|
import matplotlib.pyplot as plt
|
||
|
import random
|
||
|
from typing import List, Optional, Tuple, Dict
|
||
|
from itertools import combinations
|
||
|
|
||
|
class Node:
|
||
|
"""Represents a node in the graph."""
|
||
|
def __init__(self, node_id: int):
|
||
|
self.node_id = node_id
|
||
|
|
||
|
def __repr__(self):
|
||
|
return f"Node({self.node_id})"
|
||
|
|
||
|
class Graph:
|
||
|
"""Represents a graph with nodes and edges."""
|
||
|
def __init__(self, seed: int, num_nodes: int, edge_ratio: Optional[float] = 1.5):
|
||
|
self.seed = seed
|
||
|
self.num_nodes = num_nodes
|
||
|
self.edge_ratio = edge_ratio
|
||
|
self.nodes = [Node(i) for i in range(num_nodes)]
|
||
|
self.edges = []
|
||
|
self._generate_graph()
|
||
|
|
||
|
def _generate_graph(self):
|
||
|
random.seed(self.seed)
|
||
|
num_edges = int(self.edge_ratio * self.num_nodes)
|
||
|
|
||
|
# List of all possible edges
|
||
|
possible_edges = list(combinations(range(self.num_nodes), 2))
|
||
|
random.shuffle(possible_edges)
|
||
|
|
||
|
self.edges = possible_edges[:num_edges]
|
||
|
|
||
|
def get_neighbors(self, node_id: int) -> List[int]:
|
||
|
"""Returns a list of neighbors for a given node."""
|
||
|
neighbors = []
|
||
|
for edge in self.edges:
|
||
|
if edge[0] == node_id:
|
||
|
neighbors.append(edge[1])
|
||
|
elif edge[1] == node_id:
|
||
|
neighbors.append(edge[0])
|
||
|
return neighbors
|
||
|
|
||
|
def __repr__(self):
|
||
|
return f"Graph(num_nodes={self.num_nodes}, num_edges={len(self.edges)})"
|
||
|
|
||
|
class Cluster:
|
||
|
"""Represents clusters in a graph."""
|
||
|
def __init__(self, graph: Graph, num_clusters: int, num_walks: int = 100, walk_length: int = 10):
|
||
|
self.graph = graph
|
||
|
self.num_clusters = num_clusters
|
||
|
self.num_walks = num_walks
|
||
|
self.walk_length = walk_length
|
||
|
self.clusters = self._cluster_nodes()
|
||
|
self.critical_nodes = self._find_critical_nodes()
|
||
|
|
||
|
def _random_walk(self, start_node: int, walk_length: int) -> List[int]:
|
||
|
walk = [start_node]
|
||
|
for _ in range(walk_length - 1):
|
||
|
neighbors = self.graph.get_neighbors(walk[-1])
|
||
|
if not neighbors:
|
||
|
break
|
||
|
walk.append(random.choice(neighbors))
|
||
|
return walk
|
||
|
|
||
|
def _random_walks_matrix(self) -> np.ndarray:
|
||
|
visit_counts = np.zeros((self.graph.num_nodes, self.graph.num_nodes), dtype=int)
|
||
|
|
||
|
for node in range(self.graph.num_nodes):
|
||
|
for _ in range(self.num_walks):
|
||
|
walk = self._random_walk(node, self.walk_length)
|
||
|
for neighbor in walk:
|
||
|
visit_counts[node][neighbor] += 1
|
||
|
|
||
|
return visit_counts
|
||
|
|
||
|
def _cluster_nodes(self) -> List[int]:
|
||
|
visit_matrix = self._random_walks_matrix()
|
||
|
kmeans = KMeans(n_clusters=self.num_clusters, random_state=self.graph.seed)
|
||
|
kmeans.fit(visit_matrix)
|
||
|
return kmeans.labels_
|
||
|
|
||
|
def _find_critical_nodes(self) -> List[int]:
|
||
|
visit_matrix = self._random_walks_matrix()
|
||
|
initial_clusters = self._cluster_nodes()
|
||
|
critical_nodes = []
|
||
|
|
||
|
for node in range(self.graph.num_nodes):
|
||
|
temp_edges = [edge for edge in self.graph.edges if node not in edge]
|
||
|
temp_graph = Graph(self.graph.seed, self.graph.num_nodes)
|
||
|
temp_graph.edges = temp_edges
|
||
|
|
||
|
temp_cluster = Cluster(temp_graph, self.num_clusters, self.num_walks, self.walk_length)
|
||
|
new_clusters = temp_cluster.clusters
|
||
|
|
||
|
if not np.array_equal(initial_clusters, new_clusters):
|
||
|
critical_nodes.append(node)
|
||
|
|
||
|
return critical_nodes
|
||
|
|
||
|
def draw_graph_with_clusters(self):
|
||
|
pos = self._spring_layout()
|
||
|
plt.figure(figsize=(10, 7))
|
||
|
|
||
|
unique_clusters = np.unique(self.clusters)
|
||
|
colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_clusters)))
|
||
|
|
||
|
for cluster, color in zip(unique_clusters, colors):
|
||
|
nodes_in_cluster = [node for node, cluster_id in enumerate(self.clusters) if cluster_id == cluster]
|
||
|
plt.scatter([pos[node][0] for node in nodes_in_cluster], [pos[node][1] for node in nodes_in_cluster], color=color, s=100, label=f'Cluster {cluster}')
|
||
|
|
||
|
for edge in self.graph.edges:
|
||
|
x = [pos[edge[0]][0], pos[edge[1]][0]]
|
||
|
y = [pos[edge[0]][1], pos[edge[1]][1]]
|
||
|
plt.plot(x, y, color='gray', alpha=0.5)
|
||
|
|
||
|
critical_pos = {node: pos[node] for node in self.critical_nodes}
|
||
|
plt.scatter([critical_pos[node][0] for node in critical_pos], [critical_pos[node][1] for node in critical_pos], color='black', s=200, label='Critical Nodes')
|
||
|
|
||
|
for node, (x, y) in pos.items():
|
||
|
plt.text(x, y, str(node), fontsize=12, ha='center', va='center', color='white')
|
||
|
|
||
|
plt.title("Graph Clustering and Critical Nodes")
|
||
|
plt.legend()
|
||
|
plt.show()
|
||
|
|
||
|
def _spring_layout(self) -> Dict[int, Tuple[float, float]]:
|
||
|
"""Calculates the spring layout for the graph."""
|
||
|
pos = {i: (random.uniform(-1, 1), random.uniform(-1, 1)) for i in range(self.graph.num_nodes)}
|
||
|
return pos
|
||
|
|
||
|
# Przykładowe użycie
|
||
|
seed = 42
|
||
|
num_nodes = 30
|
||
|
edge_ratio = 2 # opcjonalny
|
||
|
|
||
|
graph = Graph(seed, num_nodes, edge_ratio)
|
||
|
cluster = Cluster(graph, num_clusters=3)
|
||
|
print("Klasteryzacja:", cluster.clusters)
|
||
|
print("Krytyczne wierzchołki:", cluster.critical_nodes)
|
||
|
cluster.draw_graph_with_clusters()
|