Source code for stark_qa.tools.graph

from typing import List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch_geometric.utils.num_nodes import maybe_num_nodes


[docs]def k_hop_subgraph( node_idx: Union[int, List[int], Tensor], num_hops: int, edge_index: Tensor, relabel_nodes: bool = False, num_nodes: Optional[int] = None, flow: str = 'source_to_target', directed: bool = False, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """ Extracts the k-hop subgraph around a given node or a list of nodes. Args: node_idx (Union[int, List[int], Tensor]): The central node or a list of central nodes. num_hops (int): The number of hops to consider. edge_index (Tensor): The edge indices of the graph. relabel_nodes (bool, optional): If True, the nodes will be relabeled to a contiguous range. Defaults to False. num_nodes (Optional[int], optional): The number of nodes in the graph. Defaults to None. flow (str, optional): The flow direction ('source_to_target', 'target_to_source', 'bidirectional'). Defaults to 'source_to_target'. directed (bool, optional): If True, the graph is treated as directed. Defaults to False. Returns: Tuple[Tensor, Tensor, Tensor, Tensor]: The node indices, the edge indices, the indices of the original nodes, and the edge mask. """ num_nodes = maybe_num_nodes(edge_index, num_nodes) assert flow in ['source_to_target', 'target_to_source', 'bidirectional'], "Invalid flow direction" if flow == 'target_to_source': row, col = edge_index elif flow == 'source_to_target': col, row = edge_index else: col, row = torch.concat([edge_index, edge_index[[1, 0]]], dim=1) node_mask = row.new_empty(num_nodes, dtype=torch.bool) edge_mask = row.new_empty(row.size(0), dtype=torch.bool) if isinstance(node_idx, (int, list, tuple)): node_idx = torch.tensor([node_idx], device=row.device).flatten() else: node_idx = node_idx.to(row.device) subsets = [node_idx] for _ in range(num_hops): node_mask.fill_(False) node_mask[subsets[-1]] = True torch.index_select(node_mask, 0, row, out=edge_mask) subsets.append(col[edge_mask]) subset, inv = torch.cat(subsets).unique(return_inverse=True) inv = inv[:node_idx.numel()] node_mask.fill_(False) node_mask[subset] = True if flow == 'bidirectional': col, row = edge_index if not directed: edge_mask = node_mask[row] & node_mask[col] edge_index = edge_index[:, edge_mask] if relabel_nodes: edge_index = relabel_graph(subset, edge_index, num_nodes) return subset, edge_index, inv, edge_mask
[docs]def relabel_graph(subset: Tensor, edge_index: Tensor, num_nodes: int) -> Tensor: """ Relabels the nodes in the graph to a contiguous range. Args: subset (Tensor): The subset of nodes. edge_index (Tensor): The edge indices of the graph. num_nodes (int): The number of nodes in the graph. Returns: Tensor: The relabeled edge indices. """ row, col = edge_index node_idx = row.new_full((num_nodes, ), -1) node_idx[subset] = torch.arange(subset.size(0), device=row.device) edge_index = node_idx[edge_index] return edge_index