Last active
December 5, 2020 09:20
-
-
Save UrosOgrizovic/d7c9a472c579673f3021f21db8f99580 to your computer and use it in GitHub Desktop.
Find the shortest path in a graph using Bellman-Ford.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
find shortest path in graph via Bellman-Ford | |
Uros Ogrizovic | |
""" | |
import numpy as np | |
import matplotlib.pyplot as plt | |
def get_neighbors_of_node(node, adjacency_matrix): | |
""" | |
Returns a list of nodes that are neighbors | |
of the parameter node. | |
:param node: | |
:param adjacency_matrix: | |
:return: neighbors | |
""" | |
neighbors = [] | |
for (i, x) in enumerate(adjacency_matrix[node]): | |
'''If the adjacency matrix has a 0 for two nodes, | |
that means they aren't connected. Any other number | |
represents the weight of the edge between the two | |
nodes, hence the condition is if ... != 0.''' | |
if x != 0: | |
neighbors.append(i) | |
return neighbors | |
def get_whose_neighbor(node, adjacency_matrix): | |
""" | |
Returns a list of nodes which the parameter | |
node is a neighbor of. | |
:param node: | |
:param adjacency_matrix: | |
:return: neighbor_of | |
""" | |
neighbor_of = [] | |
for i in range(len(adjacency_matrix)): | |
'''If the adjacency matrix has a 0 for two nodes, | |
that means they aren't connected. Any other number | |
represents the weight of the edge between the two | |
nodes, hence the condition is if ... != 0.''' | |
if adjacency_matrix[i][node] != 0: | |
neighbor_of.append(i) | |
return neighbor_of | |
def get_node_values(adjacency_matrix, terminal_nodes): | |
""" | |
:param adjacency_matrix: the graph, represented as an adjacency matrix | |
:param terminal_nodes: indices of terminal nodes | |
:return: | |
""" | |
node_values = [-np.inf for i in range(len(adjacency_matrix))] | |
for node in terminal_nodes: | |
node_values[node] = 0 | |
nodes_to_check = list(range(len(adjacency_matrix))) | |
while True: | |
if len(nodes_to_check) == 0: | |
return node_values | |
current_node = nodes_to_check.pop() | |
neighbor_of = get_whose_neighbor(current_node, adjacency_matrix) | |
'''instead of looking at the neighbors, look at whose neighbor it is - this way, the algorithm works for | |
directed graphs''' | |
for n in neighbor_of: | |
if node_values[n] < node_values[current_node]: | |
# support weighted graphs | |
node_values[n] = node_values[current_node] - adjacency_matrix[n][current_node] | |
def check_if_terminal_nodes_in_list(lst, terminal_nodes): | |
""" | |
Check if any of the terminal nodes are in lst and return the index of the node that is in lst. | |
:param lst: | |
:param terminal_nodes: | |
:return: index of terminal node that is in lst, -1 if no terminal node is in lst | |
""" | |
for node in terminal_nodes: | |
if node in lst: | |
return node | |
return -1 | |
def get_path_price(path, adjacency_matrix): | |
""" | |
:param path: | |
:param adjacency_matrix: | |
:return: path_sum | |
""" | |
path_sum = 0 | |
for i in range(len(path)-1): | |
path_sum += adjacency_matrix[path[i]][path[i+1]] | |
return path_sum | |
def stringify_path(path): | |
""" | |
Stringify path (e.g. [0, 1, 2] to 0-1-2). | |
:param path: | |
:return: str_path | |
""" | |
if type(path) != list: # terminal node, trivial case | |
return path | |
str_path = "".join([str(node) + "-" for node in path]) | |
str_path = str_path[:-1] | |
return str_path | |
def find_shortest_path(start_node, adjacency_matrix, terminal_nodes): | |
""" | |
:param start_node: index of starting node | |
:param adjacency_matrix: the graph, represented as an adjacency matrix | |
:param terminal_nodes: indices of terminal nodes | |
:return: | |
""" | |
if start_node < 0: | |
raise ValueError("Failed to find shortest path due to negative value for start node") | |
if start_node in terminal_nodes: | |
return start_node, 0 | |
for node in terminal_nodes: | |
if node < 0: | |
raise ValueError("Failed to find shortest path due to negative values for terminal nodes") | |
paths = [] # a list of lists, where each sublist is a path of node indices | |
neighbors = [] | |
num_of_iterations = len(adj_matrix) - 1 # repeat at most |V| - 1 times | |
while True: | |
# the lengths are important for updating later, after the for loop | |
curr_paths_len = len(paths) | |
curr_neighbors_len = len(neighbors) | |
num_of_iterations -= 1 | |
modified = False # no need to do |V| - 1 iterations every time | |
if curr_neighbors_len == 0: # base case | |
modified = True | |
neighbors = get_neighbors_of_node(start_node, adjacency_matrix) | |
for n in neighbors: | |
paths.append([start_node, n]) | |
else: | |
for i in range(curr_neighbors_len): | |
current_neighbors = get_neighbors_of_node(neighbors[i], adjacency_matrix) | |
for j in range(len(current_neighbors)): | |
temp_lst = paths[i].copy() # so as not to modify paths when modifying temp_lst | |
# don't append to paths that end in a terminal node, just leave them as-is | |
if paths[i][-1] not in terminal_nodes: | |
modified = True | |
temp_lst.append(current_neighbors[j]) | |
paths.append(temp_lst) | |
if len(current_neighbors) == 0: | |
paths.append(paths[i]) # so as not to lose paths[i] when removing explored paths | |
# remove explored paths from paths | |
paths = paths[curr_paths_len:] | |
# add nodes that should be explored to neighbors | |
for path in paths: | |
neighbors.append(int(path[-1])) | |
# remove explored neighbors from neighbors | |
neighbors = neighbors[curr_neighbors_len:] | |
''' Calculate path prices if: | |
1. |V| - 1 iterations have been completed | |
or | |
2. if there have been no modifications in the previous iteration - there's no need | |
to keep copying the same paths until |V| - 1 iterations are completed, so this boolean flag | |
speeds up the process | |
''' | |
if num_of_iterations == 0 or not modified: | |
path_prices = [] | |
for path in paths: | |
# only get path prices for paths that contain a terminal node | |
if check_if_terminal_nodes_in_list(path, terminal_nodes) != -1: | |
path_prices.append(get_path_price(path, adjacency_matrix)) | |
else: | |
# if path contains no terminal nodes, it has an infinite price | |
path_prices.append(np.inf) | |
min_price = min(path_prices) | |
if min_price == np.inf: | |
return 'None', 'inf' | |
idx = path_prices.index(min_price) | |
return paths[idx], path_prices[idx] | |
if __name__ == "__main__": | |
# A, B, C, D, E | |
undirected_graph = [[0, 1, 1, 0, 0], [1, 0, 0, 1, 1], [1, 0, 0, 1, 0], | |
[0, 1, 1, 0, 1], [0, 1, 0, 1, 0]] # see "undirected_graph.png" | |
# A, B, C, D, E | |
directed_graph_weighted = [[0, 4, 2, 0, 0], [0, 0, 0, -2, 2], [0, 0, 0, 2, 0], | |
[0, 0, 0, 0, 1], [0, 0, 0, 0, 0]] # see "directed_graph_weighted.png" | |
# A, B, C, D, E | |
directed_graph_1 = [[0, 0, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0], | |
[0, 0, 0, 0, 1], [0, 0, 0, 0, 0]] # see "directed_graph_1.png" | |
# A, B, C, D, E, T | |
directed_graph_2 = [[0, 1, 1, 0, 0, 1], [0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 1, 0], | |
[0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 0]] # see "directed_graph_2.png" | |
# X, T, B, G, D, C, E, M | |
directed_graph_3 = [[0, 1, 1, 0, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0], | |
[0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 1, 0], | |
[0, 0, 0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 0, 1, 0, 0]] # see "directed_graph_3.png" | |
adj_matrix = directed_graph_1 | |
terminal_nodes = [1, len(adj_matrix) - 1] | |
start_node = 2 | |
v = get_node_values(adj_matrix, terminal_nodes) | |
print("Node values:", v) | |
shortest_path, price = find_shortest_path(start_node, adj_matrix, terminal_nodes) | |
print("Shortest path:", stringify_path(shortest_path), "| Price =", price) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment