carbonx commited on
Commit
9f45f5e
·
verified ·
1 Parent(s): 2e7e0eb

Add vision_llm.py

Browse files
Files changed (1) hide show
  1. vision_llm.py +135 -0
vision_llm.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multimodal Vision-Language Model (Qwen2.5-VL) wrapper."""
2
+ import os
3
+ import torch
4
+ from PIL import Image
5
+ from typing import Optional, Union, List
6
+
7
+ from transformers import (
8
+ Qwen2_5_VLForConditionalGeneration,
9
+ AutoProcessor,
10
+ AutoModelForSpeechSeq2Seq,
11
+ AutoProcessor as WhisperProcessor,
12
+ pipeline,
13
+ )
14
+ from qwen_vl_utils import process_vision_info
15
+
16
+
17
+ class MultimodalAssistant:
18
+ """
19
+ Combines:
20
+ - Qwen2.5-VL-7B for vision+language understanding
21
+ - Whisper for STT
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ vlm_model_id: str = "Qwen/Qwen2.5-VL-7B-Instruct",
27
+ whisper_model_id: str = "openai/whisper-large-v3",
28
+ device: str = "auto",
29
+ ):
30
+ self.device = device
31
+ self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
32
+
33
+ print("[assistant] Loading VLM: %s ..." % vlm_model_id)
34
+ self.vlm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
35
+ vlm_model_id,
36
+ torch_dtype="auto",
37
+ device_map=device,
38
+ trust_remote_code=True,
39
+ )
40
+ self.processor = AutoProcessor.from_pretrained(
41
+ vlm_model_id,
42
+ trust_remote_code=True,
43
+ )
44
+ print("[assistant] VLM loaded.")
45
+
46
+ print("[assistant] Loading STT: %s ..." % whisper_model_id)
47
+ stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(
48
+ whisper_model_id,
49
+ torch_dtype=self.torch_dtype,
50
+ low_cpu_mem_usage=True,
51
+ use_safetensors=True,
52
+ )
53
+ stt_model.to(self.device if self.device != "auto" else ("cuda" if torch.cuda.is_available() else "cpu"))
54
+ stt_processor = WhisperProcessor.from_pretrained(whisper_model_id)
55
+ self.stt_pipe = pipeline(
56
+ "automatic-speech-recognition",
57
+ model=stt_model,
58
+ tokenizer=stt_processor.tokenizer,
59
+ feature_extractor=stt_processor.feature_extractor,
60
+ torch_dtype=self.torch_dtype,
61
+ device=0 if torch.cuda.is_available() else -1,
62
+ )
63
+ print("[assistant] STT loaded.")
64
+
65
+ def transcribe_audio(self, audio_bytes: bytes) -> str:
66
+ """Transcribe WAV bytes to Norwegian text."""
67
+ import tempfile
68
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
69
+ f.write(audio_bytes)
70
+ tmp_path = f.name
71
+
72
+ result = self.stt_pipe(
73
+ tmp_path,
74
+ generate_kwargs={"language": "no", "task": "transcribe"},
75
+ )
76
+ os.remove(tmp_path)
77
+ text = result["text"].strip()
78
+ print("[stt] Transcribed: %s" % text)
79
+ return text
80
+
81
+ def ask_with_image(
82
+ self,
83
+ image: Image.Image,
84
+ text: str,
85
+ max_new_tokens: int = 512,
86
+ ) -> str:
87
+ """Send a screenshot + text prompt to Qwen2.5-VL and return response."""
88
+
89
+ system_prompt = (
90
+ "Du er en hjelpsom, norsk AI-assistent som ser brukerens skjermbilde. "
91
+ "Svar konsist, presist og på norsk. Hvis spørsmålet er på engelsk, svar på engelsk."
92
+ )
93
+
94
+ messages = [
95
+ {"role": "system", "content": system_prompt},
96
+ {
97
+ "role": "user",
98
+ "content": [
99
+ {
100
+ "type": "image",
101
+ "image": image,
102
+ "min_pixels": 50176,
103
+ "max_pixels": 501760,
104
+ },
105
+ {"type": "text", "text": text},
106
+ ],
107
+ },
108
+ ]
109
+
110
+ text_input = self.processor.apply_chat_template(
111
+ messages, tokenize=False, add_generation_prompt=True
112
+ )
113
+ image_inputs, video_inputs = process_vision_info(messages)
114
+ inputs = self.processor(
115
+ text=[text_input],
116
+ images=image_inputs,
117
+ videos=video_inputs,
118
+ padding=True,
119
+ return_tensors="pt",
120
+ )
121
+ inputs = inputs.to(self.vlm.device)
122
+
123
+ generated_ids = self.vlm.generate(**inputs, max_new_tokens=max_new_tokens)
124
+ generated_ids_trimmed = [
125
+ out_ids[len(in_ids):]
126
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
127
+ ]
128
+ response = self.processor.batch_decode(
129
+ generated_ids_trimmed,
130
+ skip_special_tokens=True,
131
+ clean_up_tokenization_spaces=False,
132
+ )[0]
133
+
134
+ print("[vlm] Response: %s..." % response[:120])
135
+ return response.strip()