"""
Updaters that compute spanning tree statistics.
"""
import math
import numpy
import networkx
from typing import Dict
def _num_spanning_trees_in_district(partition, district: int) -> int:
"""
Given a district ID, returns the number of spanning trees in the
subgraph of self corresponding to the district.
Uses Kirchoff's theorem to compute the number of spanning trees.
:param partition: :class:`gerrychain.Partition`
:type partition: :class:`gerrychain.Partition`
:param district: A district label (part) in the partition.
:type district: int
:returns: The number of spanning trees in the subgraph of the
partition corresponding to district
:rtype: int
"""
graph = partition.subgraphs[district]
laplacian = networkx.laplacian_matrix(graph)
L = numpy.delete(numpy.delete(laplacian.todense(), 0, 0), 1, 1)
return math.exp(numpy.linalg.slogdet(L)[1])
[docs]def num_spanning_trees(partition) -> Dict[int, int]:
"""
:returns: The number of spanning trees in each part (district) of a partition.
:rtype: Dict[int, int]
"""
return {
part: _num_spanning_trees_in_district(partition, part)
for part in partition.parts
}