Created
July 22, 2024 15:38
-
-
Save cohnt/13f958f34b08c6a9e793e956d6c2d390 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
from sklearn.neighbors import kneighbors_graph | |
from collections import deque | |
def adj_mat_to_adj_list(adj_mat): | |
return [list(np.where(adj_mat[i] != 0)[0]) for i in range(adj_mat.shape[0])] | |
def edge_set_to_adj_mat(edge_set, n_points): | |
# Takes in a set of edges, and the number of points, and outputs an adjacency | |
# matrix in connectivity form. | |
adj_mat = np.zeros((n_points, n_points), dtype=int) | |
for edge in edge_set: | |
adj_mat[tuple(edge)] = 1 | |
adj_mat = np.maximum(adj_mat, adj_mat.T) | |
return adj_mat | |
def draw_neighbors(ax, points, adj_mat, **kwargs): | |
# Draws the neighborhood connections onto a given axis | |
# adj_mat should be an adjacency matrix. Both | |
# connectivity and distance are acceptable. points should | |
# be a nx2 or nx3 list of 2d or 3d points. | |
dim = points.shape[1] | |
for i in range(adj_mat.shape[0]): | |
for j in range(i, adj_mat.shape[1]): | |
if adj_mat[i,j] > 0: | |
line = np.array([points[i], points[j]]) | |
if dim == 2: | |
ax.plot(line[:,0], line[:,1], **kwargs) | |
elif dim >= 3: | |
ax.plot(line[:,0], line[:,1], line[:,2], **kwargs) | |
def draw_cycle(ax, points, cycle, **kwargs): | |
# Draws a cycle onto a given axis. cycle should be a | |
# list of indices (where the last index is implicitly | |
# connected to the first, although it's not an issue | |
# if it's explicitly included). points should be a | |
# nx2 or nx3 list of 2d or 3d points. | |
dim = points.shape[1] | |
line = np.array([points[idx] for idx in cycle]) | |
if dim == 2: | |
ax.plot(line[:,0], line[:,1], **kwargs) | |
ax.plot(line[[0,-1],0], line[[0,-1],1], **kwargs) | |
elif dim >= 3: | |
ax.plot(line[:,0], line[:,1], line[:,2], **kwargs) | |
ax.plot(line[[0,-1],0], line[[0,-1],1], line[[0,-1],2], **kwargs) | |
seed = 43 | |
np.random.seed(seed) | |
def find_large_atomic_cycle_with_viz(adj_mat, L): | |
# Implementation of the above routine from the CycleCut paper, for | |
# graphs stored as adjacency matrices. | |
# adj_mat should be a graph, stored as an adjacency matrix. Both distance | |
# and connectivity matrices are acceptable. Vertices should not connect | |
# to themselves. L is the cycle length hyperparamter | |
# Returns a large atomic cycle if one is found, otherwise an empty deque | |
# | |
# Convert to an adjacency list for faster lookup | |
G = adj_mat_to_adj_list(adj_mat) | |
# | |
# Note: vertices are simply stored as their indices | |
# Note: edges are stored as frozen sets of two vertices so a->b = b->a | |
# Import visualization variables | |
prev_outer_edges = None | |
outer_edges = None | |
inner_edges_list = [] | |
Q = deque() | |
S = set() | |
T = set() | |
P = np.zeros(len(G), dtype=int)-1 # List of parent vertices | |
v0 = np.random.randint(len(G)) | |
Q.append(v0) | |
S.add(v0) | |
outer_edges = T.copy() | |
while len(Q) > 0: | |
a = Q.popleft() | |
for b in G[a]: | |
if b in S: | |
P[b] = -1 # -1 Represents null | |
I = deque() | |
U = set() | |
I.append(b) | |
U.add(b) | |
inner_edges_list = [] | |
inner_edges_list.append(U.copy()) | |
while len(I) > 0: | |
break_out = False | |
c = I.popleft() | |
for d in G[c]: | |
if (d not in U) and ({c,d} in T): | |
P[d] = c | |
I.append(d) | |
U.add(d) | |
inner_edges_list.append(U.copy()) | |
if d == a: | |
Y = deque() | |
while d != -1: | |
Y.append(d) | |
d = P[d] | |
if len(Y) >= L: | |
T.add(frozenset((a,b))) | |
prev_outer_edges = outer_edges.copy() | |
outer_edges = T.copy() | |
return Y, prev_outer_edges, outer_edges, inner_edges_list | |
else: | |
break_out = True | |
break # Break out of for d in G[c] | |
if break_out: | |
break # Break out of while len(I) > 0 | |
else: | |
Q.append(b) | |
S.add(b) | |
T.add(frozenset((a,b))) | |
prev_outer_edges = outer_edges.copy() | |
outer_edges = T.copy() | |
return deque(), {}, {}, [{}] | |
def draw_and_pause(): | |
plt.draw() | |
plt.pause(0.001) | |
plt.waitforbuttonpress() | |
data = np.random.random((500,2)) | |
const = 0.33 | |
data = data[np.logical_or.reduce((data[:,0] < const, data[:,0] > (1-const), np.logical_or(data[:,1] < const, data[:,1] > (1-const))))] | |
adj_mat = kneighbors_graph(data, 12).toarray() | |
adj_mat = np.maximum(adj_mat, adj_mat.T) | |
cycle, prev_outer_edges, outer_edges, inner_edges_list = find_large_atomic_cycle_with_viz(adj_mat, 10) | |
fig, ax = plt.subplots() | |
# First, draw the data and neighbors | |
draw_neighbors(ax, data, adj_mat, c="grey", linewidth=0.5, zorder=1) | |
ax.scatter(data[:,0], data[:,1], c="grey", zorder=10) | |
ax.set_title("Data and k-Nearest-Neighbors") | |
draw_and_pause() | |
# Now, show the outer BFS right before the cycle was found | |
prev_outer_edges_adj_mat = edge_set_to_adj_mat(prev_outer_edges, data.shape[0]) | |
draw_neighbors(ax, data, prev_outer_edges_adj_mat, c="blue", linewidth=2.0, zorder=50) | |
ax.set_title("Outer BFS (Right Before Finding a Large Atomic Cycle)") | |
draw_and_pause() | |
# Now, show the outer BFS when it finds the cycle | |
new_edges = outer_edges - prev_outer_edges | |
new_edges_adj_mat = edge_set_to_adj_mat(new_edges, data.shape[0]) | |
draw_neighbors(ax, data, new_edges_adj_mat, c="green", linewidth=4.0, zorder=50) | |
ax.set_title("Identified a Cycle") | |
draw_and_pause() | |
# Now, visualize each iteration of the inner BFS | |
point_list = np.array(list(inner_edges_list[0])) | |
ax.scatter(data[point_list,0], data[point_list,1], c="red", zorder=100) | |
ax.set_title("Inner BFS (Checking if it's a Large Atomic Cycle)") | |
plt.draw() | |
plt.pause(0.001) | |
every = 10 | |
for i in range(1,len(inner_edges_list), every): | |
point_set = set() | |
for j in range(i, min(i+every, len(inner_edges_list))): | |
point_set = point_set | inner_edges_list[j] | |
point_list = np.array(list(point_set - inner_edges_list[i-1])) | |
print(len(point_list)) | |
ax.scatter(data[point_list,0], data[point_list,1], c="red", zorder=100) | |
plt.draw() | |
plt.pause(0.001) | |
draw_and_pause() | |
# Finally, trace the atomic cycle | |
if len(cycle) > 0: | |
draw_cycle(ax, data, cycle, c="red", zorder=200, linewidth=4.0) | |
ax.set_title("Large Atomic Cycle Found") | |
draw_and_pause() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment