import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Dictionary to cache loaded Hugging Face models and tokenizers
loaded_hf_models = {}
[docs]def complete_text_hf(message: str,
model: str = "huggingface/codellama/CodeLlama-7b-hf",
max_tokens: int = 2000,
temperature: float = 0.5,
json_object: bool = False,
max_retry: int = 1,
sleep_time: int = 0,
stop_sequences: list = [],
**kwargs) -> str:
"""
Generate text completion using a specified Hugging Face model.
Args:
message (str): The input text message for completion.
model (str): The Hugging Face model to use. Default is "huggingface/codellama/CodeLlama-7b-hf".
max_tokens (int): The maximum number of tokens to generate. Default is 2000.
temperature (float): Sampling temperature for generation. Default is 0.5.
json_object (bool): Whether to format the message for JSON output. Default is False.
max_retry (int): Maximum number of retries in case of an error. Default is 1.
sleep_time (int): Sleep time between retries in seconds. Default is 0.
stop_sequences (list): List of stop sequences to halt the generation.
**kwargs: Additional keyword arguments for the `generate` function.
Returns:
str: The generated text completion.
"""
if json_object:
message = "You are a helpful assistant designed to output in JSON format." + message
# Determine the device to run the model on (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = model.split("/", 1)[1]
# Load the model and tokenizer if not already loaded
if model_name in loaded_hf_models:
hf_model, tokenizer = loaded_hf_models[model_name]
else:
hf_model = AutoModelForCausalLM.from_pretrained(model).to(device)
tokenizer = AutoTokenizer.from_pretrained(model)
loaded_hf_models[model_name] = (hf_model, tokenizer)
# Encode the input message
encoded_input = tokenizer(message, return_tensors="pt", return_token_type_ids=False).to(device)
for cnt in range(max_retry):
try:
# Generate text completion
output = hf_model.generate(
**encoded_input,
temperature=temperature,
max_new_tokens=max_tokens,
do_sample=True,
return_dict_in_generate=True,
output_scores=True,
**kwargs,
)
# Decode the generated sequences
sequences = output.sequences
sequences = [sequence[len(encoded_input.input_ids[0]):] for sequence in sequences]
all_decoded_text = tokenizer.batch_decode(sequences)
completion = all_decoded_text[0]
return completion
except Exception as e:
print(f"Retry {cnt}: {e}")
time.sleep(sleep_time)
raise RuntimeError("Failed to generate text completion after max retries")