143 lines
5.4 KiB
143 lines
5.4 KiB
import numpy as np
import matplotlib.pyplot as plt
import random
from typing import List, Optional, Tuple, Dict
from itertools import combinations
class Node:
"""Represents a node in the graph."""
def __init__(self, node_id: int):
self.node_id = node_id
def __repr__(self):
return f"Node({self.node_id})"
class Graph:
"""Represents a graph with nodes and edges."""
def __init__(self, seed: int, num_nodes: int, edge_ratio: Optional[float] = 1.5):
self.seed = seed
self.num_nodes = num_nodes
self.edge_ratio = edge_ratio
self.nodes = [Node(i) for i in range(num_nodes)]
self.edges = []
def _generate_graph(self):
num_edges = int(self.edge_ratio * self.num_nodes)
# List of all possible edges
possible_edges = list(combinations(range(self.num_nodes), 2))
self.edges = possible_edges[:num_edges]
def get_neighbors(self, node_id: int) -> List[int]:
"""Returns a list of neighbors for a given node."""
neighbors = []
for edge in self.edges:
if edge[0] == node_id:
elif edge[1] == node_id:
return neighbors
def __repr__(self):
return f"Graph(num_nodes={self.num_nodes}, num_edges={len(self.edges)})"
class Cluster:
"""Represents clusters in a graph."""
def __init__(self, graph: Graph, num_clusters: int, num_walks: int = 100, walk_length: int = 10):
self.graph = graph
self.num_clusters = num_clusters
self.num_walks = num_walks
self.walk_length = walk_length
self.clusters = self._cluster_nodes()
self.critical_nodes = self._find_critical_nodes()
def _random_walk(self, start_node: int, walk_length: int) -> List[int]:
walk = [start_node]
for _ in range(walk_length - 1):
neighbors = self.graph.get_neighbors(walk[-1])
if not neighbors:
return walk
def _random_walks_matrix(self) -> np.ndarray:
visit_counts = np.zeros((self.graph.num_nodes, self.graph.num_nodes), dtype=int)
for node in range(self.graph.num_nodes):
for _ in range(self.num_walks):
walk = self._random_walk(node, self.walk_length)
for neighbor in walk:
visit_counts[node][neighbor] += 1
return visit_counts
def _cluster_nodes(self) -> List[int]:
visit_matrix = self._random_walks_matrix()
kmeans = KMeans(n_clusters=self.num_clusters, random_state=self.graph.seed)
return kmeans.labels_
def _find_critical_nodes(self) -> List[int]:
visit_matrix = self._random_walks_matrix()
initial_clusters = self._cluster_nodes()
critical_nodes = []
for node in range(self.graph.num_nodes):
temp_edges = [edge for edge in self.graph.edges if node not in edge]
temp_graph = Graph(self.graph.seed, self.graph.num_nodes)
temp_graph.edges = temp_edges
temp_cluster = Cluster(temp_graph, self.num_clusters, self.num_walks, self.walk_length)
new_clusters = temp_cluster.clusters
if not np.array_equal(initial_clusters, new_clusters):
return critical_nodes
def draw_graph_with_clusters(self):
pos = self._spring_layout()
plt.figure(figsize=(10, 7))
unique_clusters = np.unique(self.clusters)
colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_clusters)))
for cluster, color in zip(unique_clusters, colors):
nodes_in_cluster = [node for node, cluster_id in enumerate(self.clusters) if cluster_id == cluster]
plt.scatter([pos[node][0] for node in nodes_in_cluster], [pos[node][1] for node in nodes_in_cluster], color=color, s=100, label=f'Cluster {cluster}')
for edge in self.graph.edges:
x = [pos[edge[0]][0], pos[edge[1]][0]]
y = [pos[edge[0]][1], pos[edge[1]][1]]
plt.plot(x, y, color='gray', alpha=0.5)
critical_pos = {node: pos[node] for node in self.critical_nodes}
plt.scatter([critical_pos[node][0] for node in critical_pos], [critical_pos[node][1] for node in critical_pos], color='black', s=200, label='Critical Nodes')
for node, (x, y) in pos.items():
plt.text(x, y, str(node), fontsize=12, ha='center', va='center', color='white')
plt.title("Graph Clustering and Critical Nodes")
def _spring_layout(self) -> Dict[int, Tuple[float, float]]:
"""Calculates the spring layout for the graph."""
pos = {i: (random.uniform(-1, 1), random.uniform(-1, 1)) for i in range(self.graph.num_nodes)}
return pos
# Przykładowe użycie
seed = 42
num_nodes = 30
edge_ratio = 2 # opcjonalny
graph = Graph(seed, num_nodes, edge_ratio)
cluster = Cluster(graph, num_clusters=3)
print("Klasteryzacja:", cluster.clusters)
print("Krytyczne wierzchołki:", cluster.critical_nodes)