import random
from heapq import heappop, heappush
from itertools import count
from typing import Any, Callable, Dict, Set
from ..graph import Graph
from ..partition import Partition
from .bounds import SelfConfiguringLowerBound
# frm: TODO: Performance: Think about the efficiency of the routines in this module. Almost all
# of these involve traversing the entire graph, and I fear that callers
# might make multiple calls.
#
# Possible solutions are to 1) speed up these routines somehow and 2) cache
# results so that at least we don't do the traversals over and over.
# frm: TODO: Refactoring: Rethink WTF this module is all about.
#
# It seems like a grab bag for lots of different things - used in different places.
#
# What got me to write this comment was looking at the signature for def contiguous()
# which operates on a partition, but lots of other routines here operate on graphs or
# other things. So, what is going on?
#
# Peter replied to this comment in a pull request:
#
# So anything that is prefixed with an underscore in here should be a helper
# function and not a part of the public API. It looks like, other than
# is_connected_bfs (which should probably be marked "private" with an
# underscore) everything here is acting like an updater.
#
def _are_reachable(graph: Graph, start_node: Any, avoid: Callable, targets: Any) -> bool:
"""
A modified version of NetworkX's function
`networkx.algorithms.shortest_paths.weighted._dijkstra_multisource()`
This function checks if the targets are reachable from the start_node node
while avoiding edges based on the avoid condition function.
:param graph: Graph
:type graph: Graph
:param start_node: The starting node
:type start_node: int
:param avoid: The function that determines if an edge should be avoided.
It should take in three parameters: the start node, the end node, and
the edges to avoid. It should return True if the edge should be avoided,
False otherwise.
# frm: TODO: Documentation: Fix the comment above about the "avoid" function parameter.
# It may have once been accurate, but the original code below
# passed parameters to it of (node_id, neighbor_node_id, edge_data_dict)
# from NetworkX.Graph._succ So, "the edges to avoid" above is wrong.
# This whole issue is moot, however, since the only routine
# that is used as an avoid function ignores the third parameter.
# Or rather it used to avoid the third parameter, but it has
# been updated to only take two parameters, and the code below
# has been modified to use Graph.neighbors() instead of _succ
# because 1) we can't use NX and 2) because we don't need the
# edge data dictionary anyways...
#
:type avoid: Callable
:param targets: The target nodes that we would like to reach
:type targets: Any
:returns: True if all of the targets are reachable from the start_node node
under the avoid condition, False otherwise.
:rtype: bool
"""
push = heappush
pop = heappop
node_distances = {} # dictionary of final distances
seen = {}
# fringe is heapq with 3-tuples (distance,c,node)
# use the count c to avoid comparing nodes (may not be able to)
c = count()
fringe = []
seen[start_node] = 0
push(fringe, (0, next(c), start_node))
# frm: Original Code:
#
# while not all(t in seen for t in targets) and fringe:
# (d, _, v) = pop(fringe)
# if v in dist:
# continue # already searched this node.
# dist[v] = d
# for u, e in G_succ[v].items():
# if avoid(v, u, e):
# continue
#
# vu_dist = dist[v] + 1
# if u not in seen or vu_dist < seen[u]:
# seen[u] = vu_dist
# push(fringe, (vu_dist, next(c), u))
#
# return all(t in seen for t in targets)
#
# While we have not yet seen all of our targets and while there is
# still some fringe...
while not all(tgt in seen for tgt in targets) and fringe:
(distance, _, node_id) = pop(fringe)
if node_id in node_distances:
continue # already searched this node.
node_distances[node_id] = distance
for neighbor in graph.neighbors(node_id):
if avoid(node_id, neighbor):
continue
neighbor_distance = node_distances[node_id] + 1
if neighbor not in seen or neighbor_distance < seen[neighbor]:
seen[neighbor] = neighbor_distance
push(fringe, (neighbor_distance, next(c), neighbor))
# frm: TODO: Refactoring: Simplify this code. It computes distances and counts but
# never uses them. These must be relics of code copied
# from somewhere else where it had more uses...
return all(tgt in seen for tgt in targets)
[docs]
def single_flip_contiguous(partition: Partition) -> bool:
"""
Check if swapping the given node from its old assignment disconnects the
old assignment class.
:param partition: The proposed next :class:`~gerrychain.partition.Partition`
:type partition: Partition
:returns: whether the partition is contiguous
:rtype: bool
We assume that `removed_node` belonged to an assignment class that formed a
connected subgraph. To see if its removal left the subgraph connected, we
check that the neighbors of the removed node are still connected through
the changed graph.
"""
parent = partition.parent
flips = partition.flips
if not flips or not parent:
return contiguous(partition)
graph = partition.graph
assignment = partition.assignment
def _partition_edge_avoid(start_node: Any, end_node: Any):
"""
Helper function used in the graph traversal to avoid edges that cross between different
assignments. It's crucial for ensuring that the traversal only considers paths within
the same assignment class.
:param start_node: The start node of the edge.
:type start_node: Any
:param end_node: The end node of the edge.
:type end_node: Any
:param edge_attrs: The attributes of the edge (not used in this function). Needed
because this function is passed to :func:`_are_reachable`, which expects the
avoid function to have this signature.
:type edge_attrs: Dict
:returns: True if the edge should be avoided (i.e., if it crosses assignment classes),
False otherwise.
:rtype: bool
"""
return assignment.mapping[start_node] != assignment.mapping[end_node]
for changed_node in flips:
old_assignment = partition.parent.assignment.mapping[changed_node]
old_neighbors = [
node
for node in graph.neighbors(changed_node)
if assignment.mapping[node] == old_assignment
]
# Under our assumptions, if there are no old neighbors, then the
# old_assignment district has vanished. It is trivially connected.
# We consider the empty district to be disconnected.
if not old_neighbors:
return False
start_neighbor = random.choice(old_neighbors)
# Check if all old neighbors in the same assignment are still reachable.
# The "_partition_edge_avoid" function will prevent searching across
# a part (district) boundary
connected = _are_reachable(graph, start_neighbor, _partition_edge_avoid, old_neighbors)
if not connected:
return False
# All neighbors of all changed nodes are connected, so the new graph is
# connected.
return True
def _affected_parts(partition: Partition) -> Set[int]:
"""
Checks which partitions were affected by the change of nodes.
:param partition: The proposed next :class:`~gerrychain.partition.Partition`
:type partition: Partition
:returns: The set of IDs of all parts that gained or lost a node
when compared to the parent partition.
:rtype: Set[int]
"""
flips = partition.flips
parent = partition.parent
if flips is None:
return partition.parts
if parent is None:
return set(flips.values())
affected = set()
for node, part in flips.items():
affected.add(part)
affected.add(parent.assignment.mapping[node])
return affected
[docs]
def contiguous(partition: Partition) -> bool:
"""
Check if the parts of a partition are connected
:param partition: The proposed next :class:`~gerrychain.partition.Partition`
:type partition: Partition
:returns: Whether the partition is contiguous
:rtype: bool
"""
return all(is_connected_bfs(partition.subgraphs[part]) for part in _affected_parts(partition))
[docs]
def contiguous_bfs(partition: Partition) -> bool:
"""
Checks that a given partition's parts are connected as graphs using a simple
breadth-first search.
:param partition: Instance of Partition
:type partition: Partition
:returns: Whether the parts of this partition are connected
:rtype: bool
"""
# frm: TODO: Refactoring: Figure out why this routine, contiguous_bfs() exists.
#
# It is mentioned in __init__.py so maybe it is used externally in legacy code.
#
# However, I have changed the code so that it just calls contiguous() and all
# of the tests pass, so I am going to assume that my comment below is accurate,
# that is, I am assuming that this function does not need to exist independently
# except for legacy purposes. Stated differently, if someone can verify that
# this routine is NOT needed for legacy purposes, then we can just delete it.
#
# It seems to be exactly the same conceptually as contiguous(). It looks
# at the "affected" parts - those that have changed node
# assignments from parent, and sees if those parts are
# contiguous.
#
# frm: Original Code:
#
# parts_to_check = _affected_parts(partition)
#
# # Generates a subgraph for each district and perform a BFS on it
# # to check connectedness.
# for part in parts_to_check:
# adj = nx.to_dict_of_lists(partition.subgraphs[part])
# if _bfs(adj) is False:
# return False
#
# return True
return contiguous(partition)
[docs]
def number_of_contiguous_parts(partition: Partition) -> int:
"""
:param partition: Instance of Partition; contains connected components.
:type partition: Partition
:returns: Number of contiguous parts in the partition.
:rtype: int
"""
parts = partition.assignment.parts
return sum(1 for part in parts if is_connected_bfs(partition.subgraphs[part]))
# Create an instance of SelfConfiguringLowerBound using the number_of_contiguous_parts function.
# This instance, no_more_discontiguous, is configured to maintain a lower bound on the number of
# contiguous parts in a partition. This is still callable since the class
# SelfConfiguringLowerBound implements the __call__ magic method.
no_more_discontiguous = SelfConfiguringLowerBound(number_of_contiguous_parts)
[docs]
def contiguous_components(partition: Partition) -> Dict[int, list]:
"""
Return the connected components of each of the subgraphs of the parts
of the partition.
:param partition: Instance of Partition; contains connected components.
:type partition: Partition
:returns: dictionary mapping each part ID to a list holding the connected
subgraphs of that part of the partition
:rtype: dict
"""
# frm: TODO: Documentation: Migration Guide: NX vs RX Issues here:
#
# The call on subgraph() below is perhaps problematic because it will renumber
# node_ids...
#
# The issue is not that the code is incorrect (with RX there is really no other
# option), but rather that any legacy code will be unprepared to deal with the fact
# that the subgraphs returned are (I think) three node translations away from the
# original NX-Graph object's node_ids.
#
# Translations:
#
# 1) From NX to RX when partition was created
# 2) From top-level RX graph to the partition's subgraphs for each part (district)
# 3) From each part's subgraph to the subgraphs of contiguous_components...
#
connected_components_in_each_partition = {}
for part, subgraph in partition.subgraphs.items():
# create a subgraph for each set of connected nodes in the part's nodes
list_of_connected_subgraphs = subgraph.subgraphs_for_connected_components()
connected_components_in_each_partition[part] = list_of_connected_subgraphs
return connected_components_in_each_partition
def _bfs(graph: Dict[int, list]) -> bool:
"""
Performs a breadth-first search on the provided graph and returns True or
False depending on whether the graph is connected.
:param graph: Dict-of-lists; an adjacency matrix.
:type graph: Dict[int, list]
:returns: is this graph connected?
:rtype: bool
"""
q = [next(iter(graph))]
visited = set()
num_nodes = len(graph)
# Check if the district has a single vertex. If it does, then simply return
# `True`, as it's trivially connected.
if num_nodes <= 1:
return True
# bfs!
while len(q) > 0:
current = q.pop(0)
neighbors = graph[current]
for neighbor in neighbors:
if neighbor not in visited:
visited.add(neighbor)
q += [neighbor]
return num_nodes == len(visited)
# frm: TODO: Testing: Verify that is_connected_bfs() works - add a test or two...
# frm: TODO: Refactoring: Move this code into graph.py. It is all about the Graph...
# frm: TODO: Documentation: This code was obtained from the web - probably could be optimized...
# This code replaced calls on nx.is_connected()
[docs]
def is_connected_bfs(graph: Graph):
if not graph:
return True
nodes = list(graph.node_indices)
start_node = random.choice(nodes)
visited = {start_node}
queue = [start_node]
while queue:
current_node = queue.pop(0)
for neighbor in graph.neighbors(current_node):
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)
return len(visited) == len(nodes)