import gradio as gr import os, gc, torch from datetime import datetime from pynvml import * from PIL import Image import requests from io import BytesIO from transformers import AutoProcessor, AutoModelForImageTextToText # --- 硬件检测 (保持不变) --- HAS_GPU = False try: nvmlInit() GPU_COUNT = nvmlDeviceGetCount() if GPU_COUNT > 0: HAS_GPU = True gpu_h = nvmlDeviceGetHandleByIndex(0) except Exception as error: print(f"NVML Error: {error}") GPU_COUNT = 0 # --- 模型加载配置 --- model_id = "Qwen/Qwen3.5-9B" device_map = "auto" print(f"正在加载模型: {model_id} ...") processor = AutoProcessor.from_pretrained(model_id) model = AutoModelForImageTextToText.from_pretrained( model_id, device_map=device_map, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, trust_remote_code=True ) title = "Qwen 3.5-9B Multi-Model (Text & Image)" # --- 推理逻辑 --- def evaluate( image, text_input, max_new_tokens=200, temperature=1.0, top_p=0.7, ): if not text_input and image is None: return "请输入文字或上传图片。" # --- 核心修改:动态构造消息结构 --- if image is not None: # 有图片模式 messages = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": text_input} ] }, ] prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor(text=[prompt], images=[image], return_tensors="pt").to(model.device) else: # 纯文字模式 messages = [ { "role": "user", "content": [ {"type": "text", "text": text_input} ] }, ] prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # 纯文字推理时不传 images 参数 inputs = processor(text=[prompt], return_tensors="pt").to(model.device) # --- 生成 --- with torch.no_grad(): generated_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=True if temperature > 0 else False, temperature=temperature, top_p=top_p, ) # 剪切掉输入部分的 tokens generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] # GPU 显存清理 if HAS_GPU: try: gpu_info = nvmlDeviceGetMemoryInfo(gpu_h) print(f'VRAM Used: {gpu_info.used / 1024**2:.0f}MB') torch.cuda.empty_cache() except: pass gc.collect() return output_text # --- Gradio 界面 (保持结构,微调默认值) --- with gr.Blocks(title=title) as demo: gr.HTML(f"