| """ |
| SmolVLM Webcam Auto Inference (Fine-tuned) |
| 3์ด๋ง๋ค ์๋์ผ๋ก inference ์ํ |
| Fine-tuned on Hair classification & description dataset |
| """ |
|
|
| import torch |
| from PIL import Image |
| from transformers import AutoProcessor, AutoModelForImageTextToText |
| from peft import PeftModel |
| import gradio as gr |
| import numpy as np |
| from datetime import datetime |
| import time |
|
|
| |
| |
| |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| BASE_MODEL_ID = "HuggingFaceTB/SmolVLM-256M-Instruct" |
| FINETUNED_MODEL_PATH = "/root/crying_cv_vlm/checkpoint-105" |
| INFERENCE_INTERVAL = 3 |
|
|
| print(f"๐ง Device: {DEVICE}") |
| print(f"๐ Fine-tuned Model: {FINETUNED_MODEL_PATH}") |
| print("Loading model...") |
|
|
| |
| |
| |
| from transformers import AutoModelForImageTextToText |
| from peft import PeftModel |
|
|
| print("1๏ธโฃ Loading base model...") |
| model = AutoModelForImageTextToText.from_pretrained( |
| BASE_MODEL_ID, |
| dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32, |
| device_map="auto", |
| attn_implementation="eager" |
| ) |
|
|
| print("2๏ธโฃ Loading fine-tuned adapter...") |
| model = PeftModel.from_pretrained( |
| model, |
| FINETUNED_MODEL_PATH, |
| device_map="auto" |
| ) |
|
|
| print("3๏ธโฃ Merging adapter...") |
| model = model.merge_and_unload() |
| model.eval() |
|
|
| print("4๏ธโฃ Loading processor...") |
| processor = AutoProcessor.from_pretrained(FINETUNED_MODEL_PATH) |
|
|
| print("โ
Model loaded!") |
| if torch.cuda.is_available(): |
| print(f"๐พ VRAM: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB") |
|
|
|
|
| def inference(image, question): |
| """์ด๋ฏธ์ง์ ์ง๋ฌธ์ ๋ฐ์ inference ์ํ""" |
| |
| if image is None: |
| return "โ ๏ธ ์น์บ ์์ ์ด๋ฏธ์ง๋ฅผ ์บก์ฒํด์ฃผ์ธ์.", "๋๊ธฐ ์ค" |
| |
| if not question or question.strip() == "": |
| question = "Describe this image in detail." |
| |
| try: |
| |
| if isinstance(image, np.ndarray): |
| image = Image.fromarray(image).convert('RGB') |
| elif not isinstance(image, Image.Image): |
| return "โ ์๋ชป๋ ์ด๋ฏธ์ง ํ์", "์๋ฌ" |
| elif image.mode != 'RGB': |
| image = image.convert('RGB') |
| |
| |
| messages = [{ |
| "role": "user", |
| "content": [{"type": "image"}, {"type": "text", "text": question}] |
| }] |
| |
| |
| prompt = processor.apply_chat_template(messages, add_generation_prompt=True) |
| inputs = processor(text=prompt, images=[image], return_tensors="pt").to(DEVICE) |
| |
| |
| input_len = inputs["input_ids"].shape[-1] |
| |
| |
| with torch.inference_mode(): |
| generated_ids = model.generate( |
| **inputs, |
| max_new_tokens=100, |
| do_sample=True, |
| temperature=0.7, |
| top_p=0.9 |
| ) |
| |
| |
| generated_ids = generated_ids[0][input_len:] |
| response = processor.decode(generated_ids, skip_special_tokens=True).strip() |
| |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
| status = f"โ
{timestamp}" |
| |
| return response if response else "(๋น ์๋ต)", status |
| |
| except Exception as e: |
| import traceback |
| error_msg = traceback.format_exc() |
| return f"โ ์๋ฌ: {str(e)}\n\n{error_msg}", "์๋ฌ ๋ฐ์" |
|
|
|
|
| |
| |
| |
| with gr.Blocks(title="SmolVLM Auto Inference") as demo: |
| gr.Markdown(""" |
| # ๐ฅ SmolVLM ์น์บ ์๋ ์ถ๋ก (Fine-tuned) |
| |
| **3์ด๋ง๋ค ์๋์ผ๋ก ์ถ๋ก ์ ์ํํฉ๋๋ค** |
| |
| ### ๋ชจ๋ธ ์ ๋ณด: |
| - **Base Model**: HuggingFaceTB/SmolVLM-256M-Instruct |
| - **Fine-tuned on**: Hair classification & description dataset |
| - **Training**: 5 epochs, Final loss: 1.1350 |
| |
| ### ์ฌ์ฉ ๋ฐฉ๋ฒ: |
| 1. ์น์บ ํ์ฉ ๋ฐ ์ด๋ฏธ์ง ์บก์ฒ |
| 2. ์ง๋ฌธ ์
๋ ฅ |
| 3. "๐ ์๋ ์ถ๋ก ์์" ๋ฒํผ ํด๋ฆญ |
| 4. 3์ด๋ง๋ค ์๋์ผ๋ก ์ถ๋ก ๋ฉ๋๋ค |
| 5. "โธ๏ธ ์ค์ง" ๋ฒํผ์ผ๋ก ๋ฉ์ถ ์ ์์ต๋๋ค |
| """) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| |
| webcam = gr.Image( |
| label="๐ท ์น์บ ", |
| type="numpy", |
| sources=["webcam"], |
| streaming=True, |
| height=400 |
| ) |
| |
| |
| question = gr.Textbox( |
| label="๐ฌ ์ง๋ฌธ", |
| placeholder="์ด๋ฏธ์ง์ ๋ํด ๋ฌผ์ด๋ณด๊ณ ์ถ์ ๊ฒ์ ์
๋ ฅํ์ธ์", |
| value="Classify the hair length in this image. Possible values: short, mid, long. Output only one word.", |
| lines=3 |
| ) |
| |
| with gr.Row(): |
| start_btn = gr.Button("๐ ์๋ ์ถ๋ก ์์", variant="primary", scale=2) |
| stop_btn = gr.Button("โธ๏ธ ์ค์ง", variant="stop", scale=1) |
| |
| with gr.Column(scale=1): |
| |
| output = gr.Textbox( |
| label="๐ค ์๋ต", |
| lines=15, |
| max_lines=20 |
| ) |
| |
| |
| status = gr.Textbox( |
| label="๐ ์ํ", |
| value="๋๊ธฐ ์ค", |
| lines=1 |
| ) |
| |
| |
| auto_status = gr.Textbox( |
| label="๐ ์๋ ์ถ๋ก ์ํ", |
| value="๋ฉ์ถค", |
| lines=1 |
| ) |
| |
| |
| gr.Markdown("### ๐ก ์์ ์ง๋ฌธ:") |
| gr.Examples( |
| examples=[ |
| ["Classify the hair length in this image. Possible values: short, mid, long. Output only one word."], |
| ["Describe the person's hair style, color, and texture in detail."], |
| ["What is the hair length? Answer in one word: short, mid, or long."], |
| ["Describe what you see in this image."], |
| ["์ด ์ฌ๋์ ๋จธ๋ฆฌ ๊ธธ์ด๋ฅผ ๋ถ๋ฅํ์ธ์. ๊ฐ๋ฅํ ๊ฐ: short, mid, long"], |
| ], |
| inputs=[question], |
| ) |
| |
| |
| is_auto_running = gr.State(value=False) |
| last_inference_time = gr.State(value=0) |
| |
| def start_auto_inference(): |
| """์๋ ์ถ๋ก ์์""" |
| |
| return True, "โถ๏ธ ์คํ ์ค (3์ด ๊ฐ๊ฒฉ)", gr.Timer(value=0.5, active=True), time.time() - INFERENCE_INTERVAL |
| |
| def stop_auto_inference(): |
| """์๋ ์ถ๋ก ์ค์ง""" |
| return False, "โธ๏ธ ๋ฉ์ถค", gr.Timer(value=0.5, active=False) |
| |
| def auto_inference_loop(image, question_text, is_running, last_time): |
| """์๋ ์ถ๋ก ๋ฃจํ (3์ด๋ง๋ค ์คํ)""" |
| if not is_running: |
| return gr.update(), gr.update(), last_time |
| |
| current_time = time.time() |
| |
| |
| if image is None: |
| return gr.update(), "โ ๏ธ ์น์บ ์ด๋ฏธ์ง๋ฅผ ์บก์ฒํด์ฃผ์ธ์", last_time |
| |
| |
| if current_time - last_time >= INFERENCE_INTERVAL: |
| result, status_msg = inference(image, question_text) |
| return result, status_msg, current_time |
| else: |
| |
| remaining = INFERENCE_INTERVAL - (current_time - last_time) |
| return gr.update(), f"โฑ๏ธ ๋ค์ ์ถ๋ก ๊น์ง {remaining:.1f}์ด", last_time |
| |
| |
| timer = gr.Timer(value=0.5, active=False) |
| |
| |
| start_btn.click( |
| fn=start_auto_inference, |
| inputs=[], |
| outputs=[is_auto_running, auto_status, timer, last_inference_time] |
| ) |
| |
| |
| stop_btn.click( |
| fn=stop_auto_inference, |
| inputs=[], |
| outputs=[is_auto_running, auto_status, timer] |
| ) |
| |
| |
| timer.tick( |
| fn=auto_inference_loop, |
| inputs=[webcam, question, is_auto_running, last_inference_time], |
| outputs=[output, status, last_inference_time] |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| print("\n" + "="*70) |
| print("๐ Launching at http://0.0.0.0:7860") |
| print("="*70 + "\n") |
| |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=8085, |
| share=False, |
| show_error=True |
| ) |
|
|
|
|