decula commited on
Commit
9c07e66
·
1 Parent(s): 9b0e8a8

add qwen3_9b_dual

Browse files
Files changed (1) hide show
  1. qwen3_9b_dual.py +146 -0
qwen3_9b_dual.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, torch, gc, subprocess
2
+ import gradio as gr
3
+ from datetime import datetime
4
+ from pynvml import *
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
6
+ from peft import PeftModel
7
+ from kaggle_secrets import UserSecretsClient
8
+ from huggingface_hub import login
9
+
10
+ # --- 1. 配置与认证 ---
11
+ model_id = "Qwen/Qwen3.5-9B"
12
+ lora_repo = "decula/sd"
13
+ port = 7860
14
+ use_frpc = True
15
+ frpconfigfile = "7680.ini" # 确保该文件已上传至 Kaggle 工作目录
16
+
17
+ try:
18
+ user_secrets = UserSecretsClient()
19
+ hf_token = user_secrets.get_secret("DE_HF")
20
+ if hf_token: login(token=hf_token)
21
+ except:
22
+ print("HF Token 获取失败,将尝试公开访问")
23
+
24
+ # --- 2. 显存监控初始化 ---
25
+ try:
26
+ nvmlInit()
27
+ GPU_COUNT = nvmlDeviceGetCount()
28
+ gpu_h0 = nvmlDeviceGetHandleByIndex(0)
29
+ gpu_h1 = nvmlDeviceGetHandleByIndex(1) if GPU_COUNT > 1 else None
30
+ except Exception as e:
31
+ print(f"NVML 初始化失败: {e}")
32
+ GPU_COUNT = 0
33
+
34
+ # --- 3. 加载模型 (保持测试成功的逻辑) ---
35
+ print(f"正在双卡部署模型: {model_id}...")
36
+
37
+ bnb_config = BitsAndBytesConfig(
38
+ load_in_4bit=True,
39
+ bnb_4bit_use_double_quant=True,
40
+ bnb_4bit_quant_type="nf4",
41
+ bnb_4bit_compute_dtype=torch.float16,
42
+ )
43
+
44
+ # 限制每张卡 11GB,留出 4GB 给 KV Cache 和 Gradio 进程
45
+ max_memory = {0: "11GiB", 1: "11GiB", "cpu": "20GiB"}
46
+
47
+ base_model = AutoModelForCausalLM.from_pretrained(
48
+ model_id,
49
+ quantization_config=bnb_config,
50
+ device_map="auto",
51
+ max_memory=max_memory,
52
+ trust_remote_code=True,
53
+ torch_dtype=torch.float16,
54
+ low_cpu_mem_usage=True
55
+ )
56
+
57
+ try:
58
+ model = PeftModel.from_pretrained(base_model, lora_repo)
59
+ model.eval()
60
+ print("✅ 适配器加载成功")
61
+ except Exception as e:
62
+ print(f"❌ 适配器加载失败: {e}")
63
+ model = base_model
64
+
65
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
66
+ tokenizer.pad_token = tokenizer.eos_token
67
+
68
+ # --- 4. frpc 启动函数 ---
69
+ def start_frpc(port, config_file, enabled):
70
+ if enabled:
71
+ if os.path.exists('./frpc'):
72
+ subprocess.run(['chmod', '+x', './frpc'], check=True)
73
+ print(f'正在启动 frpc 映射端口 {port}...')
74
+ subprocess.Popen(['./frpc', '-c', config_file])
75
+ else:
76
+ print("错误:当前目录下未找到 frpc 可执行文件")
77
+
78
+ start_frpc(port, frpconfigfile, use_frpc)
79
+
80
+ # --- 5. 推理评估逻辑 ---
81
+ def evaluate(
82
+ prompt,
83
+ max_tokens=512,
84
+ temperature=0.7,
85
+ top_p=0.8,
86
+ repetition_penalty=1.1
87
+ ):
88
+ # 构建对话模板
89
+ messages = [{"role": "user", "content": prompt}]
90
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
91
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)
92
+
93
+ with torch.no_grad():
94
+ # 流式生成的简化模拟(Transformers 直接生成,此处 yield 最终结果)
95
+ output_ids = model.generate(
96
+ **inputs,
97
+ max_new_tokens=int(max_tokens),
98
+ do_sample=True,
99
+ temperature=float(temperature),
100
+ top_p=float(top_p),
101
+ repetition_penalty=float(repetition_penalty),
102
+ pad_token_id=tokenizer.pad_token_id,
103
+ eos_token_id=tokenizer.eos_token_id
104
+ )
105
+
106
+ response = tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
107
+
108
+ # 打印显存状态
109
+ if GPU_COUNT > 0:
110
+ info0 = nvmlDeviceGetMemoryInfo(gpu_h0)
111
+ print(f"GPU0: {info0.used/1024**2:.0f}MB / GPU1: {nvmlDeviceGetMemoryInfo(gpu_h1).used/1024**2:.0f}MB" if gpu_h1 else f"GPU0: {info0.used/1024**2:.0f}MB")
112
+
113
+ return response
114
+
115
+ # --- 6. Gradio 界面设计 ---
116
+ with gr.Blocks(title="Qwen3.5-9B Dual-GPU GUI") as demo:
117
+ gr.HTML("<div style='text-align: center;'><h1>Qwen 3.5 9B + LoRA (Dual T4)</h1></div>")
118
+
119
+ with gr.Row():
120
+ with gr.Column():
121
+ input_text = gr.Textbox(lines=4, label="输入问题", placeholder="请输入您想咨询的内容...")
122
+ with gr.Row():
123
+ btn_submit = gr.Button("发送请求", variant="primary")
124
+ btn_clear = gr.Button("重置")
125
+
126
+ with gr.Accordion("生成参数配置", open=False):
127
+ tk_count = gr.Slider(128, 2048, label="最大生成长度", step=128, value=512)
128
+ temp = gr.Slider(0.1, 1.5, label="温度 (Temperature)", step=0.1, value=0.7)
129
+ tp = gr.Slider(0.1, 1.0, label="Top P", step=0.05, value=0.8)
130
+ rep_p = gr.Slider(1.0, 1.5, label="重复惩罚", step=0.05, value=1.1)
131
+
132
+ with gr.Column():
133
+ output_text = gr.Textbox(lines=12, label="AI 回复", interactive=False)
134
+
135
+ # 绑定事件
136
+ btn_submit.click(
137
+ evaluate,
138
+ inputs=[input_text, tk_count, temp, tp, rep_p],
139
+ outputs=[output_text]
140
+ )
141
+ btn_clear.click(lambda: ("", ""), outputs=[input_text, output_text])
142
+
143
+ # --- 7. 启动 ---
144
+ if __name__ == "__main__":
145
+ # share=False 因为我们要用自己的 frpc 进行穿透
146
+ demo.launch(server_port=port, share=False)