| |
| import os |
| import gc |
| import torch |
| from transformers import LlamaTokenizer, LlamaForCausalLM |
|
|
| |
| |
| |
| MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained" |
| MAX_NEW_TOKENS = 200 |
| TEMPERATURE = 0.5 |
| TOP_K = 50 |
| REPETITION_PENALTY = 1.1 |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Loading model from {MODEL_PATH} on {device}...") |
|
|
| |
| |
| |
| tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH) |
|
|
| model = LlamaForCausalLM.from_pretrained( |
| MODEL_PATH, |
| device_map="auto", |
| torch_dtype=torch.float16, |
| low_cpu_mem_usage=True |
| ) |
|
|
| |
| generator = model.generate |
| print("✅ Model loaded successfully!\n") |
|
|
| |
| |
| |
| systemprompt = ("""You are ChatDoctor — an intelligent, empathetic medical AI assistant. |
| Your role is to carefully gather medical information, reason clinically, |
| and provide safe, evidence-based guidance. |
| |
| Follow these instructions strictly: |
| 1. When a patient describes their illness, DO NOT diagnose immediately. |
| 2. Ask relevant, targeted questions to collect all necessary details |
| such as symptoms, duration, severity, lifestyle habits, medical history, |
| medications, and any recent tests or changes. |
| 3. Once you have enough information for a preliminary diagnosis, clearly |
| explain your reasoning and possible causes in simple medical language. |
| 4. Then, provide a clear and structured response that includes: |
| - **Diagnosis:** probable or confirmed condition(s) |
| - **Dietary Advice:** foods to include and avoid |
| - **Lifestyle Advice:** exercise, sleep, stress, and other habits |
| 5. Be concise, empathetic, and professional at all times. |
| 6. Never switch roles or generate “Patient:” responses. Always remain as ChatDoctor. |
| 7. If symptoms suggest a serious or emergency condition, advise the patient |
| to seek immediate medical attention.""") |
|
|
| history = [systemprompt, "ChatDoctor: I am ChatDoctor, what medical questions do you have?"] |
|
|
| |
| |
| |
| def get_response(user_input): |
| global history |
| human_invitation = "Patient: " |
| doctor_invitation = "ChatDoctor: " |
|
|
| |
| history.append(human_invitation + user_input) |
|
|
| |
| prompt = "\n".join(history) + "\n" + doctor_invitation |
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) |
|
|
| |
| with torch.no_grad(): |
| output_ids = generator( |
| input_ids, |
| max_new_tokens=MAX_NEW_TOKENS, |
| do_sample=True, |
| temperature=TEMPERATURE, |
| top_k=TOP_K, |
| repetition_penalty=REPETITION_PENALTY |
| ) |
|
|
| |
| full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
| response = full_output[len(prompt):].strip() |
|
|
| |
| if response.startswith("Patient:"): |
| response = response[len("Patient:"):].strip() |
|
|
| |
| history.append(doctor_invitation + response) |
|
|
| |
| del input_ids, output_ids |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| return response |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| print("\n=== ChatDoctor is ready! Type your questions. ===\n") |
| while True: |
| try: |
| user_input = input("Patient: ").strip() |
| if user_input.lower() in ["exit", "quit"]: |
| print("Exiting ChatDoctor. Goodbye!") |
| break |
| response = get_response(user_input) |
| print("ChatDoctor: " + response + "\n") |
| except KeyboardInterrupt: |
| print("\nExiting ChatDoctor. Goodbye!") |
| break |
|
|