| import os |
| import gc |
| import torch |
| from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList |
| from huggingface_hub import login |
| import os |
|
|
| |
| login(token=os.getenv("HUGGINGFACE_TOKEN")) |
| |
| |
| |
| MODEL_PATH = r"zl111/ChatDoctor" |
| MAX_NEW_TOKENS = 200 |
| TEMPERATURE = 0.5 |
| TOP_K = 50 |
| REPETITION_PENALTY = 1.1 |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Loading model from {MODEL_PATH} on {device}...") |
|
|
| |
| |
| |
| tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH) |
| model = LlamaForCausalLM.from_pretrained( |
| MODEL_PATH, |
| device_map="auto", |
| torch_dtype=torch.float16, |
| low_cpu_mem_usage=True |
| ) |
|
|
| generator = model.generate |
| print("✅ ChatDoctor model loaded successfully!\n") |
|
|
| |
| |
| |
| class StopOnTokens(StoppingCriteria): |
| def __init__(self, stop_ids): |
| self.stop_ids = stop_ids |
|
|
| def __call__(self, input_ids, scores, **kwargs): |
| for stop_id_seq in self.stop_ids: |
| if len(stop_id_seq) == 1: |
| if input_ids[0][-1] == stop_id_seq[0]: |
| return True |
| else: |
| if len(input_ids[0]) >= len(stop_id_seq): |
| if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq: |
| return True |
| return False |
|
|
| |
| |
| |
| history = ["ChatDoctor: I am ChatDoctor, your AI medical assistant. How can I help you today?"] |
|
|
| |
| |
| |
| def get_response(user_input): |
| global history |
| human_invitation = "Patient: " |
| doctor_invitation = "ChatDoctor: " |
|
|
| |
| history.append(human_invitation + user_input) |
|
|
| |
| prompt = "\n".join(history) + "\n" + doctor_invitation |
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) |
|
|
| |
| stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"] |
| stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words] |
| stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)]) |
|
|
| |
| with torch.no_grad(): |
| output_ids = generator( |
| input_ids, |
| max_new_tokens=MAX_NEW_TOKENS, |
| do_sample=True, |
| temperature=TEMPERATURE, |
| top_k=TOP_K, |
| repetition_penalty=REPETITION_PENALTY, |
| stopping_criteria=stopping_criteria, |
| pad_token_id=tokenizer.eos_token_id, |
| eos_token_id=tokenizer.eos_token_id |
| ) |
|
|
| |
| full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
| response = full_output[len(prompt):].strip() |
| |
| |
| for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]: |
| if stop_word in response: |
| response = response.split(stop_word)[0].strip() |
| break |
|
|
| |
| response = response.strip() |
|
|
| history.append(doctor_invitation + response) |
|
|
| |
| del input_ids, output_ids |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| return response |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| print("\n=== ChatDoctor is ready! ===") |
| print("You (the human) = Patient ") |
| print("AI = ChatDoctor") |
| print("Type 'exit' or 'quit' to end the chat.\n") |
|
|
| print("ChatDoctor: Hi there! How can I help you today?\n") |
|
|
| while True: |
| try: |
| user_input = input("Patient: ").strip() |
| if user_input.lower() in ["exit", "quit"]: |
| print("ChatDoctor: Take care! Goodbye ") |
| break |
|
|
| if not user_input: |
| continue |
|
|
| response = get_response(user_input) |
| print("ChatDoctor:", response, "\n") |
|
|
| except KeyboardInterrupt: |
| print("\nChatDoctor: Take care! Goodbye") |
| break |
| except Exception as e: |
| print(f"Error: {e}") |
| print("Please try again.\n") |