本文参考了C++ ACM的dijkstra模板:
先说之前自己写错了的:
from collections import defaultdict
from heapq import heappush, heappop
def dijkstra(edges, start_node, end_node):
graph = defaultdict(dict)
for src, dst, distance in edges:
graph[src][dst] = distance
q = [(0, start_node, None)]
found_min_dist_nodes = set()
distances = {start_node: 0}
back_paths = {}
while q:
cost, min_dist_node, src_node = heappop(q)
# 下面这行代码非常关键,是为了去除优先级队列q里冗余的push,见后注释说明重复节点push问题
#if min_dist_node in found_min_dist_nodes:
# continue
found_min_dist_nodes.add(min_dist_node)
back_paths[min_dist_node] = src_node
if min_dist_node == end_node:
return cost, back_paths
for neibor_node, distance in graph[min_dist_node].items():
if neibor_node in found_min_dist_nodes:
continue
prev_dist = distances.get(neibor_node, float('inf'))
new_dist = cost + distance
if new_dist < prev_dist:
distances[neibor_node] = new_dist
# 下面这个代码,正常情况下,应该是更新优先级队列里neibor_node的priority value
# 但因为priority queue无原生更新api支持,所以下面代码是在没有remove neibor_node的情况直接push,会导致重复节点push
heappush(q, (new_dist, neibor_node, min_dist_node))
return float("inf"), back_paths
def find_path(back_paths, start_node, end_node):
ans = [end_node]
while end_node != start_node:
end_node = back_paths[end_node]
ans.append(end_node)
return ans[::-1]
if __name__ == "__main__":
edges = [
("A", "B", 5),
("A", "C", 10),
("C", "D", 10),
("B", "C", 2)]
dist, backpaths = dijkstra(edges, "A", "D")
print("dist: ", dist)
print(find_path(backpaths, "A", "D"))
上面案例中的图示例:
A--------(5)--------->B
| /
(10) / (2)
| /
C
|(10)
D
肉眼看,A->D的最短路距离是17,路径是ABCD。
如果没有下面的代码:
if min_dist_node in found_min_dist_nodes:
continue
输出A到D的最短距离和路径为:
dist: 17
['A', 'C', 'D']
路径这个答案是错的!!!路径计算错了!但是距离计算是ok的!
为啥呢???因此C节点会重复push,debug下就可以看出来了:
因此,我们加上上述if判定进行去重,因为之前已经pop过了,再pop重复节点已经没有意义:
from collections import defaultdict
from heapq import heappush, heappop
def dijkstra(edges, start_node, end_node):
graph = defaultdict(dict)
for src, dst, distance in edges:
graph[src][dst] = distance
q = [(0, start_node, None)]
found_min_dist_nodes = set()
distances = {start_node: 0}
back_paths = {}
while q:
cost, min_dist_node, src_node = heappop(q)
# 下面这行代码非常关键,是为了去除优先级队列q里冗余的push,见后注释说明重复节点push问题
if min_dist_node in found_min_dist_nodes:
continue
found_min_dist_nodes.add(min_dist_node)
back_paths[min_dist_node] = src_node
if min_dist_node == end_node:
return cost, back_paths
for neibor_node, distance in graph[min_dist_node].items():
if neibor_node in found_min_dist_nodes:
continue
prev_dist = distances.get(neibor_node, float('inf'))
new_dist = cost + distance
if new_dist < prev_dist:
distances[neibor_node] = new_dist
# 下面这个代码,正常情况下,应该是更新优先级队列里neibor_node的priority value
# 但因为priority queue无原生更新api支持,所以下面代码是在没有remove neibor_node的情况直接push,会导致重复节点push
heappush(q, (new_dist, neibor_node, min_dist_node))
return float("inf"), back_paths
def find_path(back_paths, start_node, end_node):
ans = [end_node]
while end_node != start_node:
end_node = back_paths[end_node]
ans.append(end_node)
return ans[::-1]
if __name__ == "__main__":
edges = [
("A", "B", 5),
("A", "C", 10),
("C", "D", 10),
("B", "C", 2)]
dist, backpaths = dijkstra(edges, "A", "D")
print("dist: ", dist)
print(find_path(backpaths, "A", "D"))
输出:
dist: 17
['A', 'B', 'C', 'D']
这下就对了!!!
因此对于dijkstra最短路代码,还是要加上if判定!