File size: 4,310 Bytes
fa9878d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def load_model(model_path="./"):
    """
    Load the Llama 3 Dementia Care model and tokenizer.
    
    Args:
        model_path (str): Path to the model directory
        
    Returns:
        tuple: (model, tokenizer)
    """
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    return model, tokenizer

def generate_response(model, tokenizer, prompt, max_new_tokens=256, temperature=0.7, top_p=0.9, top_k=50):
    """
    Generate a response using the dementia care model.
    
    Args:
        model: The loaded model
        tokenizer: The loaded tokenizer
        prompt (str): The user's question or prompt
        max_new_tokens (int): Maximum number of new tokens to generate
        temperature (float): Sampling temperature
        top_p (float): Nucleus sampling parameter
        top_k (int): Top-k sampling parameter
        
    Returns:
        str: The model's response
    """
    # Prepare the conversation with system prompt
    messages = [
        {
            "role": "system", 
            "content": "You are a specialized assistant for dementia and memory care. Provide compassionate, accurate, and helpful information about dementia, Alzheimer's disease, caregiving strategies, and support resources. Always be empathetic and practical in your responses."
        },
        {
            "role": "user", 
            "content": prompt
        }
    ]
    
    # Apply chat template
    input_ids = tokenizer.apply_chat_template(
        messages, 
        return_tensors="pt",
        add_generation_prompt=True
    )
    
    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            repetition_penalty=1.1,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    # Decode the response
    response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
    return response.strip()

def interactive_demo():
    """
    Run an interactive demo of the dementia care model.
    """
    print("Loading Llama 3 Dementia Care Assistant...")
    model, tokenizer = load_model()
    print("Model loaded successfully!\n")
    
    print("Llama 3 Dementia Care Assistant")
    print("=" * 40)
    print("This model provides specialized guidance for dementia and memory care.")
    print("Ask questions about caregiving, communication, safety, or support resources.")
    print("Type 'quit' to exit.\n")
    
    while True:
        user_input = input("You: ").strip()
        if user_input.lower() in ['quit', 'exit', 'bye']:
            print("Thank you for using the Dementia Care Assistant. Take care!")
            break
        
        if not user_input:
            continue
            
        print("\nAssistant: ", end="")
        response = generate_response(model, tokenizer, user_input)
        print(response)
        print("\n" + "-" * 60 + "\n")

def example_usage():
    """
    Demonstrate example usage of the model.
    """
    print("Loading model for examples...")
    model, tokenizer = load_model()
    
    examples = [
        "What are some effective strategies for helping someone with dementia maintain their daily routine?",
        "How should I communicate with my mother who has Alzheimer's disease when she becomes confused?",
        "What safety modifications should I make to my home for someone with dementia?",
        "How can I handle agitation and restlessness in dementia patients?"
    ]
    
    print("Example responses from the Dementia Care Assistant:")
    print("=" * 60)
    
    for i, example in enumerate(examples, 1):
        print(f"\n{i}. Question: {example}")
        print(f"   Answer: {generate_response(model, tokenizer, example)}")
        print("-" * 60)

if __name__ == "__main__":
    import sys
    
    if len(sys.argv) > 1 and sys.argv[1] == "examples":
        example_usage()
    else:
        interactive_demo()