File size: 4,310 Bytes
fa9878d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
def load_model(model_path="./"):
"""
Load the Llama 3 Dementia Care model and tokenizer.
Args:
model_path (str): Path to the model directory
Returns:
tuple: (model, tokenizer)
"""
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
return model, tokenizer
def generate_response(model, tokenizer, prompt, max_new_tokens=256, temperature=0.7, top_p=0.9, top_k=50):
"""
Generate a response using the dementia care model.
Args:
model: The loaded model
tokenizer: The loaded tokenizer
prompt (str): The user's question or prompt
max_new_tokens (int): Maximum number of new tokens to generate
temperature (float): Sampling temperature
top_p (float): Nucleus sampling parameter
top_k (int): Top-k sampling parameter
Returns:
str: The model's response
"""
# Prepare the conversation with system prompt
messages = [
{
"role": "system",
"content": "You are a specialized assistant for dementia and memory care. Provide compassionate, accurate, and helpful information about dementia, Alzheimer's disease, caregiving strategies, and support resources. Always be empathetic and practical in your responses."
},
{
"role": "user",
"content": prompt
}
]
# Apply chat template
input_ids = tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True
)
# Generate response
with torch.no_grad():
outputs = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=1.1,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Decode the response
response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
return response.strip()
def interactive_demo():
"""
Run an interactive demo of the dementia care model.
"""
print("Loading Llama 3 Dementia Care Assistant...")
model, tokenizer = load_model()
print("Model loaded successfully!\n")
print("Llama 3 Dementia Care Assistant")
print("=" * 40)
print("This model provides specialized guidance for dementia and memory care.")
print("Ask questions about caregiving, communication, safety, or support resources.")
print("Type 'quit' to exit.\n")
while True:
user_input = input("You: ").strip()
if user_input.lower() in ['quit', 'exit', 'bye']:
print("Thank you for using the Dementia Care Assistant. Take care!")
break
if not user_input:
continue
print("\nAssistant: ", end="")
response = generate_response(model, tokenizer, user_input)
print(response)
print("\n" + "-" * 60 + "\n")
def example_usage():
"""
Demonstrate example usage of the model.
"""
print("Loading model for examples...")
model, tokenizer = load_model()
examples = [
"What are some effective strategies for helping someone with dementia maintain their daily routine?",
"How should I communicate with my mother who has Alzheimer's disease when she becomes confused?",
"What safety modifications should I make to my home for someone with dementia?",
"How can I handle agitation and restlessness in dementia patients?"
]
print("Example responses from the Dementia Care Assistant:")
print("=" * 60)
for i, example in enumerate(examples, 1):
print(f"\n{i}. Question: {example}")
print(f" Answer: {generate_response(model, tokenizer, example)}")
print("-" * 60)
if __name__ == "__main__":
import sys
if len(sys.argv) > 1 and sys.argv[1] == "examples":
example_usage()
else:
interactive_demo()
|