stark_qa

stark_qa.evaluator

class stark_qa.evaluator.Evaluator(candidate_ids)[source]

Bases: object

evaluate(pred_dict, answer_ids, metrics=['mrr', 'hit@3', 'recall@20'])[source]

Evaluates the predictions using the specified metrics.

Parameters:
  • pred_dict (Dict[int, float]) – Dictionary of predicted scores.

  • answer_ids (torch.LongTensor) – Ground truth answer IDs.

  • metrics (List[str]) – A list of metrics to be evaluated, including ‘mrr’, ‘hit@k’, ‘recall@k’, ‘precision@k’, ‘map@k’, ‘ndcg@k’.

Returns:

Dictionary of evaluation metrics.

Return type:

Dict[str, float]

stark_qa.load_qa

stark_qa.load_qa.load_qa(name, root=None, human_generated_eval=False)[source]

Load the QA dataset.

Parameters:
  • name (str) – Name of the dataset. One of ‘amazon’, ‘prime’, or ‘mag’.

  • root (Union[str, None]) – Root directory to store the dataset. If not provided, the default Hugging Face cache path is used.

  • human_generated_eval (bool) – Whether to use human-generated evaluation data. Default is False.

Returns:

The loaded STaRK dataset.

Return type:

STaRKDataset

Raises:

ValueError – If the dataset name is not registered.

stark_qa.load_skb

stark_qa.load_skb.load_skb(name, root=None, download_processed=True, **kwargs)[source]

Load the SKB dataset.

Parameters:
  • name (str) – Name of the dataset. One of ‘amazon’, ‘prime’, or ‘mag’.

  • root (Union[str, None]) – Root directory to store the dataset. If None, defaults to the HF cache path.

  • download_processed (bool) – Whether to download processed data. Default is False. If True, root must be provided.

  • **kwargs – Additional keyword arguments for the specific dataset class.

Return type:

SKB

Returns:

An instance of the specified SKB dataset class.

Raises:
  • ValueError – If the dataset name is not recognized.

  • AssertionError – If root is not provided when download_processed is False.