sd / qianwen_rag.py
decula
added cai
16363a1
import gradio as gr
import os, gc, copy, torch
from datetime import datetime
from pynvml import *
from duckduckgo_search import DDGS
import re
import asyncio
import subprocess
from transformers import AutoTokenizer, AutoModelForCausalLM
# --- 硬件检测与初始化 ---
HAS_GPU = False
try:
nvmlInit()
GPU_COUNT = nvmlDeviceGetCount()
if GPU_COUNT > 0:
HAS_GPU = True
print(f"检测到 {GPU_COUNT} 个GPU设备")
for i in range(GPU_COUNT):
handle = nvmlDeviceGetHandleByIndex(i)
info = nvmlDeviceGetMemoryInfo(handle)
name = nvmlDeviceGetName(handle)
print(f"GPU {i}: {name}, 总内存: {info.total / 1024**3:.2f} GB")
except NVMLError as error:
print(error)
# --- 模型导入方式 (指定方式且适配双卡) ---
model_id = "Orion-zhen/c4ai-command-r-08-2024-h-novel-exl2"
ctx_limit = 20000
title = "c4ai ommand r 08 2024 h novel exl2with RAG"
print(f"正在并行加载模型 {model_id} 到多块 GPU...")
# 采用你指定的导入方式
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto", # 自动适配双 GPU 分片
torch_dtype=torch.float16, # 确保 16G 显存不溢出
low_cpu_mem_usage=True,
trust_remote_code=True
)
print(f"模型加载完成。设备映射: {model.hf_device_map}")
# --- 搜索与辅助功能 ---
async def understanding_question(question: str):
question = question.lower()
question = re.sub(r'^(can you|could you|please|tell me about|what is|who is|how to|why is|when did)\s+', '', question)
return question
def web_search(query, max_results=3):
try:
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36"
}
with DDGS(headers=headers) as ddgs:
results = list(ddgs.text(query, max_results=max_results))
if not results: return ""
formatted_results = "\n\nSearch Results:\n"
for i, result in enumerate(results, 1):
formatted_results += f"[{i}] {result['title']}\nSummary: {result['body']}\n\n"
return formatted_results
except Exception as e:
print(f"Search error: {e}")
return ""
def extract_search_query(text):
text = text.lower()
text = re.sub(r'user:\s*|a:\s*', '', text)
text = re.sub(r'^(can you|could you|please|tell me about|what is|who is|how to|why is|when did)\s+', '', text)
return text[:100]
# --- 核心推理函数 (适配 GUI 参数与 Chat Template) ---
def evaluate(
ctx,
token_count=200,
temperature=1.0,
top_p=0.7,
presencePenalty=0.1,
countPenalty=0.1,
):
# 1. RAG 逻辑
search_query = extract_search_query(ctx)
search_results = ""
if len(search_query) > 5:
search_results = web_search(search_query)
# 2. 构造对话消息 (处理 RAG 信息)
user_content = ctx
if search_results:
user_content = f"参考信息:\n{search_results}\n\n用户问题:{ctx}"
messages = [
{"role": "user", "content": user_content},
]
# 3. 按照你要求的 apply_chat_template 方式处理输入
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device) # 移动到模型所在设备
# 4. 生成参数适配
gen_kwargs = {
"max_new_tokens": int(token_count),
"temperature": max(0.1, float(temperature)),
"top_p": float(top_p),
"repetition_penalty": float(1.0 + presencePenalty), # 映射 GUI 惩罚项
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id
}
with torch.no_grad():
output = model.generate(**inputs, **gen_kwargs)
# 5. 解码并移除 Prompt
prompt_len = inputs.input_ids.shape[1]
result = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
# 6. 显存清理与监控 (保留原有逻辑)
if HAS_GPU:
torch.cuda.empty_cache()
for i in range(GPU_COUNT):
handle = nvmlDeviceGetHandleByIndex(i)
gpu_info = nvmlDeviceGetMemoryInfo(handle)
print(f'GPU {i} VRAM: 已用 {gpu_info.used/(1024**3):.2f}GB / 总计 {gpu_info.total/(1024**3):.2f}GB')
return result
# --- frpc 部分代码 (保持原封不动) ---
port=7860
use_frpc=True
frpconfigfile="7680.ini"
def install_Frpc(port, frpconfigfile, use_frpc):
if use_frpc:
subprocess.run(['chmod', '+x', './frpc'], check=True)
print(f'正在启动frp ,端口{port}')
subprocess.Popen(['./frpc', '-c', frpconfigfile])
install_Frpc('7860', frpconfigfile, use_frpc)
# --- Gradio UI 界面 ---
with gr.Blocks(title=title) as demo:
gr.HTML(f"<div style=\"text-align: center;\"><h1>{title}</h1></div>")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(lines=5, label="提示词", value="")
token_count = gr.Slider(100, 10000, label="最大Token数", step=100, value=500)
temperature = gr.Slider(0.2, 2.0, label="温度", step=0.1, value=1.0)
top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.7)
presence_penalty = gr.Slider(0.0, 1.0, label="存在惩罚", step=0.1, value=0.1)
count_penalty = gr.Slider(0.0, 1.0, label="计数惩罚", step=0.1, value=0.1)
with gr.Column():
submit = gr.Button("提交", variant="primary")
output = gr.Textbox(label="输出", lines=15)
clear = gr.Button("清除", variant="secondary")
submit.click(
evaluate,
[prompt, token_count, temperature, top_p, presence_penalty, count_penalty],
[output]
)
clear.click(lambda: None, [], [output])
# 启动服务
demo.launch(share=False)