diff --git a/graph_new.py b/graph_new.py new file mode 100644 index 0000000..57f058e --- /dev/null +++ b/graph_new.py @@ -0,0 +1,142 @@ +import numpy as np +import matplotlib.pyplot as plt +import random +from typing import List, Optional, Tuple, Dict +from itertools import combinations + +class Node: + """Represents a node in the graph.""" + def __init__(self, node_id: int): + self.node_id = node_id + + def __repr__(self): + return f"Node({self.node_id})" + +class Graph: + """Represents a graph with nodes and edges.""" + def __init__(self, seed: int, num_nodes: int, edge_ratio: Optional[float] = 1.5): + self.seed = seed + self.num_nodes = num_nodes + self.edge_ratio = edge_ratio + self.nodes = [Node(i) for i in range(num_nodes)] + self.edges = [] + self._generate_graph() + + def _generate_graph(self): + random.seed(self.seed) + num_edges = int(self.edge_ratio * self.num_nodes) + + # List of all possible edges + possible_edges = list(combinations(range(self.num_nodes), 2)) + random.shuffle(possible_edges) + + self.edges = possible_edges[:num_edges] + + def get_neighbors(self, node_id: int) -> List[int]: + """Returns a list of neighbors for a given node.""" + neighbors = [] + for edge in self.edges: + if edge[0] == node_id: + neighbors.append(edge[1]) + elif edge[1] == node_id: + neighbors.append(edge[0]) + return neighbors + + def __repr__(self): + return f"Graph(num_nodes={self.num_nodes}, num_edges={len(self.edges)})" + +class Cluster: + """Represents clusters in a graph.""" + def __init__(self, graph: Graph, num_clusters: int, num_walks: int = 100, walk_length: int = 10): + self.graph = graph + self.num_clusters = num_clusters + self.num_walks = num_walks + self.walk_length = walk_length + self.clusters = self._cluster_nodes() + self.critical_nodes = self._find_critical_nodes() + + def _random_walk(self, start_node: int, walk_length: int) -> List[int]: + walk = [start_node] + for _ in range(walk_length - 1): + neighbors = self.graph.get_neighbors(walk[-1]) + if not neighbors: + break + walk.append(random.choice(neighbors)) + return walk + + def _random_walks_matrix(self) -> np.ndarray: + visit_counts = np.zeros((self.graph.num_nodes, self.graph.num_nodes), dtype=int) + + for node in range(self.graph.num_nodes): + for _ in range(self.num_walks): + walk = self._random_walk(node, self.walk_length) + for neighbor in walk: + visit_counts[node][neighbor] += 1 + + return visit_counts + + def _cluster_nodes(self) -> List[int]: + visit_matrix = self._random_walks_matrix() + kmeans = KMeans(n_clusters=self.num_clusters, random_state=self.graph.seed) + kmeans.fit(visit_matrix) + return kmeans.labels_ + + def _find_critical_nodes(self) -> List[int]: + visit_matrix = self._random_walks_matrix() + initial_clusters = self._cluster_nodes() + critical_nodes = [] + + for node in range(self.graph.num_nodes): + temp_edges = [edge for edge in self.graph.edges if node not in edge] + temp_graph = Graph(self.graph.seed, self.graph.num_nodes) + temp_graph.edges = temp_edges + + temp_cluster = Cluster(temp_graph, self.num_clusters, self.num_walks, self.walk_length) + new_clusters = temp_cluster.clusters + + if not np.array_equal(initial_clusters, new_clusters): + critical_nodes.append(node) + + return critical_nodes + + def draw_graph_with_clusters(self): + pos = self._spring_layout() + plt.figure(figsize=(10, 7)) + + unique_clusters = np.unique(self.clusters) + colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_clusters))) + + for cluster, color in zip(unique_clusters, colors): + nodes_in_cluster = [node for node, cluster_id in enumerate(self.clusters) if cluster_id == cluster] + plt.scatter([pos[node][0] for node in nodes_in_cluster], [pos[node][1] for node in nodes_in_cluster], color=color, s=100, label=f'Cluster {cluster}') + + for edge in self.graph.edges: + x = [pos[edge[0]][0], pos[edge[1]][0]] + y = [pos[edge[0]][1], pos[edge[1]][1]] + plt.plot(x, y, color='gray', alpha=0.5) + + critical_pos = {node: pos[node] for node in self.critical_nodes} + plt.scatter([critical_pos[node][0] for node in critical_pos], [critical_pos[node][1] for node in critical_pos], color='black', s=200, label='Critical Nodes') + + for node, (x, y) in pos.items(): + plt.text(x, y, str(node), fontsize=12, ha='center', va='center', color='white') + + plt.title("Graph Clustering and Critical Nodes") + plt.legend() + plt.show() + + def _spring_layout(self) -> Dict[int, Tuple[float, float]]: + """Calculates the spring layout for the graph.""" + pos = {i: (random.uniform(-1, 1), random.uniform(-1, 1)) for i in range(self.graph.num_nodes)} + return pos + +# Przykładowe użycie +seed = 42 +num_nodes = 30 +edge_ratio = 2 # opcjonalny + +graph = Graph(seed, num_nodes, edge_ratio) +cluster = Cluster(graph, num_clusters=3) +print("Klasteryzacja:", cluster.clusters) +print("Krytyczne wierzchołki:", cluster.critical_nodes) +cluster.draw_graph_with_clusters()