| |
| import gradio as gr |
| from unsloth import FastLanguageModel |
| import torch |
| from PIL import Image |
| from transformers import TextIteratorStreamer |
| from threading import Thread |
| import os |
|
|
| |
| |
| BASE_MODEL_NAME = "unsloth/gemma-3n-E4B-it" |
|
|
| |
| PEFT_MODEL_NAME = "lyimo/mosquito-breeding-detection" |
|
|
| |
| MAX_SEQ_LENGTH = 2048 |
|
|
| |
| print("Loading base model...") |
| model, tokenizer = FastLanguageModel.from_pretrained( |
| model_name=BASE_MODEL_NAME, |
| max_seq_length=MAX_SEQ_LENGTH, |
| dtype=None, |
| load_in_4bit=True, |
| ) |
|
|
| print("Loading LoRA adapters...") |
| model = FastLanguageModel.get_peft_model(model, peft_model_name=PEFT_MODEL_NAME) |
|
|
| print("Setting up chat template...") |
| from unsloth.chat_templates import get_chat_template |
| tokenizer = get_chat_template(tokenizer, chat_template="gemma-3") |
|
|
| print("Model and tokenizer loaded successfully!") |
|
|
|
|
| |
| def analyze_image(image, prompt): |
| """ |
| Analyzes the image using the fine-tuned model and streams the output. |
| """ |
| if image is None: |
| return "Please upload an image." |
|
|
| temp_image_path = "temp_uploaded_image.jpg" |
| try: |
| image.save(temp_image_path) |
|
|
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": temp_image_path}, |
| {"type": "text", "text": prompt} |
| ] |
| } |
| ] |
|
|
| full_prompt = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
|
|
| inputs = tokenizer( |
| full_prompt, |
| return_tensors="pt", |
| ).to(model.device) |
|
|
| |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
| |
| generation_kwargs = dict( |
| **inputs, |
| streamer=streamer, |
| max_new_tokens=1024, |
| |
| |
| |
| |
| ) |
|
|
| |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) |
| thread.start() |
|
|
| |
| generated_text = "" |
| for new_text in streamer: |
| generated_text += new_text |
| yield generated_text |
|
|
| except Exception as e: |
| error_msg = f"An error occurred during processing: {str(e)}" |
| print(error_msg) |
| yield error_msg |
| finally: |
| |
| if os.path.exists(temp_image_path): |
| os.remove(temp_image_path) |
|
|
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("# 🦟 Mosquito Breeding Site Detector") |
| gr.Markdown("Upload an image and ask the AI to analyze it for potential mosquito breeding sites.") |
| with gr.Row(): |
| with gr.Column(): |
| image_input = gr.Image(type="pil", label="Upload Image") |
| prompt_input = gr.Textbox( |
| label="Your Question", |
| value="Can you analyze this image for mosquito breeding sites and recommend what to do?", |
| lines=2 |
| ) |
| submit_btn = gr.Button("Analyze") |
| with gr.Column(): |
| output_text = gr.Textbox(label="Analysis Result", interactive=False, lines=15) |
|
|
| |
| |
| |
| submit_btn.click( |
| fn=analyze_image, |
| inputs=[image_input, prompt_input], |
| outputs=output_text |
| ) |
|
|
| |
| if __name__ == "__main__": |
| demo.launch() |