import collections
from enum import Enum
from typing import Callable, Dict
CountyInfo = collections.namedtuple("CountyInfo", "split nodes contains")
"""
A named tuple to store county split information.
:param split: The county split status. Makes use of
:class:`.CountySplit` enum to compute.
:type split: int
:param nodes: The nodes that are contained in the county.
:type nodes: List
:param contains: The assignment IDs that are contained in the county.
:type contains: Set
"""
[docs]class CountySplit(Enum):
"""
Enum to track county splits in a partition.
:cvar NOT_SPLIT: The county is not split.
:cvar NEW_SPLIT: The county is split in the current partition.
:cvar OLD_SPLIT: The county is split in the parent partition.
"""
NOT_SPLIT = 0
NEW_SPLIT = 1
OLD_SPLIT = 2
[docs]def county_splits(partition_name: str, county_field_name: str) -> Callable:
"""
Update that allows for the tracking of county splits.
:param partition_name: Name that the :class:`.Partition` instance will store.
:type partition_name: str
:param county_field_name: Name of county ID field on the graph.
:type county_field_name: str
:returns: The tracked data is a dictionary keyed on the county ID. The
stored values are tuples of the form `(split, nodes, seen)`.
`split` is a :class:`.CountySplit` enum, `nodes` is a list of
node IDs, and `seen` is a list of assignment IDs that are
contained in the county.
:rtype: Callable
"""
def _get_county_splits(partition):
return compute_county_splits(partition, county_field_name, partition_name)
return _get_county_splits
[docs]def compute_county_splits(
partition, county_field: str, partition_field: str
) -> Dict[str, CountyInfo]:
"""
Track nodes in counties and information about their splitting.
:param partition: The partition object to compute county splits for.
:type partition: :class:`~gerrychain.partition.Partition`
:param county_field: Name of county ID field on the graph.
:type county_field: str
:param partition_field: Name of the attribute in the graph
that stores the partition information. The county
split information will be computed with respect to this
division of the graph.
:type partition_field: str
:returns: A dict containing the information on how counties changed
between the parent and child partitions. If there is no parent
partition, then only the OLD_SPLIT and NOT_SPLIT values will be
used.
:rtype: Dict[str, CountyInfo]
"""
# Create the initial county data containers.
if not partition.parent:
county_dict = dict()
for node in partition.graph.node_indices:
county = partition.graph.lookup(node, county_field)
if county in county_dict:
split, nodes, seen = county_dict[county]
else:
split, nodes, seen = CountySplit.NOT_SPLIT, [], set()
nodes.append(node)
seen.update(set([partition.assignment.mapping[node]]))
if len(seen) > 1:
split = CountySplit.OLD_SPLIT
county_dict[county] = CountyInfo(split, nodes, seen)
return county_dict
new_county_dict = dict()
parent = partition.parent
for county, county_info in parent[partition_field].items():
seen = set(partition.assignment.mapping[node] for node in county_info.nodes)
split = CountySplit.NOT_SPLIT
if len(seen) > 1:
if county_info.split != CountySplit.OLD_SPLIT:
split = CountySplit.NEW_SPLIT
else:
split = CountySplit.OLD_SPLIT
new_county_dict[county] = CountyInfo(split, county_info.nodes, seen)
return new_county_dict
[docs]def tally_region_splits(reg_attr_lst):
"""
A naive updater for tallying the number of times a region attribute is split.
for each region attribute in reg_attr_lst.
:param reg_attr_lst: A list of region names to tally splits for.
:type reg_attr_lst: List[str]
:returns: A function that takes a partition and returns a dictionary which
maps the region name to the number of times that it is split in a
a particular partition.
:rtype: Callable
"""
def _get_splits(partition):
nonlocal reg_attr_lst
if "cut_edges" not in partition.updaters:
raise ValueError("The cut_edges updater must be attached to the partition")
return {
reg_attr: total_reg_splits(partition, reg_attr) for reg_attr in reg_attr_lst
}
return _get_splits
[docs]def total_reg_splits(partition, reg_attr):
"""Returns the total number of times that reg_attr is split in the partition."""
all_region_names = set(
partition.graph.nodes[node][reg_attr] for node in partition.graph.nodes
)
split = {name: 0 for name in all_region_names}
# Require that the cut_edges updater is attached to the partition
for node1, node2 in partition["cut_edges"]:
if (
partition.assignment[node1] != partition.assignment[node2]
and partition.graph.nodes[node1][reg_attr]
== partition.graph.nodes[node2][reg_attr]
):
split[partition.graph.nodes[node1][reg_attr]] += 1
split[partition.graph.nodes[node2][reg_attr]] += 1
return sum(1 for value in split.values() if value > 0)