from typing import List, Dict
import torch
from torchmetrics.functional import (
retrieval_hit_rate, retrieval_reciprocal_rank, retrieval_recall,
retrieval_precision, retrieval_average_precision, retrieval_normalized_dcg,
retrieval_r_precision
)
[docs]class Evaluator:
def __init__(self, candidate_ids: List[int]):
"""
Initializes the evaluator with the given candidate IDs.
Args:
candidate_ids (List[int]): List of candidate IDs.
"""
self.candidate_ids = candidate_ids
def __call__(self,
pred_dict: Dict[int, float],
answer_ids: torch.LongTensor,
metrics: List[str] = ['mrr', 'hit@3', 'recall@20']) -> Dict[str, float]:
"""
Evaluates the predictions using the specified metrics.
Args:
pred_dict (Dict[int, float]): Dictionary of predicted scores.
answer_ids (torch.LongTensor): Ground truth answer IDs.
metrics (List[str]): List of metrics to be evaluated, including 'mrr', 'hit@k', 'recall@k',
'precision@k', 'map@k', 'ndcg@k'.
Returns:
Dict[str, float]: Dictionary of evaluation metrics.
"""
return self.evaluate(pred_dict, answer_ids, metrics)
[docs] def evaluate(self,
pred_dict: Dict[int, float],
answer_ids: torch.LongTensor,
metrics: List[str] = ['mrr', 'hit@3', 'recall@20']) -> Dict[str, float]:
"""
Evaluates the predictions using the specified metrics.
Args:
pred_dict (Dict[int, float]): Dictionary of predicted scores.
answer_ids (torch.LongTensor): Ground truth answer IDs.
metrics (List[str]): A list of metrics to be evaluated, including 'mrr', 'hit@k', 'recall@k',
'precision@k', 'map@k', 'ndcg@k'.
Returns:
Dict[str, float]: Dictionary of evaluation metrics.
"""
# Convert prediction dictionary to tensor
pred_ids = torch.LongTensor(list(pred_dict.keys())).view(-1)
pred = torch.FloatTensor(list(pred_dict.values())).view(-1)
answer_ids = answer_ids.view(-1)
# Initialize all predictions to a very low value
all_pred = torch.ones(max(self.candidate_ids) + 1, dtype=torch.float) * (min(pred) - 1)
all_pred[pred_ids] = pred
all_pred = all_pred[self.candidate_ids]
# Initialize ground truth boolean tensor
bool_gd = torch.zeros(max(self.candidate_ids) + 1, dtype=torch.bool)
bool_gd[answer_ids] = True
bool_gd = bool_gd[self.candidate_ids]
# Compute evaluation metrics
eval_metrics = {}
for metric in metrics:
k = int(metric.split('@')[-1]) if '@' in metric else None
if metric == 'mrr':
result = retrieval_reciprocal_rank(all_pred, bool_gd)
elif metric == 'rprecision':
result = retrieval_r_precision(all_pred, bool_gd)
elif 'hit' in metric:
result = retrieval_hit_rate(all_pred, bool_gd, top_k=k)
elif 'recall' in metric:
result = retrieval_recall(all_pred, bool_gd, top_k=k)
elif 'precision' in metric:
result = retrieval_precision(all_pred, bool_gd, top_k=k)
elif 'map' in metric:
result = retrieval_average_precision(all_pred, bool_gd, top_k=k)
elif 'ndcg' in metric:
result = retrieval_normalized_dcg(all_pred, bool_gd, top_k=k)
eval_metrics[metric] = float(result)
return eval_metrics