import json
import os
import os.path as osp
import zipfile
import numpy as np
import pandas as pd
import torch
from huggingface_hub import hf_hub_download
from langdetect import detect
from ogb.nodeproppred import NodePropPredDataset
from ogb.utils.url import download_url, extract_zip
from tqdm import tqdm
from typing import Union
from stark_qa.skb.knowledge_base import SKB
from stark_qa.tools.download_hf import download_hf_file, download_hf_folder
from stark_qa.tools.io import load_files, save_files
from stark_qa.tools.process_text import compact_text
DATASET = {
"repo": "snap-stanford/stark",
'metadata': 'skb/mag/schema',
'raw': 'skb/mag/idx_title_abs.zip',
'processed': 'skb/mag/processed.zip'
}
RAW_DATA = {
'ogbn_papers100M': 'https://snap.stanford.edu/ogb/data/misc/ogbn_papers100M/paperinfo.zip',
'mag_mapping': 'https://zenodo.org/records/2628216/files'
}
[docs]class MagSKB(SKB):
test_columns = ['title', 'abstract', 'text']
candidate_types = ['paper']
node_type_dict = {0: 'author', 1: 'institution', 2: 'field_of_study', 3: 'paper'}
edge_type_dict = {
0: 'author___affiliated_with___institution',
1: 'paper___cites___paper',
2: 'paper___has_topic___field_of_study',
3: 'author___writes___paper'
}
node_attr_dict = {
'paper': ['title', 'abstract', 'publication date', 'venue'],
'author': ['name'],
'institution': ['name'],
'field_of_study': ['name']
}
def __init__(self,
root: Union[str, None] = None,
download_processed: bool = True,
**kwargs):
"""
Initialize the MagSKB class.
Args:
root (Union[str, None]): Root directory to store the dataset. If None, default HF cache paths will be used.
download_processed (bool): Whether to download the processed data.
"""
self.root = root
if download_processed:
if (self.root is None) or (self.root is not None and not osp.exists(osp.join(self.root, 'processed', 'node_info.pkl'))):
processed_path = hf_hub_download(
DATASET["repo"], DATASET["processed"], repo_type="dataset"
)
if self.root is None:
self.root = osp.dirname(processed_path)
if not osp.exists(osp.join(self.root, 'processed', 'node_info.pkl')):
with zipfile.ZipFile(processed_path, "r") as zip_ref:
zip_ref.extractall(self.root)
print(f"Extracting downloaded processed data to {self.root}")
self.raw_data_dir = osp.join(self.root, 'raw')
self.processed_data_dir = osp.join(self.root, 'processed')
self.graph_data_root = osp.join(self.raw_data_dir, 'ogbn_mag')
self.text_root = osp.join(self.raw_data_dir, 'ogbn_papers100M')
# existing dirs/files
self.schema_dir = osp.join(self.root, 'schema')
if not osp.exists(self.schema_dir):
download_hf_folder(
DATASET["repo"], DATASET["metadata"],
repo_type="dataset", save_as_folder=self.schema_dir
)
self.mag_mapping_dir = osp.join(self.graph_data_root, 'mag_mapping')
self.ogbn_mag_mapping_dir = osp.join(self.graph_data_root, 'mapping')
self.title_path = osp.join(self.text_root, 'paperinfo/idx_title.tsv')
self.abstract_path = osp.join(self.text_root, 'paperinfo/idx_abs.tsv')
# new files
self.mag_metadata_cache_dir = osp.join(self.processed_data_dir, 'mag_cache')
self.paper100M_text_cache_dir = osp.join(self.processed_data_dir, 'paper100M_cache')
self.merged_filtered_path = osp.join(self.paper100M_text_cache_dir, 'idx_title_abs.tsv')
os.makedirs(self.mag_metadata_cache_dir, exist_ok=True)
os.makedirs(self.paper100M_text_cache_dir, exist_ok=True)
if osp.exists(osp.join(self.processed_data_dir, 'node_info.pkl')):
print(f'Loading from {self.processed_data_dir}!')
processed_data = load_files(self.processed_data_dir)
else:
print('Start processing raw data...')
processed_data = self._process_raw()
processed_data.update({
'node_type_dict': self.node_type_dict,
'edge_type_dict': self.edge_type_dict
})
super(MagSKB, self).__init__(**processed_data, **kwargs)
[docs] def load_edge(self, edge_type: str) -> tuple:
"""
Load edge data for the specified edge type.
Args:
edge_type (str): Type of edge to load.
Returns:
tuple: A tuple containing edge tensor and edge numbers.
"""
edge_dir = osp.join(self.graph_data_root, f"raw/relations/{edge_type}/edge.csv.gz")
edge_type_dir = osp.join(self.graph_data_root, f"raw/relations/{edge_type}/edge_reltype.csv.gz")
num_dir = osp.join(self.graph_data_root, f"raw/relations/{edge_type}/num-edge-list.csv.gz")
edge = pd.read_csv(edge_dir, names=['src', 'dst'])
edge_t = pd.read_csv(edge_type_dir, names=['type'])
edge_n = pd.read_csv(num_dir, names=['num'])
edge_num = edge_n['num'].tolist()
edge = [edge['src'].tolist(), edge['dst'].tolist(), edge_t['type'].tolist()]
edge = torch.LongTensor(edge)
return edge, edge_num
[docs] def load_english_paper_text(self,
mag_ids: list,
download_cache: bool = True) -> pd.DataFrame:
"""
Load English text data for the papers.
Args:
mag_ids (list): List of MAG IDs for the papers.
download_cache (bool): Whether to download cached data.
Returns:
DataFrame: DataFrame containing English titles and abstracts.
"""
def is_english(text):
try:
return detect(text) == 'en'
except:
return False
if not osp.exists(self.merged_filtered_path):
if download_cache:
merged_filtered_zip_path = self.merged_filtered_path.replace('tsv', 'zip')
download_hf_file(
DATASET["repo"], DATASET["raw"],
repo_type="dataset", save_as_file=merged_filtered_zip_path
)
extract_zip(merged_filtered_zip_path, osp.dirname(self.merged_filtered_path))
else:
if not osp.exists(self.title_path):
raw_text_path = download_url(RAW_DATA['ogbn_papers100M'], self.text_root)
extract_zip(raw_text_path, self.text_root)
print('Start reading title...')
title = pd.read_csv(self.title_path, sep='\t', header=None)
title.columns = ["mag_id", "title"]
print('Filtering titles in English...')
# filter the titles that are in mag_ids
title = title[title['mag_id'].apply(lambda x: x in mag_ids)]
title_en = title[title['title'].apply(is_english)]
print('Start reading abstract...')
abstract = pd.read_csv(self.abstract_path, sep='\t', header=None)
abstract.columns = ["mag_id", "abstract"]
print('Filtering abstracts in English...')
abstract = abstract[abstract['mag_id'].apply(lambda x: x in mag_ids)]
abstract_en = abstract[abstract['abstract'].apply(is_english)]
print('Start merging titles and abstracts...')
title_abs_en = pd.merge(title, abstract, how="outer", on="mag_id", sort=True)
title_abs_en.to_csv(self.merged_filtered_path, sep="\t", header=True, index=False)
print('Loading merged and filtered titles and abstracts (English)...')
title_abs_en = pd.read_csv(self.merged_filtered_path, sep='\t')
title_abs_en.columns = ['mag_id', 'title', 'abstract']
print('Done!')
return title_abs_en
[docs] def get_map(self, df):
"""
Create mappings between MAG IDs and internal IDs.
Args:
df (DataFrame): DataFrame containing MAG IDs.
Returns:
tuple: Mappings from MAG IDs to internal IDs and vice versa.
"""
mag2id, id2mag = {}, {}
for idx in range(len(df)):
mag2id[df['mag_id'][idx]] = idx
id2mag[idx] = df['mag_id'][idx]
return mag2id, id2mag
[docs] def get_doc_info(self,
idx : int,
compact: bool = False,
add_rel: bool = True,
n_rel: int = -1) -> str:
"""
Get document information for the specified node.
Args:
idx (int): Index of the node.
compact (bool): Whether to compact the text.
add_rel (bool): Whether to add relation information.
n_rel (int): Number of relations to add. Default is -1 if all relations are included.
Returns:
str: Document information.
"""
node = self[idx]
if node.type == 'author':
doc = f'- author name: {node.DisplayName}\n'
if node.PaperCount != -1:
doc += f'- author paper count: {node.PaperCount}\n'
if node.CitationCount != -1:
doc += f'- author citation count: {node.CitationCount}\n'
doc = doc.replace('-1', 'Unknown')
elif node.type == 'paper':
doc = f' - paper title: {node.title}\n'
doc += ' - abstract: ' + node.abstract.replace('\r', '').rstrip('\n') + '\n'
if str(node.Date) != '-1':
doc += f' - publication date: {node.Date}\n'
if str(node.OriginalVenue) != '-1':
doc += f' - venue: {node.OriginalVenue}\n'
elif str(node.JournalDisplayName) != '-1':
doc += f' - journal: {node.JournalDisplayName}\n'
elif str(node.ConferenceSeriesDisplayName) != '-1':
doc += f' - conference: {node.ConferenceSeriesDisplayName}\n'
elif str(node.ConferenceInstancesDisplayName) != '-1':
doc += f' - conference: {node.ConferenceInstancesDisplayName}\n'
elif node.type == 'field_of_study':
doc = f' - field of study: {node.DisplayName}\n'
if node.PaperCount != -1:
doc += f'- field paper count: {node.PaperCount}\n'
if node.CitationCount != -1:
doc += f'- field citation count: {node.CitationCount}\n'
doc = doc.replace('-1', 'Unknown')
elif node.type == 'institution':
doc = f' - institution: {node.DisplayName}\n'
if node.PaperCount != -1:
doc += f'- institution paper count: {node.PaperCount}\n'
if node.CitationCount != -1:
doc += f'- institution citation count: {node.CitationCount}\n'
doc = doc.replace('-1', 'Unknown')
if add_rel and node.type == 'paper':
doc += self.get_rel_info(idx, n_rel=n_rel)
if compact:
doc = compact_text(doc)
return doc
[docs] def get_rel_info(self,
idx: int,
rel_types: Union[list, None] = None,
n_rel: int = -1) -> str:
"""
Get relation information for the specified node.
Args:
idx (int): Index of the node.
rel_types (Union[list, None]): List of relation types or None if all relation types are included.
n_rel (int): Number of relations. Default is -1 if all relations are included.
Returns:
doc (str): Relation information.
"""
doc = ''
rel_types = self.rel_type_lst() if rel_types is None else rel_types
for edge_t in rel_types:
node_ids = torch.LongTensor(self.get_neighbor_nodes(idx, edge_t)).tolist()
if not node_ids:
continue
node_type = self.node_types[node_ids[0]]
str_edge = edge_t.replace('___', ' ')
doc += f"\n{str_edge}: "
if n_rel > 0 and edge_t == 'paper___cites___paper':
node_ids = node_ids[torch.randperm(len(node_ids))[:n_rel]].tolist()
neighbors = []
for i in node_ids:
if self[i].type == 'paper':
neighbors.append(f'\"{self[i].title}\"')
elif self[i].type == 'author':
if str(self[i].DisplayName) != '-1':
institutions = self.get_neighbor_nodes(i, "author___affiliated_with___institution")
for inst in institutions:
assert self[inst].type == 'institution'
str_institutions = [self[j].DisplayName for j in institutions if str(self[j].DisplayName) != '-1']
if str_institutions:
str_institutions = ', '.join(str_institutions)
neighbors.append(f'{self[i].DisplayName} ({str_institutions})')
else:
neighbors.append(f'{self[i].DisplayName}')
else:
if str(self[i].DisplayName) != '-1':
neighbors.append(f'{self[i].DisplayName}')
neighbors = '(' + ', '.join(neighbors) + '),'
doc += neighbors
if doc:
doc = '- relations:\n' + doc
return doc
def _process_raw(self):
"""
Process raw data for the MAG dataset.
Returns:
processed_data (dict): Processed data.
"""
NodePropPredDataset(name='ogbn-mag', root=self.raw_data_dir)
author_data, field_of_study_data, institution_data, paper_data = self.load_meta_data()
paper_text_data = self.load_english_paper_text(paper_data['mag_id'].tolist())
print('Processing graph data...')
author_id_to_mag = {row['id']: row['mag_id'] for _, row in author_data.iterrows()}
institution_id_to_mag = {row['id']: row['mag_id'] for _, row in institution_data.iterrows()}
field_of_study_id_to_mag = {row['id']: row['mag_id'] for _, row in field_of_study_data.iterrows()}
paper_mapping = pd.read_csv(osp.join(self.ogbn_mag_mapping_dir, "paper_entidx2name.csv.gz"), names=['id', 'mag_id'], skiprows=[0])
mag_to_paper_id, paper_id_to_mag = self.get_map(paper_mapping)
unique_paper_id = paper_text_data['mag_id'].unique()
unique_paper_id = torch.unique(torch.tensor(unique_paper_id))
node_type_edge = {
0: 'author___writes___paper',
2: 'paper___has_topic___field_of_study',
3: 'paper___cites___paper'
}
node_type_overlapping_node = {}
node_type_overlapping_edge = {}
# # from mag_id to id
unique_paper_id_list = unique_paper_id.tolist()
mapping_list = [mag_to_paper_id.get(k, k) for k in tqdm(unique_paper_id_list)]
unique_paper_id = torch.tensor(mapping_list)
# load edge data
print('Start loading edge data...')
for node_type, paper_rel in node_type_edge.items():
print(node_type, paper_rel)
edge, edge_num = self.load_edge(paper_rel)
# Identify edges connected to target nodes
if node_type == 3:
target_array = unique_paper_id.numpy()
edge_array = edge.numpy()
mask = np.isin(edge_array[0], target_array) & np.isin(edge_array[1], target_array)
valid_edges_array = edge_array[:, mask]
valid_edges_tensor = torch.from_numpy(valid_edges_array)
node_type_overlapping_node[node_type] = unique_paper_id
node_type_overlapping_edge[node_type] = valid_edges_tensor
print(f'{node_type} has {unique_paper_id.shape[0]} nodes left, and {valid_edges_tensor.t().shape[0]} edges left.')
continue
else:
edge = edge.t()
connected_edges_list = []
for target_node in tqdm(unique_paper_id):
# Find the edges connected to the current target node
if node_type == 0:
mask = edge[:, 1] == target_node.item()
current_connected_edges = edge[mask].clone()
elif node_type == 2:
mask = edge[:, 0] == target_node.item()
current_connected_edges = edge[mask].clone()
# Collect the other ends of the connected edges
connected_edges_list.append(current_connected_edges)
del mask
del current_connected_edges
connected_edges = torch.cat(connected_edges_list, dim=0)
if node_type == 0:
other_ends = torch.unique(connected_edges.t()[0])
elif node_type == 2:
other_ends = torch.unique(connected_edges.t()[1])
node_type_overlapping_node[node_type] = other_ends
node_type_overlapping_edge[node_type] = connected_edges.t()
print(f'{node_type} has {other_ends.shape[0]} nodes left, and {connected_edges.shape[0]} edges left.')
# specifically choose for institution by author
edge, edge_num = self.load_edge('author___affiliated_with___institution')
edge = edge.t()
connected_edges_list = []
for target_node in node_type_overlapping_node[0]:
mask = edge[:, 0] == target_node
current_connected_edges = edge[mask].clone()
# Collect the other ends of the connected edges
connected_edges_list.append(current_connected_edges)
connected_edges = torch.cat(connected_edges_list, dim=0)
other_ends = torch.unique(connected_edges.t()[1])
node_type_overlapping_node[1] = other_ends
node_type_overlapping_edge[1] = connected_edges.t()
print(f'1 has {other_ends.shape[0]} nodes left, and {connected_edges.shape[0]} edges left.')
# save shared nodes in node_type_overlapping_node and shared edges in node_type_overlapping_edge
tot_n = sum([len(node_type_overlapping_node[i]) for i in range(4)])
# the order of re-indexing is author, institution, field_of_study, paper
domain_mappings = {
0: author_id_to_mag,
1: institution_id_to_mag,
2: field_of_study_id_to_mag,
3: paper_id_to_mag
}
new_domain_mappings = {}
domain_old_to_new = {}
id_to_mag = {}
offset = 0
node_type_overlapping_node_sort = {k: node_type_overlapping_node[k] for k in sorted(node_type_overlapping_node.keys())}
# start to re-index
print('Start re-indexing...')
for i, remain_node in node_type_overlapping_node_sort.items():
old_to_new_mappings = {key: id + offset for id, key in enumerate(remain_node.tolist())}
updated_dict = {value: domain_mappings[i][key] for key, value in old_to_new_mappings.items()}
print(f'{i} has {len(updated_dict)} nodes left')
domain_old_to_new[i] = old_to_new_mappings
id_to_mag.update(updated_dict)
new_domain_mappings[i] = updated_dict
offset += len(node_type_overlapping_node[i])
# check last index equals tot_n
assert offset == tot_n
edges_full = torch.cat([node_type_overlapping_edge[i] for i in range(4)], dim=1)
# re-index edges
# Different types of nodes all start from 0, need to re-index according to types
d_of_mapping_dict = {
0: [domain_old_to_new[0], domain_old_to_new[3]],
1: [domain_old_to_new[0], domain_old_to_new[1]],
2: [domain_old_to_new[3], domain_old_to_new[2]],
3: [domain_old_to_new[3], domain_old_to_new[3]]
}
for i, remain_edge in tqdm(node_type_overlapping_edge.items()):
edges = remain_edge[:2]
edge_types = remain_edge[2]
new_edges = edges.clone()
dict1 = d_of_mapping_dict[i][0]
dict2 = d_of_mapping_dict[i][1]
# Update the first dimension using dict1
for old, new in dict1.items():
new_edges[0, edges[0] == old] = new
# Update the second dimension using dict2
for old, new in dict2.items():
new_edges[1, edges[1] == old] = new
final_edges = torch.cat([new_edges, edge_types.unsqueeze(0)], dim=0)
node_type_overlapping_edge[i] = final_edges
edges_final = torch.cat([node_type_overlapping_edge[i] for i in range(4)], dim=1)
assert edges_final.shape == edges_full.shape
edge_index = torch.LongTensor(edges_final[:2])
edge_types = torch.LongTensor(edges_final[2])
# re-index nodes
author_data['new_id'] = author_data['id'].map(domain_old_to_new[0])
author_data.dropna(subset=['new_id'], inplace=True)
author_data['new_id'] = author_data['new_id'].astype(int)
institution_data['new_id'] = institution_data['id'].map(domain_old_to_new[1])
institution_data.dropna(subset=['new_id'], inplace=True)
institution_data['new_id'] = institution_data['new_id'].astype(int)
field_of_study_data['new_id'] = field_of_study_data['id'].map(domain_old_to_new[2])
field_of_study_data.dropna(subset=['new_id'], inplace=True)
field_of_study_data['new_id'] = field_of_study_data['new_id'].astype(int)
paper_data['new_id'] = paper_data['id'].map(domain_old_to_new[3])
paper_data.dropna(subset=['new_id'], inplace=True)
paper_data['new_id'] = paper_data['new_id'].astype(int)
# add text data onto the graph (paper nodes)
merged_df = pd.merge(paper_data, paper_text_data, on='mag_id', how='outer')
merged_df.dropna(subset=['new_id'], inplace=True)
merged_df['new_id'] = merged_df['new_id'].astype(int)
merged_df['mag_id'] = merged_df['mag_id'].astype(int)
merged_df = merged_df.drop_duplicates(subset=['new_id'])
# record node_info into dict
node_frame = {0: author_data, 1: institution_data, 2: field_of_study_data, 3: merged_df}
node_info = {}
node_types = []
for node_type, frame in tqdm(node_frame.items()):
for idx, row in frame.iterrows():
# csv_row to dict
node_info[row['new_id']] = row.to_dict()
node_types.append(node_type)
node_types = torch.tensor(node_types)
if len(node_types) != tot_n:
raise ValueError('node_types length does not match tot_n')
processed_data = {
'node_info': node_info,
'edge_index': edge_index,
'edge_types': edge_types,
'node_types': node_types
}
print('Start saving processed data...')
save_files(save_path=self.processed_data_dir, **processed_data)
return processed_data