import torch from transformers import AutoTokenizer, AutoModelForCausalLM def load_model(model_path="./"): """ Load the Llama 3 Dementia Care model and tokenizer. Args: model_path (str): Path to the model directory Returns: tuple: (model, tokenizer) """ tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) return model, tokenizer def generate_response(model, tokenizer, prompt, max_new_tokens=256, temperature=0.7, top_p=0.9, top_k=50): """ Generate a response using the dementia care model. Args: model: The loaded model tokenizer: The loaded tokenizer prompt (str): The user's question or prompt max_new_tokens (int): Maximum number of new tokens to generate temperature (float): Sampling temperature top_p (float): Nucleus sampling parameter top_k (int): Top-k sampling parameter Returns: str: The model's response """ # Prepare the conversation with system prompt messages = [ { "role": "system", "content": "You are a specialized assistant for dementia and memory care. Provide compassionate, accurate, and helpful information about dementia, Alzheimer's disease, caregiving strategies, and support resources. Always be empathetic and practical in your responses." }, { "role": "user", "content": prompt } ] # Apply chat template input_ids = tokenizer.apply_chat_template( messages, return_tensors="pt", add_generation_prompt=True ) # Generate response with torch.no_grad(): outputs = model.generate( input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=1.1, do_sample=True, pad_token_id=tokenizer.eos_token_id ) # Decode the response response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True) return response.strip() def interactive_demo(): """ Run an interactive demo of the dementia care model. """ print("Loading Llama 3 Dementia Care Assistant...") model, tokenizer = load_model() print("Model loaded successfully!\n") print("Llama 3 Dementia Care Assistant") print("=" * 40) print("This model provides specialized guidance for dementia and memory care.") print("Ask questions about caregiving, communication, safety, or support resources.") print("Type 'quit' to exit.\n") while True: user_input = input("You: ").strip() if user_input.lower() in ['quit', 'exit', 'bye']: print("Thank you for using the Dementia Care Assistant. Take care!") break if not user_input: continue print("\nAssistant: ", end="") response = generate_response(model, tokenizer, user_input) print(response) print("\n" + "-" * 60 + "\n") def example_usage(): """ Demonstrate example usage of the model. """ print("Loading model for examples...") model, tokenizer = load_model() examples = [ "What are some effective strategies for helping someone with dementia maintain their daily routine?", "How should I communicate with my mother who has Alzheimer's disease when she becomes confused?", "What safety modifications should I make to my home for someone with dementia?", "How can I handle agitation and restlessness in dementia patients?" ] print("Example responses from the Dementia Care Assistant:") print("=" * 60) for i, example in enumerate(examples, 1): print(f"\n{i}. Question: {example}") print(f" Answer: {generate_response(model, tokenizer, example)}") print("-" * 60) if __name__ == "__main__": import sys if len(sys.argv) > 1 and sys.argv[1] == "examples": example_usage() else: interactive_demo()