buddy-desktop / vision_llm.py
carbonx's picture
Add configurable models with auto VRAM detection
b28ee4c verified
"""Multimodal Vision-Language Model wrapper med konfigurerbar modell."""
import os
import torch
from PIL import Image
from transformers import (
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
AutoProcessor,
AutoModelForSpeechSeq2Seq,
AutoProcessor as WhisperProcessor,
pipeline,
)
from qwen_vl_utils import process_vision_info
# Modellkart: vennlig navn -> (model_id, min_vram_gb, klass, processor)
MODEL_REGISTRY = {
"qwen2.5-vl-3b": {
"id": "Qwen/Qwen2.5-VL-3B-Instruct",
"min_vram": 7,
"class": Qwen2_5_VLForConditionalGeneration,
"supports_pixels": True, # Qwen2.5 støtter min/max_pixels
},
"qwen2.5-vl-7b": {
"id": "Qwen/Qwen2.5-VL-7B-Instruct",
"min_vram": 16,
"class": Qwen2_5_VLForConditionalGeneration,
"supports_pixels": True,
},
"qwen2-vl-2b": {
"id": "Qwen/Qwen2-VL-2B-Instruct",
"min_vram": 5,
"class": Qwen2VLForConditionalGeneration,
"supports_pixels": False, # Qwen2 bruker annen syntaks
},
}
def _detect_gpu_vram() -> int:
"""Returner estimert GPU VRAM i GB."""
if not torch.cuda.is_available():
return 0
return torch.cuda.get_device_properties(0).total_memory // (1024 ** 3)
def _pick_model(preferred: str = None) -> dict:
"""Velg beste modell basert på VRAM og preferanse."""
vram = _detect_gpu_vram()
print("[config] Oppdaget VRAM: %d GB" % vram)
if preferred and preferred in MODEL_REGISTRY:
model = MODEL_REGISTRY[preferred]
if vram >= model["min_vram"]:
print("[config] Bruker foretrukken modell: %s" % preferred)
return model
print("[config] Advarsel: %s trenger %dGB, har bare %dGB" % (preferred, model["min_vram"], vram))
# Auto-velg største modell som passer i VRAM
for name in ["qwen2.5-vl-7b", "qwen2.5-vl-3b", "qwen2-vl-2b"]:
model = MODEL_REGISTRY[name]
if vram >= model["min_vram"]:
print("[config] Auto-valgte modell: %s (%s)" % (name, model["id"]))
return model
print("[config] Fallback: qwen2-vl-2b (minst VRAM-krav)")
return MODEL_REGISTRY["qwen2-vl-2b"]
class MultimodalAssistant:
"""
Konfigurerbar multimodal assistent.
Miljøvariabler:
BUDDY_VLM_MODEL -- qwen2.5-vl-3b | qwen2.5-vl-7b | qwen2-vl-2b
BUDDY_STT_MODEL -- openai/whisper-large-v3 | openai/whisper-medium | openai/whisper-small
BUDDY_DEVICE -- auto | cuda | cpu
"""
def __init__(
self,
vlm_model: str = None,
whisper_model: str = None,
device: str = None,
):
# --- Konfigurer ---
vlm_cfg = _pick_model(vlm_model or os.environ.get("BUDDY_VLM_MODEL"))
whisper_id = whisper_model or os.environ.get("BUDDY_STT_MODEL", "openai/whisper-large-v3")
device = device or os.environ.get("BUDDY_DEVICE", "auto")
self._vlm_model_id = vlm_cfg["id"]
self._vlm_class = vlm_cfg["class"]
self._supports_pixels = vlm_cfg["supports_pixels"]
self.device = device
self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# --- Last VLM ---
print("[assistant] Laster VLM: %s ..." % self._vlm_model_id)
self.vlm = self._vlm_class.from_pretrained(
self._vlm_model_id,
torch_dtype="auto",
device_map=device,
trust_remote_code=True,
)
self.processor = AutoProcessor.from_pretrained(
self._vlm_model_id,
trust_remote_code=True,
)
print("[assistant] VLM lastet.")
# --- Last STT ---
print("[assistant] Laster STT: %s ..." % whisper_id)
stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(
whisper_id,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
)
stt_model.to(
device if device != "auto" else ("cuda" if torch.cuda.is_available() else "cpu")
)
stt_processor = WhisperProcessor.from_pretrained(whisper_id)
self.stt_pipe = pipeline(
"automatic-speech-recognition",
model=stt_model,
tokenizer=stt_processor.tokenizer,
feature_extractor=stt_processor.feature_extractor,
torch_dtype=self.torch_dtype,
device=0 if torch.cuda.is_available() else -1,
)
print("[assistant] STT lastet.")
def transcribe_audio(self, audio_bytes: bytes) -> str:
"""Transkriber WAV bytes til norsk tekst."""
import tempfile
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
f.write(audio_bytes)
tmp_path = f.name
result = self.stt_pipe(
tmp_path,
generate_kwargs={"language": "no", "task": "transcribe"},
)
os.remove(tmp_path)
text = result["text"].strip()
print("[stt] Transkribert: %s" % text)
return text
def ask_with_image(
self,
image: Image.Image,
text: str,
max_new_tokens: int = 512,
) -> str:
"""Send screenshot + tekst til VLM og returner svar."""
system_prompt = (
"Du er en hjelpsom, norsk AI-assistent som ser brukerens skjermbilde. "
"Svar konsist, presist og på norsk. Hvis spørsmålet er på engelsk, svar på engelsk."
)
# Bygg bilde-element
if self._supports_pixels:
image_elem = {
"type": "image",
"image": image,
"min_pixels": 50176,
"max_pixels": 501760,
}
else:
# Qwen2-VL bruker annen syntaks
image_elem = {
"type": "image",
"image": image,
}
messages = [
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": [
image_elem,
{"type": "text", "text": text},
],
},
]
text_input = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text_input],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.vlm.device)
generated_ids = self.vlm.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
response = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
print("[vlm] Svar: %s..." % response[:120])
return response.strip()