import os
import os.path as osp
from typing import Union
from stark_qa.skb import SKB, AmazonSKB, PrimeSKB, MagSKB
[docs]def load_skb(name: str,
root: Union[str, None] = None,
download_processed: bool = True,
**kwargs) -> SKB:
"""
Load the SKB dataset.
Args:
name (str): Name of the dataset. One of 'amazon', 'prime', or 'mag'.
root (Union[str, None]): Root directory to store the dataset. If None, defaults to the HF cache path.
download_processed (bool): Whether to download processed data. Default is False. If True, `root` must be provided.
**kwargs: Additional keyword arguments for the specific dataset class.
Returns:
An instance of the specified SKB dataset class.
Raises:
ValueError: If the dataset name is not recognized.
AssertionError: If `root` is not provided when `download_processed` is False.
"""
if not download_processed:
assert root is not None, "root must be provided if download_processed is False"
if root is None:
data_root = None
else:
root = os.path.abspath(root)
data_root = osp.join(root, name)
if name == 'amazon':
categories = ['Sports_and_Outdoors']
skb = AmazonSKB(root=data_root,
categories=categories,
download_processed=download_processed,
**kwargs
)
elif name == 'prime':
skb = PrimeSKB(root=data_root,
download_processed=download_processed,
**kwargs)
elif name == 'mag':
skb = MagSKB(root=data_root,
download_processed=download_processed,
**kwargs)
else:
raise ValueError(f"Unknown dataset {name}")
return skb