Source code for stark_qa.tools.api_lib.huggingface

import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Dictionary to cache loaded Hugging Face models and tokenizers
loaded_hf_models = {}


[docs]def complete_text_hf(message: str, model: str = "huggingface/codellama/CodeLlama-7b-hf", max_tokens: int = 2000, temperature: float = 0.5, json_object: bool = False, max_retry: int = 1, sleep_time: int = 0, stop_sequences: list = [], **kwargs) -> str: """ Generate text completion using a specified Hugging Face model. Args: message (str): The input text message for completion. model (str): The Hugging Face model to use. Default is "huggingface/codellama/CodeLlama-7b-hf". max_tokens (int): The maximum number of tokens to generate. Default is 2000. temperature (float): Sampling temperature for generation. Default is 0.5. json_object (bool): Whether to format the message for JSON output. Default is False. max_retry (int): Maximum number of retries in case of an error. Default is 1. sleep_time (int): Sleep time between retries in seconds. Default is 0. stop_sequences (list): List of stop sequences to halt the generation. **kwargs: Additional keyword arguments for the `generate` function. Returns: str: The generated text completion. """ if json_object: message = "You are a helpful assistant designed to output in JSON format." + message # Determine the device to run the model on (GPU if available, otherwise CPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_name = model.split("/", 1)[1] # Load the model and tokenizer if not already loaded if model_name in loaded_hf_models: hf_model, tokenizer = loaded_hf_models[model_name] else: hf_model = AutoModelForCausalLM.from_pretrained(model).to(device) tokenizer = AutoTokenizer.from_pretrained(model) loaded_hf_models[model_name] = (hf_model, tokenizer) # Encode the input message encoded_input = tokenizer(message, return_tensors="pt", return_token_type_ids=False).to(device) for cnt in range(max_retry): try: # Generate text completion output = hf_model.generate( **encoded_input, temperature=temperature, max_new_tokens=max_tokens, do_sample=True, return_dict_in_generate=True, output_scores=True, **kwargs, ) # Decode the generated sequences sequences = output.sequences sequences = [sequence[len(encoded_input.input_ids[0]):] for sequence in sequences] all_decoded_text = tokenizer.batch_decode(sequences) completion = all_decoded_text[0] return completion except Exception as e: print(f"Retry {cnt}: {e}") time.sleep(sleep_time) raise RuntimeError("Failed to generate text completion after max retries")