stark_qa
stark_qa.evaluator
- class stark_qa.evaluator.Evaluator(candidate_ids)[source]
Bases:
object
- evaluate(pred_dict, answer_ids, metrics=['mrr', 'hit@3', 'recall@20'])[source]
Evaluates the predictions using the specified metrics.
- Parameters:
pred_dict (Dict[int, float]) – Dictionary of predicted scores.
answer_ids (torch.LongTensor) – Ground truth answer IDs.
metrics (List[str]) – A list of metrics to be evaluated, including ‘mrr’, ‘hit@k’, ‘recall@k’, ‘precision@k’, ‘map@k’, ‘ndcg@k’.
- Returns:
Dictionary of evaluation metrics.
- Return type:
Dict[str, float]
stark_qa.load_qa
- stark_qa.load_qa.load_qa(name, root=None, human_generated_eval=False)[source]
Load the QA dataset.
- Parameters:
name (str) – Name of the dataset. One of ‘amazon’, ‘prime’, or ‘mag’.
root (Union[str, None]) – Root directory to store the dataset. If not provided, the default Hugging Face cache path is used.
human_generated_eval (bool) – Whether to use human-generated evaluation data. Default is False.
- Returns:
The loaded STaRK dataset.
- Return type:
STaRKDataset
- Raises:
ValueError – If the dataset name is not registered.
stark_qa.load_skb
- stark_qa.load_skb.load_skb(name, root=None, download_processed=True, **kwargs)[source]
Load the SKB dataset.
- Parameters:
name (str) – Name of the dataset. One of ‘amazon’, ‘prime’, or ‘mag’.
root (Union[str, None]) – Root directory to store the dataset. If None, defaults to the HF cache path.
download_processed (bool) – Whether to download processed data. Default is False. If True, root must be provided.
**kwargs – Additional keyword arguments for the specific dataset class.
- Return type:
SKB
- Returns:
An instance of the specified SKB dataset class.
- Raises:
ValueError – If the dataset name is not recognized.
AssertionError – If root is not provided when download_processed is False.