Source code for stark_qa.retrieval.dataset

import copy
import os.path as osp

import pandas as pd
from typing import Union
import torch

from stark_qa.tools.download_hf import download_hf_folder


STARK_QA_DATASET = {
    "repo": "snap-stanford/stark",
    "folder": "qa"
}

[docs]class STaRKDataset: def __init__(self, name: str, root: Union[str, None] = None, human_generated_eval: bool = False): """ Initialize the STaRK dataset. Args: name (str): Name of the dataset. root (Union[str, None]): Root directory to store the dataset. If None, default HF cache paths will be used. human_generated_eval (bool): Whether to use human-generated evaluation data. """ self.name = name self.root = root self.dataset_root = osp.join(self.root, name) if self.root is not None else None self._download() self.split_dir = osp.join(self.dataset_root, 'split') self.query_dir = osp.join(self.dataset_root, 'stark_qa') self.human_generated_eval = human_generated_eval self.qa_csv_path = osp.join( self.query_dir, 'stark_qa_human_generated_eval.csv' if human_generated_eval else 'stark_qa.csv' ) self.data = pd.read_csv(self.qa_csv_path) self.indices = sorted(self.data['id'].tolist()) self.split_indices = self.get_idx_split() def __len__(self) -> int: """ Return the number of queries in the dataset. Returns: int: Number of queries. """ return len(self.indices) def __getitem__(self, idx: int): """ Get the query, id, answer ids, and meta information for a given index. Args: idx (int): Index of the query. Returns: tuple: Query, query id, answer ids, and meta information. """ q_id = self.indices[idx] row = self.data[self.data['id'] == q_id].iloc[0] query = row['query'] answer_ids = eval(row['answer_ids']) meta_info = None # Replace with actual meta information if available return query, q_id, answer_ids, meta_info def _download(self): """ Download the dataset from the Hugging Face repository. """ self.dataset_root = download_hf_folder( STARK_QA_DATASET["repo"], osp.join(STARK_QA_DATASET["folder"], self.name), repo_type="dataset", save_as_folder=self.dataset_root, )
[docs] def get_idx_split(self, test_ratio: float = 1.0) -> dict: """ Return the indices of train/val/test split in a dictionary. Args: test_ratio (float): Ratio of test data to include. Returns: dict: Dictionary with split indices for train, val, and test sets. """ if self.human_generated_eval: return {'human_generated_eval': torch.LongTensor(self.indices)} split_idx = {} for split in ['train', 'val', 'test']: indices_file = osp.join(self.split_dir, f'{split}.index') with open(indices_file, 'r') as f: indices = f.read().strip().split('\n') query_ids = [int(idx) for idx in indices] split_idx[split] = torch.LongTensor([self.indices.index(query_id) for query_id in query_ids]) if test_ratio < 1.0: split_idx['test'] = split_idx['test'][:int(len(split_idx['test']) * test_ratio)] return split_idx
[docs] def get_query_by_qid(self, q_id: int) -> str: """ Return the query by query id. Args: q_id (int): Query id. Returns: str: Query string. """ row = self.data[self.data['id'] == q_id].iloc[0] return row['query']
[docs] def get_subset(self, split: str): """ Return a subset of the dataset. Args: split (str): Split type ('train', 'val', 'test'). Returns: STaRKDataset: Subset of the dataset. """ assert split in ['train', 'val', 'test'], "Invalid split specified." indices_file = osp.join(self.split_dir, f'{split}.index') with open(indices_file, 'r') as f: indices = f.read().strip().split('\n') subset = copy.deepcopy(self) subset.indices = [int(idx) for idx in indices] return subset