new-dim / usage_example.py
splendidcomputer's picture
Upload folder using huggingface_hub
fa9878d verified
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
def load_model(model_path="./"):
"""
Load the Llama 3 Dementia Care model and tokenizer.
Args:
model_path (str): Path to the model directory
Returns:
tuple: (model, tokenizer)
"""
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
return model, tokenizer
def generate_response(model, tokenizer, prompt, max_new_tokens=256, temperature=0.7, top_p=0.9, top_k=50):
"""
Generate a response using the dementia care model.
Args:
model: The loaded model
tokenizer: The loaded tokenizer
prompt (str): The user's question or prompt
max_new_tokens (int): Maximum number of new tokens to generate
temperature (float): Sampling temperature
top_p (float): Nucleus sampling parameter
top_k (int): Top-k sampling parameter
Returns:
str: The model's response
"""
# Prepare the conversation with system prompt
messages = [
{
"role": "system",
"content": "You are a specialized assistant for dementia and memory care. Provide compassionate, accurate, and helpful information about dementia, Alzheimer's disease, caregiving strategies, and support resources. Always be empathetic and practical in your responses."
},
{
"role": "user",
"content": prompt
}
]
# Apply chat template
input_ids = tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True
)
# Generate response
with torch.no_grad():
outputs = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=1.1,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Decode the response
response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
return response.strip()
def interactive_demo():
"""
Run an interactive demo of the dementia care model.
"""
print("Loading Llama 3 Dementia Care Assistant...")
model, tokenizer = load_model()
print("Model loaded successfully!\n")
print("Llama 3 Dementia Care Assistant")
print("=" * 40)
print("This model provides specialized guidance for dementia and memory care.")
print("Ask questions about caregiving, communication, safety, or support resources.")
print("Type 'quit' to exit.\n")
while True:
user_input = input("You: ").strip()
if user_input.lower() in ['quit', 'exit', 'bye']:
print("Thank you for using the Dementia Care Assistant. Take care!")
break
if not user_input:
continue
print("\nAssistant: ", end="")
response = generate_response(model, tokenizer, user_input)
print(response)
print("\n" + "-" * 60 + "\n")
def example_usage():
"""
Demonstrate example usage of the model.
"""
print("Loading model for examples...")
model, tokenizer = load_model()
examples = [
"What are some effective strategies for helping someone with dementia maintain their daily routine?",
"How should I communicate with my mother who has Alzheimer's disease when she becomes confused?",
"What safety modifications should I make to my home for someone with dementia?",
"How can I handle agitation and restlessness in dementia patients?"
]
print("Example responses from the Dementia Care Assistant:")
print("=" * 60)
for i, example in enumerate(examples, 1):
print(f"\n{i}. Question: {example}")
print(f" Answer: {generate_response(model, tokenizer, example)}")
print("-" * 60)
if __name__ == "__main__":
import sys
if len(sys.argv) > 1 and sys.argv[1] == "examples":
example_usage()
else:
interactive_demo()