Source code for votekit.graphs.ballot_graph

from votekit.graphs.base_graph import Graph
from votekit.pref_profile import RankProfile
from typing import Optional, Union
import networkx as nx  # type: ignore
from functools import cache
from typing import Callable
import matplotlib.pyplot as plt


def all_nodes(graph, node):
    return True


[docs] class BallotGraph(Graph): """ Class to build ballot graphs. Args: source (Union[RankProfile, int, list]): data to create graph from, either ``RankProfile`` object, number of candidates, or list of candidates. allow_partial (bool, optional): If True, builds graph using all possible ballots, If False, only uses total linear ordered ballots. Defaults to True. fix_short (bool, optional): If True, auto completes ballots of length :math:`n-1` to :math:`n`. Ballots of length less than :math:`n-1` are preserved. Defaults to True. Attributes: profile (RankProfile): Profile used to create graph, None if not provided. candidates (tuple[str]): Tuple of candidates, None if not provided. num_cands (int): Number of candidates. num_voters (float): Sum of weights of profile if provided. allow_partial (bool, optional): If True, builds graph using all possible ballots, If False, only uses total linear ordered ballots. graph (networkx.Graph): underlying ``networkx`` graph. """ def __init__( self, source: Union[RankProfile, int, list], allow_partial: Optional[bool] = True, fix_short: Optional[bool] = True, ): super().__init__(nx.Graph()) self.profile = None self.candidates = None self.allow_partial = allow_partial if isinstance(source, int): self.num_cands = source self.graph = self.build_graph(source) if isinstance(source, list): self.num_cands = len(source) self.graph = self.build_graph(len(source)) self.candidates = tuple(source) if isinstance(source, RankProfile): self.profile = source self.num_voters = source.total_ballot_wt self.num_cands = len(source.candidates) self.allow_partial = True if len(self.graph.nodes) == 0: self.graph = self.build_graph(len(source.candidates)) self.graph = self.from_profile(source, fix_short=fix_short) self.num_voters = sum(self.node_weights.values()) # if no partial ballots allowed, create induced subgraph if not self.allow_partial: total_ballots = [n for n in self.graph.nodes() if len(n) == self.num_cands] self.graph = self.graph.subgraph(total_ballots) if self.node_weights is None: self.node_weights = {ballot: 0 for ballot in self.graph.nodes} def _relabel(self, gr: nx.Graph, new_label: int, num_cands: int) -> nx.Graph: """ Relabels nodes in gr based on new_label """ node_map = {} graph_nodes = list(gr.nodes) for k in graph_nodes: # add the value of new_label to every entry in every ballot tmp = [new_label + y for y in k] # reduce everything mod new_label for i in range(len(tmp)): if tmp[i] > num_cands: tmp[i] = tmp[i] - num_cands node_map[k] = tuple([new_label] + tmp) return nx.relabel_nodes(gr, node_map)
[docs] def build_graph(self, n: int) -> nx.Graph: """ Builds graph of all possible ballots given a number of candiates. Args: n (int): Number of candidates in an election. Returns: networkx.Graph: A ``networkx`` graph. """ if n > 9: raise ValueError( "Ballot graphs with more than 9 candidates are not supported due to " "exponential growth in the number of possible ballots." ) Gc: nx.Graph = nx.Graph() # base cases if n == 1: Gc.add_nodes_from([1], weight=0, cast=False) elif n == 2: Gc.add_nodes_from([(1, 2), (2, 1)], weight=0, cast=False) Gc.add_edges_from([((1, 2), (2, 1))]) elif n > 2: G_prev = self.build_graph(n - 1) for i in range(1, n + 1): # add the node for the bullet vote i Gc.add_node((i,), weight=0, cast=False) # make the subgraph for the ballots where i is ranked first G_corner = self._relabel(G_prev, i, n) # add the components from that graph to the larger graph Gc.add_nodes_from(G_corner.nodes, weight=0, cast=False) Gc.add_edges_from(G_corner.edges) # connect the bullet vote node to the appropriate vertices if n == 3: Gc.add_edges_from([(k, (i,)) for k in G_corner.nodes]) else: Gc.add_edges_from( [(k, (i,)) for k in G_corner.nodes if len(k) == 2] ) nodes = Gc.nodes new_edges = [ (bal, (bal[1], bal[0]) + bal[2:]) for bal in nodes if len(bal) >= 2 ] Gc.add_edges_from(new_edges) return Gc
[docs] def from_profile( self, profile: RankProfile, fix_short: Optional[bool] = True ) -> nx.Graph: """ Updates existing graph based on cast ballots from a RankProfile, or creates graph based on RankProfile. Args: profile (RankProfile): ``RankProfile`` assigned to graph. fix_short (bool, optional): If True, complete short ballots. Defaults to True. Returns: networkx.Graph: Graph based on ``RankProfile``, 'cast' node attribute indicates ballots cast in ``RankProfile``. """ if self.profile is None: self.profile = profile if self.num_voters is None: self.num_voters = profile.total_ballot_wt self.candidates = tuple(profile.candidates) ballots = profile.ballots self.cand_num = self._number_cands(self.candidates) self.node_weights = {ballot: 0 for ballot in self.graph.nodes} for ballot in ballots: ballot_node = [] if ballot.ranking is None: raise TypeError("Ballots must have rankings.") for position in ballot.ranking: if len(position) > 1: raise ValueError( "ballots must be cleaned to resolve ties" ) # still unsure about ties for cand in position: ballot_node.append(self.cand_num[cand]) if len(ballot_node) == len(self.candidates) - 1 and fix_short: ballot_node = self.fix_short_ballot( ballot_node, list(self.cand_num.values()) ) if tuple(ballot_node) in self.graph.nodes: self.graph.nodes[tuple(ballot_node)]["weight"] += ballot.weight self.graph.nodes[tuple(ballot_node)]["cast"] = True self.node_weights[tuple(ballot_node)] += ballot.weight return self.graph
[docs] def fix_short_ballot(self, ballot: list, candidates: list) -> list: """ Adds missing candidates to a short ballot. Args: ballot (list): A list of candidates on the ballot. candidates (list): A list of all candidates. Returns: list: A new list with the missing candidates added to the end of the ballot. """ missing = set(candidates).difference(set(ballot)) return ballot + list(missing)
[docs] def label_cands(self, candidates, to_display: Callable = all_nodes): """ Assigns candidate labels to ballot graph for plotting. Args: candidates (list): A list of candidates. to_display (Callable, optional): A Boolean callable that takes in a graph and node, returns True if node should be displayed. Defaults to showing all nodes. """ candidate_numbers = self._number_cands(tuple(candidates)) cand_dict = {value: key for key, value in candidate_numbers.items()} cand_labels = {} for node in self.graph.nodes: if to_display(self.graph, node): ballot = [] for num in node: ballot.append(cand_dict[num]) # label the ballot and give the number of votes cand_labels[node] = ( str(tuple(ballot)) + ": " + str(self.graph.nodes[node]["weight"]) ) return cand_labels
[docs] def label_weights(self, to_display: Callable = all_nodes): """ Assigns weight labels to ballot graph for plotting. Only shows weight if non-zero. Args: to_display (Callable, optional): A Boolean callable that takes in a graph and node, returns True if node should be displayed. Defaults to showing all nodes. """ node_labels = {} for node in self.graph.nodes: if to_display(self.graph, node): # label the ballot and give the number of votes if self.graph.nodes[node]["weight"] > 0: node_labels[node] = ( str(node) + ": " + str(self.graph.nodes[node]["weight"]) ) else: node_labels[node] = str(node) return node_labels
@cache def _number_cands(self, cands: tuple) -> dict: """ Assigns numerical marker to candidates """ legend = {} for idx, cand in enumerate(cands): legend[cand] = idx + 1 return legend
[docs] def draw( self, to_display: Callable = all_nodes, neighborhoods: Optional[list[tuple]] = [], show_cast: Optional[bool] = False, labels: Optional[bool] = False, scale: float = 1.0, ): """ Visualize the graph. Args: to_display (Callable, optional): A boolean function that takes the graph and a node as input, returns True if you want that node displayed. Defaults to showing all nodes. neighborhoods (list[tuple], optional): A list of neighborhoods to display, given as tuple (node, radius). eg. (n,1) gives all nodes within one step of n. Defaults to empty list which shows all nodes. show_cast (bool, optional): If True, show only nodes with "cast" attribute = True. If False, show all nodes. Defaults to False. labels (bool, optional): If True, labels nodes with candidate names and vote totals. Defaults to False. scale (float, optional): How much to scale the base graph by. Defaults to 1.0. """ def cast_nodes(graph, node): return graph.nodes[node]["cast"] def in_neighborhoods(graph, node): centers = [node for node, radius in neighborhoods] radii = [radius for node, radius in neighborhoods] distances = [nx.shortest_path_length(graph, node, x) for x in centers] return True in [d <= r for d, r in zip(distances, radii)] if show_cast: to_display = cast_nodes if neighborhoods: to_display = in_neighborhoods ballots = [n for n in self.graph.nodes if to_display(self.graph, n)] if labels: if self.candidates is None: raise ValueError("no candidate names assigned") node_labels = self.label_cands(self.candidates, to_display) else: node_labels = self.label_weights(to_display) # if not labeling the nodes with candidates and graph is drawn from profile, # print labeling dictionary if self.profile and self.candidates: print("The candidates are labeled as follows.") cand_dict = self._number_cands(cands=tuple(self.candidates)) for cand, value in cand_dict.items(): print(value, cand) subgraph = self.graph.subgraph(ballots) pos = nx.spring_layout(subgraph) plt.figure(figsize=(8 * scale, 8 * scale)) nx.draw_networkx( subgraph, pos=pos, node_color="#a6cee3", # Color here is from districtr edge_color="#1f78b4", with_labels=True, labels=node_labels, font_weight="bold", node_size=1000 * scale, width=scale, ) # handles labels overlapping with margins x_values, y_values = zip(*pos.values()) x_max, y_max = max(x_values), max(y_values) x_min, y_min = min(x_values), min(y_values) x_margin = (x_max - x_min) * 0.25 y_margin = (y_max - y_min) * 0.25 plt.xlim(x_min - x_margin, x_max + x_margin) plt.ylim(y_min - y_margin, y_max + y_margin) plt.show()