import sys
import warnings
from collections import defaultdict
from collections.abc import Callable
import geopandas as gpd
import networkx as nx
import numpy as np
from loguru import logger
from sklearn.neighbors import KDTree
from tqdm.auto import tqdm
from .utils import (
_create_k_col,
_find_scaling_factor,
_get_id_col_array,
_set_node_attributes,
)
def _get_nearest_nodes(
gdf: gpd.GeoDataFrame,
k: int | float | str,
*,
max_degree: int,
query_factor: int = 2,
constraint: Callable | None = None,
random_state: int | None = None,
verbose: bool = False,
) -> tuple[dict[int, list[int]], np.ndarray]:
"""
Find the nearest neighbors for each node
Args:
gdf (gpd.GeoDataFrame): GeoDataFrame containing nodes
k (int | float | str): number of nearest neighbors to connect initially
If a number, it determines the number of nearest neighbors to connect initially, also the average degree of the network.
If a string, it determines the column name containing expected degree centrality for each node, when connecting to neighbors initially.
max_degree (int): maximum degree centrality allowed
query_factor (int): factor of k to query neighbors initially to handle constraint filtering
constraint (Callable | None): constraint function to filter out invalid neighbors, default is None
Example: constraint=lambda u, v: u.household != v.household
This will ensure that nodes from the same household are not connected.
random_state (int | None): random seed for reproducibility, default is None.
verbose (bool): whether to show detailed progress messages, default is False
Returns:
dict[int, list[int]]: dictionary of nearest neighbors for each node
np.ndarray: array of degree centrality for each node
"""
logger.debug("Building KDTree for efficient nearest neighbor search")
degree_centrality_array = np.zeros(len(gdf))
if isinstance(k, str):
k_col = gdf[k].values // 2
elif isinstance(k, int | float):
k_col = _create_k_col(k, len(gdf), random_state=random_state)
else:
raise ValueError("k must be an integer, a float, or a string")
k_col = np.clip(k_col, 0, max_degree)
# Get positions from either points or polygon centroids
if gdf.geometry.geom_type.iloc[0] == "Polygon":
positions = np.stack([gdf.geometry.centroid.x, gdf.geometry.centroid.y], axis=1)
else:
positions = np.stack([gdf.geometry.x, gdf.geometry.y], axis=1)
kdtree = KDTree(positions, metric="euclidean")
nearest_neighbors = defaultdict(list)
desc = (
f"Finding {k} nearest neighbors"
if isinstance(k, int)
else f"Finding nearest neighbors based on column {k}"
)
# Step 2: Find nearest neighbors using KDTree queries with constraint handling
for this_node_idx in tqdm(range(len(gdf)), desc=desc, disable=not verbose):
expected_num_neighbors = k_col[this_node_idx]
if expected_num_neighbors > max_degree:
logger.error(
f"Node {this_node_idx} has expected degree {expected_num_neighbors} > max degree {max_degree}. Skipping node."
)
continue
if expected_num_neighbors > len(gdf):
logger.error(
f"Node {this_node_idx} has expected degree {expected_num_neighbors} > number of nodes {len(gdf)}. Skipping node."
)
continue
neighbors_set = set() # Track neighbors in a set to avoid duplicates
# Initially query more neighbors to account for filtering
query_k = min(expected_num_neighbors * query_factor, len(gdf) - 1)
while len(neighbors_set) < expected_num_neighbors:
# Query KDTree for the next batch of neighbors
_, idxs = kdtree.query(
[positions[this_node_idx]], k=query_k + 1
) # +1 to avoid self-loop
for new_node_idx in idxs[0]:
if new_node_idx == this_node_idx: # Skip the node itself
continue
if new_node_idx in neighbors_set: # Skip if already added
continue
# Avoid nodes with degree centrality >= max_degree
if degree_centrality_array[this_node_idx] >= max_degree:
continue
# Avoid nodes with degree centrality >= max_degree
expected_num_neighbors_of_new_node = k_col[new_node_idx]
if (
degree_centrality_array[new_node_idx]
+ expected_num_neighbors_of_new_node
>= max_degree
):
continue
# Avoid double counting neighbors
if this_node_idx in nearest_neighbors[new_node_idx]:
continue
# Apply constraint if provided
if constraint is not None and not constraint(
gdf.iloc[this_node_idx], gdf.iloc[new_node_idx]
):
continue
neighbors_set.add(new_node_idx)
degree_centrality_array[this_node_idx] += 1
degree_centrality_array[new_node_idx] += 1
# If we've found enough valid neighbors, break the loop
if len(neighbors_set) >= expected_num_neighbors:
break
# If we've reached all nodes, break the loop
if query_k == len(gdf) - 1:
if len(neighbors_set) < expected_num_neighbors:
warnings.warn(
f"Node at index {this_node_idx} has only {len(neighbors_set)} neighbors out of expected {expected_num_neighbors}. "
f"Consider reducing the expected degree for this node:\n{gdf.iloc[this_node_idx].to_frame().T}",
UserWarning,
stacklevel=2,
)
break
# If we still don't have enough neighbors, increase query size and try again
if len(neighbors_set) < expected_num_neighbors:
# Increase query size to find more neighbors
query_k = min(query_k + query_factor, len(gdf) - 1)
# Step 3: Assign the valid neighbors and update degree centrality
nearest_neighbors[this_node_idx] = list(neighbors_set)
return nearest_neighbors, degree_centrality_array
[docs]
def geo_watts_strogatz_network(
gdf,
k: int | str,
p: float,
*,
a=3,
scaling_factor: float | None = None,
max_degree=150,
id_col: str | None = None,
query_factor: int = 2,
node_attributes: bool | str | list[str] = True,
constraint: Callable | None = None,
random_state: int | None = None,
verbose: bool = False,
) -> nx.Graph:
r"""Construct a geo watts-strogatz network using the Geospatial Watts-Strogatz model
The Geospatial Watts-Strogatz model is a variant of the Watts-Strogatz model that incorporates spatial considerations.
First, the model connects each node to its k nearest neighbors.
Then, it rewires each edge with probability p. When an edge is rewired, it is removed and a new edge is added to a random node.
The probability of being rewired to a new node is determined by the distance between the nodes:
.. math::
p(\textrm{distance}|a, \textrm{min\_dist}) = \textrm{min}\left(1, \left(\frac{\textrm{distance}}{\textrm{min\_dist}}\right) ^ {-a}\right)
where :math:`min\_dist` is the minimum distance between nodes, and :math:`a` is the distance decay exponent parameter, default is 3.
The minimum distance is a threshold, below which nodes are connected with probability 1, if an edge is chosen to be rewired.
It is 1/20 of the bounding box diagonal by default. Users can set the scaling factor directly if needed, which is the inverse of the minimum distance.
Args:
gdf (gpd.GeoDataFrame): GeoDataFrame containing nodes
k (int | str): number of nearest neighbors to connect initially
If a number, it determines the number of nearest neighbors to connect initially, also the average degree of the network.
If a string, it determines the column name containing expected degree centrality for each node, when connecting to neighbors initially.
p (float): probability of rewiring an edge
Keyword Args:
a (int): distance decay exponent parameter, default is 3
scaling_factor (float): scaling factor is the inverse of the minimum distance between nodes, default is None.
The minimum distance is a threshold, below which nodes are connected with probability 1,
if an edge is chosen to be rewired.
If None, the scaling factor will be calculated based on the bounding box of the GeoDataFrame.
max_degree (int): maximum degree centrality allowed, default is 150
id_col (str): column name containing unique IDs, default is None.
If "index", the index of the GeoDataFrame will be used as the unique ID.
If a column name, the values in the column will be used as the unique ID.
If None, the positional index of the node will be used as the unique ID.
query_factor (int): factor of k to query neighbors initially to handle constraint filtering, default is 2
node_attributes (bool | str | list[str]): node attributes to save in the graph, default is True.
If True, all attributes will be saved as node attributes.
If False, only the position of the nodes will be saved as a `pos` attribute.
If a string or a list of strings, the attributes will be saved as node attributes.
constraint (Callable | None): constraint function to filter out invalid neighbors, default is None
Example: constraint=lambda u, v: u.household != v.household
This will ensure that nodes from the same household are not connected.
random_state (int | None): random seed for reproducibility, default is None.
verbose (bool): whether to show detailed progress messages, default is False
Returns:
nx.Graph: a geo watts-strogatz network graph with average degree k, maximum degree max_degree
"""
# Set logger level based on verbose flag
logger.remove()
logger.add(sys.stderr, level="DEBUG" if verbose else "WARNING")
logger.debug(
f"Building geo watts-strogatz network with k={k}, p={p}, a={a}, scaling_factor={scaling_factor}, max_degree={max_degree}"
)
if gdf.crs and gdf.crs.is_geographic:
warnings.warn(
"Geometry is in a geographic CRS. "
"Results from distance calculations are likely incorrect. "
"Use 'GeoDataFrame.to_crs()' to re-project geometries to a "
"projected CRS before this operation.\n",
UserWarning,
stacklevel=2,
)
if gdf.crs is None:
warnings.warn(
"Input GeoDataFrame has no CRS; storing crs=None. Downstream exports will produce GeoDataFrames with an undefined coordinate reference system.",
UserWarning,
stacklevel=2,
)
if k == 0:
raise ValueError("k must be greater than 0")
if not 0 <= p <= 1:
raise ValueError("p must be between 0 and 1")
id_col_array = _get_id_col_array(gdf, id_col)
if isinstance(k, int | float | str):
nearest_neighbors, degree_centrality_array = _get_nearest_nodes(
gdf,
k,
max_degree=max_degree,
query_factor=query_factor,
constraint=constraint,
random_state=random_state,
verbose=verbose,
)
else:
raise ValueError("k must be an integer, a float, or a string")
rewire_count = 0
graph = nx.Graph()
graph.graph["crs"] = gdf.crs
if id_col == "index":
graph.graph["id_col"] = "index"
graph.graph["index_name"] = gdf.index.name
else:
graph.graph["id_col"] = id_col
# use centroid if geometry is a polygon
if gdf.geometry.geom_type.iloc[0] == "Polygon":
pos_x_array = gdf.geometry.centroid.x.values
pos_y_array = gdf.geometry.centroid.y.values
else:
pos_x_array = gdf.geometry.x.values
pos_y_array = gdf.geometry.y.values
for this_node_idx in tqdm(
nearest_neighbors,
desc="Creating initial network from nearest neighbors",
disable=not verbose,
):
for neighboring_node_idx in nearest_neighbors[this_node_idx]:
distance = (
float(pos_x_array[this_node_idx] - pos_x_array[neighboring_node_idx])
** 2
+ (
float(
pos_y_array[this_node_idx] - pos_y_array[neighboring_node_idx]
)
** 2
)
) ** 0.5
this_node_graph_id = (
id_col_array[this_node_idx] if id_col else this_node_idx
)
if this_node_graph_id not in graph:
graph.add_node(this_node_graph_id)
neighboring_node_graph_id = (
id_col_array[neighboring_node_idx] if id_col else neighboring_node_idx
)
if neighboring_node_graph_id not in graph:
graph.add_node(neighboring_node_graph_id)
graph.add_edge(
this_node_graph_id, neighboring_node_graph_id, length=distance
)
if scaling_factor is None:
scaling_factor = _find_scaling_factor(gdf)
np_rng = np.random.default_rng(seed=random_state)
# connect each node to k/2 neighbors
# rewire edges from each node
# loop over all nodes in order (label) and neighbors in order (distance)
# no self loops or multiple edges allowed
for this_node_idx in tqdm(
nearest_neighbors,
desc="Rewiring edges in geo watts-strogatz network",
disable=not verbose,
):
for neighboring_node_idx in nearest_neighbors[this_node_idx]:
this_node_graph_id = (
id_col_array[this_node_idx] if id_col else this_node_idx
)
neighboring_node_graph_id = (
id_col_array[neighboring_node_idx] if id_col else neighboring_node_idx
)
if np_rng.random() < p:
chosen = False
while not chosen:
# get a random position index from gdf
random_node_idx = np_rng.integers(0, len(gdf))
random_node_graph_id = (
id_col_array[random_node_idx] if id_col else random_node_idx
)
checked_nodes = {random_node_idx}
# Enforce no self-loops, or multiple edges, or degree >= max_degree, or constraint
while (
random_node_idx == this_node_idx
or graph.has_edge(this_node_graph_id, random_node_graph_id)
or degree_centrality_array[random_node_idx] >= max_degree
or (
constraint is not None
and not constraint(
gdf.iloc[this_node_idx], gdf.iloc[random_node_idx]
)
)
):
random_node_idx = np_rng.integers(0, len(gdf))
random_node_graph_id = (
id_col_array[random_node_idx] if id_col else random_node_idx
)
checked_nodes.add(random_node_idx)
if len(checked_nodes) == len(gdf):
break
if len(checked_nodes) == len(gdf):
warnings.warn(
f"Node {this_node_graph_id} has exhausted all possible rewiring options. Skipping."
f"Consider reducing the constraints for this node:\n{gdf.iloc[this_node_idx].to_frame().T}",
UserWarning,
stacklevel=2,
)
break
distance = (
float(pos_x_array[this_node_idx] - pos_x_array[random_node_idx])
** 2
+ (
float(
pos_y_array[this_node_idx]
- pos_y_array[random_node_idx]
)
** 2
)
) ** 0.5
# if distance is less than minimum distance, connect with probability 1
# minimum distance is determined by method _find_scaling_factor
if distance < 1 / scaling_factor:
q = 1
# else, connect with probability (distance / min_dist) ^ (-a)
# where min_dist is the minimum distance
# and a is the distance decay parameter
else:
q = (distance * scaling_factor) ** (-a)
if np_rng.random() < q:
graph.remove_edge(this_node_graph_id, neighboring_node_graph_id)
graph.add_edge(
this_node_graph_id, random_node_graph_id, length=distance
)
degree_centrality_array[neighboring_node_idx] -= 1
degree_centrality_array[random_node_idx] += 1
rewire_count += 1
chosen = True
_set_node_attributes(graph, gdf, id_col, node_attributes)
total_edges = graph.number_of_edges()
logger.debug(
f"Rewire Count: {rewire_count:,} edges out of {total_edges:,}. {rewire_count / total_edges * 100:.2f}% of edges rewired"
)
return graph