matma-proj-11/graph_new.py

143 lines
5.4 KiB
Python
Raw Permalink Normal View History

2024-06-04 20:57:08 +02:00
import numpy as np
import matplotlib.pyplot as plt
import random
from typing import List, Optional, Tuple, Dict
from itertools import combinations
class Node:
"""Represents a node in the graph."""
def __init__(self, node_id: int):
self.node_id = node_id
def __repr__(self):
return f"Node({self.node_id})"
class Graph:
"""Represents a graph with nodes and edges."""
def __init__(self, seed: int, num_nodes: int, edge_ratio: Optional[float] = 1.5):
self.seed = seed
self.num_nodes = num_nodes
self.edge_ratio = edge_ratio
self.nodes = [Node(i) for i in range(num_nodes)]
self.edges = []
self._generate_graph()
def _generate_graph(self):
random.seed(self.seed)
num_edges = int(self.edge_ratio * self.num_nodes)
# List of all possible edges
possible_edges = list(combinations(range(self.num_nodes), 2))
random.shuffle(possible_edges)
self.edges = possible_edges[:num_edges]
def get_neighbors(self, node_id: int) -> List[int]:
"""Returns a list of neighbors for a given node."""
neighbors = []
for edge in self.edges:
if edge[0] == node_id:
neighbors.append(edge[1])
elif edge[1] == node_id:
neighbors.append(edge[0])
return neighbors
def __repr__(self):
return f"Graph(num_nodes={self.num_nodes}, num_edges={len(self.edges)})"
class Cluster:
"""Represents clusters in a graph."""
def __init__(self, graph: Graph, num_clusters: int, num_walks: int = 100, walk_length: int = 10):
self.graph = graph
self.num_clusters = num_clusters
self.num_walks = num_walks
self.walk_length = walk_length
self.clusters = self._cluster_nodes()
self.critical_nodes = self._find_critical_nodes()
def _random_walk(self, start_node: int, walk_length: int) -> List[int]:
walk = [start_node]
for _ in range(walk_length - 1):
neighbors = self.graph.get_neighbors(walk[-1])
if not neighbors:
break
walk.append(random.choice(neighbors))
return walk
def _random_walks_matrix(self) -> np.ndarray:
visit_counts = np.zeros((self.graph.num_nodes, self.graph.num_nodes), dtype=int)
for node in range(self.graph.num_nodes):
for _ in range(self.num_walks):
walk = self._random_walk(node, self.walk_length)
for neighbor in walk:
visit_counts[node][neighbor] += 1
return visit_counts
def _cluster_nodes(self) -> List[int]:
visit_matrix = self._random_walks_matrix()
kmeans = KMeans(n_clusters=self.num_clusters, random_state=self.graph.seed)
kmeans.fit(visit_matrix)
return kmeans.labels_
def _find_critical_nodes(self) -> List[int]:
visit_matrix = self._random_walks_matrix()
initial_clusters = self._cluster_nodes()
critical_nodes = []
for node in range(self.graph.num_nodes):
temp_edges = [edge for edge in self.graph.edges if node not in edge]
temp_graph = Graph(self.graph.seed, self.graph.num_nodes)
temp_graph.edges = temp_edges
temp_cluster = Cluster(temp_graph, self.num_clusters, self.num_walks, self.walk_length)
new_clusters = temp_cluster.clusters
if not np.array_equal(initial_clusters, new_clusters):
critical_nodes.append(node)
return critical_nodes
def draw_graph_with_clusters(self):
pos = self._spring_layout()
plt.figure(figsize=(10, 7))
unique_clusters = np.unique(self.clusters)
colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_clusters)))
for cluster, color in zip(unique_clusters, colors):
nodes_in_cluster = [node for node, cluster_id in enumerate(self.clusters) if cluster_id == cluster]
plt.scatter([pos[node][0] for node in nodes_in_cluster], [pos[node][1] for node in nodes_in_cluster], color=color, s=100, label=f'Cluster {cluster}')
for edge in self.graph.edges:
x = [pos[edge[0]][0], pos[edge[1]][0]]
y = [pos[edge[0]][1], pos[edge[1]][1]]
plt.plot(x, y, color='gray', alpha=0.5)
critical_pos = {node: pos[node] for node in self.critical_nodes}
plt.scatter([critical_pos[node][0] for node in critical_pos], [critical_pos[node][1] for node in critical_pos], color='black', s=200, label='Critical Nodes')
for node, (x, y) in pos.items():
plt.text(x, y, str(node), fontsize=12, ha='center', va='center', color='white')
plt.title("Graph Clustering and Critical Nodes")
plt.legend()
plt.show()
def _spring_layout(self) -> Dict[int, Tuple[float, float]]:
"""Calculates the spring layout for the graph."""
pos = {i: (random.uniform(-1, 1), random.uniform(-1, 1)) for i in range(self.graph.num_nodes)}
return pos
# Przykładowe użycie
seed = 42
num_nodes = 30
edge_ratio = 2 # opcjonalny
graph = Graph(seed, num_nodes, edge_ratio)
cluster = Cluster(graph, num_clusters=3)
print("Klasteryzacja:", cluster.clusters)
print("Krytyczne wierzchołki:", cluster.critical_nodes)
cluster.draw_graph_with_clusters()