matma-proj-11/graph_graphical.py

87 lines
2.9 KiB
Python
Raw Normal View History

2024-06-04 20:37:52 +02:00
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
def random_walk(graph, start_node, walk_length):
walk = [start_node]
for _ in range(walk_length - 1):
neighbors = list(graph.neighbors(walk[-1]))
if not neighbors:
break
walk.append(np.random.choice(neighbors))
return walk
def random_walks_matrix(graph, num_walks, walk_length):
nodes = list(graph.nodes)
visit_counts = {node: {neighbor: 0 for neighbor in nodes} for node in nodes}
for node in nodes:
for _ in range(num_walks):
walk = random_walk(graph, node, walk_length)
for neighbor in walk:
visit_counts[node][neighbor] += 1
visit_matrix = np.array([[visit_counts[node][neighbor] for neighbor in nodes] for node in nodes])
return visit_matrix
def cluster_nodes(visit_matrix, num_clusters):
kmeans = KMeans(n_clusters=num_clusters, random_state=0)
kmeans.fit(visit_matrix)
return kmeans.labels_
def critical_nodes(graph, visit_matrix, num_clusters):
initial_clusters = cluster_nodes(visit_matrix, num_clusters)
critical_nodes = []
for node in graph.nodes:
temp_graph = graph.copy()
temp_graph.remove_node(node)
temp_visit_matrix = random_walks_matrix(temp_graph, num_walks, walk_length)
new_clusters = cluster_nodes(temp_visit_matrix, num_clusters)
if not np.array_equal(initial_clusters, new_clusters):
critical_nodes.append(node)
return critical_nodes
def draw_graph_with_clusters(graph, clusters, critical_nodes):
pos = nx.spring_layout(graph)
plt.figure(figsize=(10, 7))
unique_clusters = np.unique(clusters)
colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_clusters)))
for cluster, color in zip(unique_clusters, colors):
nodes_in_cluster = [node for node, cluster_id in enumerate(clusters) if cluster_id == cluster]
nx.draw_networkx_nodes(graph, pos, nodelist=nodes_in_cluster, node_color=[color], node_size=300)
nx.draw_networkx_edges(graph, pos, alpha=0.5)
critical_pos = {node: pos[node] for node in critical_nodes}
nx.draw_networkx_nodes(graph, pos, nodelist=critical_nodes, node_color='black', node_size=100, label='Critical Nodes')
labels = {node: node for node in graph.nodes}
nx.draw_networkx_labels(graph, pos, labels, font_size=8)
plt.title("Graph Clustering and Critical Nodes")
plt.legend()
plt.show()
# Przykładowa użycie
num_walks = 100
walk_length = 10
num_clusters = 3
# Dla grafu nieskierowanego
G = nx.karate_club_graph()
visit_matrix = random_walks_matrix(G, num_walks, walk_length)
clusters = cluster_nodes(visit_matrix, num_clusters)
critical = critical_nodes(G, visit_matrix, num_clusters)
print("Klasteryzacja:", clusters)
print("Krytyczne wierzchołki:", critical)
draw_graph_with_clusters(G, clusters, critical)